跳转至

第11章:性能优化

RAG系统响应太慢?通过缓存、批处理、并发和内存优化,将系统吞吐量提升2-3倍,响应时间降低50%!


📚 学习目标

学完本章后,你将能够:

  • 识别RAG系统的性能瓶颈
  • 实现多层缓存策略
  • 应用批处理优化
  • 使用并发提升吞吐量
  • 优化内存使用
  • 将系统性能提升2-3倍

预计学习时间:3小时 难度等级:⭐⭐⭐☆☆


前置知识

在开始本章学习前,你需要具备:

  • 完成模块1的基础RAG实现
  • 理解高级RAG模式(第10章)
  • 熟悉Python异步编程

环境要求: - Python >= 3.9 - redis(用于缓存) - profiling工具(py-spy, memory_profiler)


11.1 性能瓶颈分析

11.1.1 RAG系统性能剖析

典型RAG流程的时间分解

用户查询 → RAG系统 → 答案
  │         │
  │         ├─ [1] 查询处理: ~50ms
  │         ├─ [2] 嵌入查询: ~100ms
  │         ├─ [3] 向量检索: ~200ms
  │         ├─ [4] 文档加载: ~100ms
  │         ├─ [5] LLM生成: ~2000ms ⚠️ 瓶颈
  │         └─ [6] 后处理: ~50ms
  总计: ~2500ms

各部分占比:
  LLM生成: 80% ⚠️
  向量检索: 8%
  查询嵌入: 4%
  文档加载: 4%
  其他: 4%

11.1.2 性能分析工具

工具1:cProfile(CPU分析)

# 文件名:profiling_example.py
"""
性能分析示例
"""

import cProfile
import pstats
import io
from pstats import SortKey


def rag_pipeline(query: str) -> str:
    """
    模拟RAG流程
    """
    # 步骤1:查询处理
    import time
    time.sleep(0.05)

    # 步骤2:嵌入查询
    time.sleep(0.1)

    # 步骤3:向量检索
    time.sleep(0.2)

    # 步骤4:文档加载
    time.sleep(0.1)

    # 步骤5:LLM生成
    time.sleep(2.0)

    # 步骤6:后处理
    time.sleep(0.05)

    return "答案"


# 性能分析
if __name__ == "__main__":
    # 创建分析器
    profiler = cProfile.Profile()

    # 运行分析
    print("运行性能分析...")
    profiler.enable()

    # 执行函数
    result = rag_pipeline("Python性能优化")

    profiler.disable()

    # 输出结果
    s = io.StringIO()
    ps = pstats.Stats(profiler, stream=s).sort_stats(SortKey.CUMULATIVE)

    # 打印Top-10函数
    ps.print_stats(10)
    print(s.getvalue())

输出示例

运行性能分析...

         100 function calls in 2.501 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    2.501    2.501 profiling_example.py:11(rag_pipeline)
        1    2.000    2.000    2.000    2.000 {built-in method time.sleep}
        1    0.200    0.200    0.200    0.200 {built-in method time.sleep}
        1    0.100    0.100    0.100    0.100 {built-in method time.sleep}
        ...

工具2:memory_profiler(内存分析)

# 安装
# pip install memory_profiler

from memory_profiler import profile


@profile
def rag_pipeline_memory(query: str) -> str:
    """
    分析内存使用
    """
    # 模拟数据加载
    large_data = [i for i in range(1000000)]  # ~8MB

    # 模拟嵌入
    embeddings = [[0.1] * 768 for _ in range(100)]  # ~300KB

    # 模拟文档
    documents = ["文档内容" * 1000 for _ in range(100)]  # ~1MB

    return "答案"


if __name__ == "__main__":
    result = rag_pipeline_memory("测试查询")

输出示例

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
     1     50.0 MiB     50.0 MiB           1   @profile
     2                                         def rag_pipeline_memory(query):
     3     58.0 MiB      8.0 MiB           1       large_data = [i for i in range(1000000)]
     4     58.3 MiB      0.3 MiB           1       embeddings = [[0.1] * 768 for _ in range(100)]
     5     59.3 MiB      1.0 MiB           1       documents = ["文档内容" * 1000 for _ in range(100)]
     6     59.3 MiB      0.0 MiB           1       return "答案"

11.2 缓存策略

11.2.1 为什么需要缓存?

问题:重复计算浪费资源

场景:FAQ问答系统

问题列表:
- "如何重置密码?"  (每天100次)
- "如何退款?"      (每天50次)
- "如何联系客服?"  (每天30次)

无缓存:
  每次查询都需要:
  - 嵌入查询: 100ms
  - 向量检索: 200ms
  - LLM生成: 2000ms
  总计: 2300ms

有缓存:
  首次查询: 2300ms
  后续查询: 10ms (直接返回缓存)

节省: 99.6%的时间!

11.2.2 多层缓存架构

┌─────────────────────────────────────────────────────┐
│               多层缓存架构                          │
├─────────────────────────────────────────────────────┤
│                                                     │
│  L1: 内存缓存 (LRU)                                 │
│  ├─ 容量: 1000条                                    │
│  ├─ TTL: 1小时                                      │
│  ├─ 命中率: ~60%                                    │
│  └─ 响应时间: <1ms                                  │
│     ↓ (未命中)                                      │
│  L2: Redis缓存                                      │
│  ├─ 容量: 10000条                                   │
│  ├─ TTL: 24小时                                     │
│  ├─ 命中率: ~30%                                    │
│  └─ 响应时间: <10ms                                 │
│     ↓ (未命中)                                      │
│  L3: 向量数据库 + LLM                                │
│  ├─ 完整RAG流程                                     │
│  ├─ 响应时间: ~2000ms                               │
│  └─ 结果回填到L1和L2                                │
│                                                     │
└─────────────────────────────────────────────────────┘

整体命中率: 90% (60% + 30%)
平均响应时间: 0.6*1ms + 0.3*10ms + 0.1*2000ms = 210ms
提升: 10.7倍

11.2.3 缓存实现

L1: 内存缓存

# 文件名:cache_strategies.py
"""
缓存策略实现
"""

from typing import Any, Optional, Dict
from functools import lru_cache
import hashlib
import json
import time


class MemoryCache:
    """
    内存缓存(LRU)

    Args:
        max_size: 最大缓存条目数
        ttl: 生存时间(秒)

    Example:
        >>> cache = MemoryCache(max_size=1000, ttl=3600)
        >>> cache.set("key", {"answer": "..."})
        >>> result = cache.get("key")
    """

    def __init__(self, max_size: int = 1000, ttl: int = 3600):
        self.max_size = max_size
        self.ttl = ttl
        self.cache: Dict[str, Dict] = {}

    def _generate_key(self, query: str, **kwargs) -> str:
        """
        生成缓存键
        """
        # 包含查询和参数
        data = {'query': query, **kwargs}
        data_str = json.dumps(data, sort_keys=True)
        return hashlib.md5(data_str.encode()).hexdigest()

    def get(self, query: str, **kwargs) -> Optional[Any]:
        """
        获取缓存
        """
        key = self._generate_key(query, **kwargs)

        if key in self.cache:
            entry = self.cache[key]

            # 检查是否过期
            if time.time() - entry['timestamp'] < self.ttl:
                entry['hits'] += 1
                return entry['value']
            else:
                # 过期,删除
                del self.cache[key]

        return None

    def set(self, query: str, value: Any, **kwargs):
        """
        设置缓存
        """
        key = self._generate_key(query, **kwargs)

        # 检查容量
        if len(self.cache) >= self.max_size:
            # LRU删除(简化版:删除第一个)
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]

        # 存储缓存
        self.cache[key] = {
            'value': value,
            'timestamp': time.time(),
            'hits': 0
        }

    def clear(self):
        """
        清空缓存
        """
        self.cache.clear()

    def get_stats(self) -> Dict:
        """
        获取统计信息
        """
        total_hits = sum(entry['hits'] for entry in self.cache.values())

        return {
            'size': len(self.cache),
            'max_size': self.max_size,
            'total_hits': total_hits,
            'hit_rate': total_hits / max(total_hits, 1)
        }


# 使用示例
if __name__ == "__main__":
    cache = MemoryCache(max_size=100, ttl=60)

    # 设置缓存
    cache.set("如何重置密码?", {
        'answer': '请在设置页面点击"重置密码"...',
        'sources': ['doc1', 'doc2']
    })

    # 获取缓存
    result = cache.get("如何重置密码?")
    print(f"缓存命中: {result}")

    # 再次获取(命中次数+1)
    result = cache.get("如何重置密码?")
    print(f"再次命中: {result}")

    # 统计信息
    stats = cache.get_stats()
    print(f"\n缓存统计: {stats}")

L2: Redis缓存

# 安装
# pip install redis

import redis
import json
import time


class RedisCache:
    """
    Redis缓存

    Args:
        host: Redis主机
        port: Redis端口
        db: Redis数据库编号
        ttl: 生存时间(秒)

    Example:
        >>> cache = RedisCache(host='localhost', port=6379, ttl=86400)
        >>> cache.set("key", {"answer": "..."})
        >>> result = cache.get("key")
    """

    def __init__(self,
                 host: str = 'localhost',
                 port: int = 6379,
                 db: int = 0,
                 ttl: int = 86400):

        self.ttl = ttl
        self.redis_client = redis.Redis(
            host=host,
            port=port,
            db=db,
            decode_responses=True
        )

    def _generate_key(self, query: str, **kwargs) -> str:
        """
        生成缓存键
        """
        data = {'query': query, **kwargs}
        data_str = json.dumps(data, sort_keys=True)
        return f"rag:cache:{hashlib.md5(data_str.encode()).hexdigest()}"

    def get(self, query: str, **kwargs) -> Optional[Any]:
        """
        获取缓存
        """
        key = self._generate_key(query, **kwargs)

        try:
            value = self.redis_client.get(key)
            if value:
                return json.loads(value)
        except Exception as e:
            print(f"Redis获取失败: {e}")

        return None

    def set(self, query: str, value: Any, **kwargs):
        """
        设置缓存
        """
        key = self._generate_key(query, **kwargs)

        try:
            value_json = json.dumps(value, ensure_ascii=False)
            self.redis_client.setex(key, self.ttl, value_json)
        except Exception as e:
            print(f"Redis设置失败: {e}")

    def clear(self):
        """
        清空所有RAG缓存
        """
        try:
            # 删除所有rag:cache:*键
            for key in self.redis_client.scan_iter("rag:cache:*"):
                self.redis_client.delete(key)
        except Exception as e:
            print(f"Redis清空失败: {e}")

    def get_stats(self) -> Dict:
        """
        获取统计信息
        """
        try:
            # 统计键数量
            keys = list(self.redis_client.scan_iter("rag:cache:*"))
            return {
                'size': len(keys),
                'ttl': self.ttl
            }
        except Exception as e:
            print(f"Redis统计失败: {e}")
            return {'size': 0, 'ttl': self.ttl}


# 使用示例
if __name__ == "__main__":
    # 注意:需要先启动Redis服务
    # redis-server

    try:
        cache = RedisCache(host='localhost', port=6379, ttl=3600)

        # 设置缓存
        cache.set("如何退款?", {
            'answer': '您可以在订单页面申请退款...',
            'sources': ['doc3']
        })

        # 获取缓存
        result = cache.get("如何退款?")
        print(f"Redis缓存: {result}")

        # 统计信息
        stats = cache.get_stats()
        print(f"\nRedis统计: {stats}")

    except Exception as e:
        print(f"Redis未运行,跳过示例: {e}")

完整的缓存装饰器

# 文件名:cached_rag.py
"""
带缓存的RAG系统
"""

from typing import Optional, Callable, Any
import functools


class CachedRAGSystem:
    """
    带缓存的RAG系统

    Args:
        rag_pipeline: RAG处理函数
        l1_cache: L1缓存(内存)
        l2_cache: L2缓存(Redis),可选

    Example:
        >>> rag = CachedRAGSystem(rag_pipeline, l1_cache, l2_cache)
        >>> result = rag.query("如何重置密码?")
        >>> # 首次查询:完整流程
        >>> result = rag.query("如何重置密码?")
        >>> # 第二次:从L1缓存返回
    """

    def __init__(self,
                 rag_pipeline: Callable,
                 l1_cache: MemoryCache,
                 l2_cache: Optional[RedisCache] = None):

        self.rag_pipeline = rag_pipeline
        self.l1_cache = l1_cache
        self.l2_cache = l2_cache

        # 统计信息
        self.stats = {
            'total_queries': 0,
            'l1_hits': 0,
            'l2_hits': 0,
            'misses': 0
        }

    def query(self, query: str, **kwargs) -> Dict:
        """
        查询(带缓存)
        """
        self.stats['total_queries'] += 1

        # L1缓存查找
        result = self.l1_cache.get(query, **kwargs)
        if result is not None:
            self.stats['l1_hits'] += 1
            return {
                'answer': result['answer'],
                'sources': result['sources'],
                'cache': 'L1'
            }

        # L2缓存查找
        if self.l2_cache:
            result = self.l2_cache.get(query, **kwargs)
            if result is not None:
                self.stats['l2_hits'] += 1

                # 回填L1缓存
                self.l1_cache.set(query, result, **kwargs)

                return {
                    'answer': result['answer'],
                    'sources': result['sources'],
                    'cache': 'L2'
                }

        # 缓存未命中,执行RAG流程
        self.stats['misses'] += 1

        print("缓存未命中,执行RAG流程...")
        result = self.rag_pipeline(query, **kwargs)

        # 存储到缓存
        cache_data = {
            'answer': result['answer'],
            'sources': result.get('sources', [])
        }

        self.l1_cache.set(query, cache_data, **kwargs)

        if self.l2_cache:
            self.l2_cache.set(query, cache_data, **kwargs)

        return {
            'answer': result['answer'],
            'sources': result.get('sources', []),
            'cache': 'None'
        }

    def get_stats(self) -> Dict:
        """
        获取缓存统计
        """
        total = self.stats['total_queries']
        l1_rate = self.stats['l1_hits'] / total if total > 0 else 0
        l2_rate = self.stats['l2_hits'] / total if total > 0 else 0
        miss_rate = self.stats['misses'] / total if total > 0 else 0

        return {
            'total_queries': total,
            'l1_hits': self.stats['l1_hits'],
            'l2_hits': self.stats['l2_hits'],
            'misses': self.stats['misses'],
            'l1_hit_rate': l1_rate,
            'l2_hit_rate': l2_rate,
            'overall_hit_rate': l1_rate + l2_rate,
            'miss_rate': miss_rate
        }


# 使用示例
if __name__ == "__main__":
    # 模拟RAG流程
    def mock_rag_pipeline(query: str, **kwargs) -> Dict:
        print(f"  [RAG] 处理查询: {query}")
        import time
        time.sleep(0.5)  # 模拟耗时
        return {
            'answer': f'关于"{query}"的答案...',
            'sources': ['doc1', 'doc2']
        }

    # 创建缓存
    l1_cache = MemoryCache(max_size=100, ttl=60)
    l2_cache = None  # RedisCache()  # 如果有Redis

    # 创建带缓存的RAG系统
    rag = CachedRAGSystem(mock_rag_pipeline, l1_cache, l2_cache)

    # 测试查询
    test_queries = [
        "如何重置密码?",
        "如何重置密码?",  # 重复查询
        "如何退款?",
        "如何重置密码?",  # 再次重复
        "如何联系客服?",
        "如何退款?",  # 重复查询
    ]

    print("\n" + "="*80)
    print("带缓存的RAG系统测试")
    print("="*80 + "\n")

    for query in test_queries:
        print(f"查询: {query}")
        result = rag.query(query)
        print(f"  缓存: {result['cache']}")
        print(f"  答案: {result['answer'][:50]}...")
        print()

    # 统计信息
    stats = rag.get_stats()

    print("="*80)
    print("缓存统计")
    print("="*80)
    print(f"总查询数: {stats['total_queries']}")
    print(f"L1命中: {stats['l1_hits']} ({stats['l1_hit_rate']*100:.1f}%)")
    print(f"L2命中: {stats['l2_hits']} ({stats['l2_hit_rate']*100:.1f}%)")
    print(f"未命中: {stats['misses']} ({stats['miss_rate']*100:.1f}%)")
    print(f"整体命中率: {stats['overall_hit_rate']*100:.1f}%")
    print("\n性能提升:")
    print(f"  节省时间: {stats['l1_hits'] + stats['l2_hits']} * 500ms = {(stats['l1_hits'] + stats['l2_hits']) * 0.5:.1f}秒")

11.3 批处理优化

11.3.1 为什么需要批处理?

问题:逐个处理效率低

场景:批量处理100个查询

无批处理:
  for query in queries:
      result = process(query)  # 每个查询独立处理
  总时间: 100 * 500ms = 50秒

批处理:
  batch_process(queries)  # 批量处理
  总时间: 10秒

提升: 5倍

11.3.2 批处理实现

# 文件名:batch_processing.py
"""
批处理优化实现
"""

from typing import List, Dict, Any
import time
from concurrent.futures import ThreadPoolExecutor, as_completed


class BatchProcessor:
    """
    批处理器

    Args:
        batch_size: 批大小
        process_function: 处理函数
        max_workers: 最大工作线程数

    Example:
        >>> processor = BatchProcessor(
        ...     batch_size=10,
        ...     process_function=rag_query,
        ...     max_workers=5
        ... )
        >>> results = processor.process(queries)
    """

    def __init__(self,
                 batch_size: int = 10,
                 process_function: Callable = None,
                 max_workers: int = 4):

        self.batch_size = batch_size
        self.process_function = process_function
        self.max_workers = max_workers

    def process(self, items: List[Any]) -> List[Any]:
        """
        批量处理

        Args:
            items: 待处理的项目列表

        Returns:
            处理结果列表
        """
        results = []

        # 分批
        batches = [
            items[i:i + self.batch_size]
            for i in range(0, len(items), self.batch_size)
        ]

        print(f"总共 {len(items)} 个项目,分为 {len(batches)} 批")

        # 并发处理批次
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # 提交批次任务
            future_to_batch = {
                executor.submit(self._process_batch, batch): i
                for i, batch in enumerate(batches)
            }

            # 收集结果
            for future in as_completed(future_to_batch):
                batch_idx = future_to_batch[future]
                try:
                    batch_results = future.result()
                    results.extend(batch_results)
                    print(f"批次 {batch_idx + 1}/{len(batches)} 完成")
                except Exception as e:
                    print(f"批次 {batch_idx + 1} 失败: {e}")

        # 恢复原始顺序
        return results

    def _process_batch(self, batch: List[Any]) -> List[Any]:
        """
        处理单个批次
        """
        if self.process_function:
            return [self.process_function(item) for item in batch]
        return batch


# RAG批处理示例
class BatchRAGSystem:
    """
    批量RAG系统

    支持批量查询、批量嵌入等
    """

    def __init__(self, retriever, llm_client):
        self.retriever = retriever
        self.llm_client = llm_client

    def batch_query(self, queries: List[str],
                   batch_size: int = 10) -> List[Dict]:
        """
        批量查询

        Args:
            queries: 查询列表
            batch_size: 批大小

        Returns:
            答案列表
        """
        results = []

        # 批量嵌入(如果支持)
        all_embeddings = self._batch_embed(queries, batch_size)

        # 批量检索
        for i in range(0, len(queries), batch_size):
            batch_queries = queries[i:i + batch_size]
            batch_embeddings = all_embeddings[i:i + batch_size]

            # 检索
            batch_results = self._batch_retrieve(batch_queries, batch_embeddings)

            # 生成答案
            for query, retrieved_docs in zip(batch_queries, batch_results):
                answer = self._generate_answer(query, retrieved_docs)
                results.append(answer)

        return results

    def _batch_embed(self, queries: List[str],
                    batch_size: int) -> List[List[float]]:
        """
        批量嵌入查询

        优化:批量调用嵌入模型
        """
        all_embeddings = []

        for i in range(0, len(queries), batch_size):
            batch = queries[i:i + batch_size]
            # 批量嵌入(假设retriever支持)
            embeddings = self.retriever.embed_batch(batch)
            all_embeddings.extend(embeddings)

        return all_embeddings

    def _batch_retrieve(self,
                       queries: List[str],
                       embeddings: List[List[float]]) -> List[List[Dict]]:
        """
        批量检索
        """
        results = []
        for query, embedding in zip(queries, embeddings):
            docs = self.retriever.retrieve_by_embedding(embedding)
            results.append(docs)
        return results

    def _generate_answer(self, query: str, docs: List[Dict]) -> Dict:
        """
        生成答案
        """
        context = "\n".join([doc['text'] for doc in docs[:3]])

        # 调用LLM
        prompt = f"基于以下信息回答问题:\n\n{context}\n\n问题:{query}\n答案:"

        # 模拟LLM调用
        return {
            'answer': f'基于{len(docs)}个文档生成的答案...',
            'sources': [doc['id'] for doc in docs]
        }


# 使用示例
if __name__ == "__main__":
    # 模拟RAG系统
    class MockRAGSystem:
        def query(self, query: str) -> Dict:
            time.sleep(0.1)  # 模拟处理时间
            return {'query': query, 'answer': f'答案: {query}'}

    rag = MockRAGSystem()

    # 测试查询
    queries = [f"查询{i+1}" for i in range(20)]

    print("\n" + "="*80)
    print("批处理性能测试")
    print("="*80 + "\n")

    # 无批处理
    print("方式1: 逐个处理")
    start = time.time()
    results_single = [rag.query(q) for q in queries]
    time_single = time.time() - start
    print(f"  时间: {time_single:.2f}\n")

    # 批处理
    print("方式2: 批处理")
    processor = BatchProcessor(
        batch_size=5,
        process_function=rag.query,
        max_workers=4
    )

    start = time.time()
    results_batch = processor.process(queries)
    time_batch = time.time() - start
    print(f"  时间: {time_batch:.2f}\n")

    # 对比
    print("="*80)
    print("性能对比")
    print("="*80)
    print(f"逐个处理: {time_single:.2f}秒")
    print(f"批处理:   {time_batch:.2f}秒")
    print(f"提升:     {time_single/time_batch:.2f}x")
    print(f"节省:     {((time_single - time_batch)/time_single * 100):.1f}%")

11.4 并发优化

11.4.1 Python并发模型

并发模型对比:

┌─────────────────────────────────────────────────────┐
│  模型          适用场景              GIL影响        │
├─────────────────────────────────────────────────────┤
│  多线程          I/O密集型             受限制        │
│  (threading)    - 网络请求                          │
│                  - 文件读写                         │
│                  - 数据库查询                       │
├─────────────────────────────────────────────────────┤
│  多进程          CPU密集型             不受限制      │
│  (multiprocessing) - 计算密集                       │
│                  - 数据处理                         │
│                  - 模型推理                         │
├─────────────────────────────────────────────────────┤
│  异步I/O         高并发I/O            不受限制      │
│  (asyncio)      - 大量网络请求                     │
│                  - WebSocket                       │
│                  - 实时数据流                       │
└─────────────────────────────────────────────────────┘

11.4.2 并发实现

多线程(适合I/O密集)

# 文件名:concurrent_rag.py
"""
并发RAG实现
"""

from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import asyncio
import time
from typing import List, Dict, Callable


class ConcurrentRAGSystem:
    """
    并发RAG系统

    支持多线程处理多个查询
    """

    def __init__(self, rag_pipeline: Callable, max_workers: int = 10):
        self.rag_pipeline = rag_pipeline
        self.max_workers = max_workers

    def query_concurrent(self, queries: List[str]) -> List[Dict]:
        """
        并发查询(多线程)

        Args:
            queries: 查询列表

        Returns:
            答案列表(保持原始顺序)
        """
        results = [None] * len(queries)

        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # 提交所有任务
            future_to_index = {
                executor.submit(self.rag_pipeline, query): idx
                for idx, query in enumerate(queries)
            }

            # 收集结果
            for future in as_completed(future_to_index):
                idx = future_to_index[future]
                try:
                    results[idx] = future.result()
                    print(f"查询 {idx + 1}/{len(queries)} 完成")
                except Exception as e:
                    print(f"查询 {idx + 1} 失败: {e}")
                    results[idx] = {'error': str(e)}

        return results

    def query_sequential(self, queries: List[str]) -> List[Dict]:
        """
        顺序查询(对比)
        """
        results = []
        for query in queries:
            result = self.rag_pipeline(query)
            results.append(result)
        return results


# 异步RAG系统
class AsyncRAGSystem:
    """
    异步RAG系统

    使用asyncio实现高并发
    """

    def __init__(self):
        pass

    async def query_async(self, query: str) -> Dict:
        """
        异步查询
        """
        # 异步嵌入
        embedding = await self._embed_async(query)

        # 异步检索
        docs = await self._retrieve_async(embedding)

        # 异步生成答案
        answer = await self._generate_async(query, docs)

        return answer

    async def _embed_async(self, query: str) -> List[float]:
        """
        异步嵌入(模拟)
        """
        await asyncio.sleep(0.1)  # 模拟I/O
        return [0.1] * 768

    async def _retrieve_async(self, embedding: List[float]) -> List[Dict]:
        """
        异步检索(模拟)
        """
        await asyncio.sleep(0.2)  # 模拟I/O
        return [{'id': 'doc1', 'text': '相关文档'}]

    async def _generate_async(self, query: str, docs: List[Dict]) -> Dict:
        """
        异步生成答案(模拟)
        """
        await asyncio.sleep(1.0)  # 模拟LLM调用
        return {'answer': f'异步答案: {query}'}

    async def batch_query_async(self, queries: List[str]) -> List[Dict]:
        """
        批量异步查询
        """
        tasks = [self.query_async(q) for q in queries]
        return await asyncio.gather(*tasks)


# 使用示例
if __name__ == "__main__":
    # 模拟RAG流程
    def mock_rag_pipeline(query: str) -> Dict:
        """模拟RAG流程(I/O密集)"""
        time.sleep(0.5)  # 模拟I/O等待
        return {'query': query, 'answer': f'答案: {query}'}

    # 测试查询
    queries = [f"查询{i+1}" for i in range(10)]

    print("\n" + "="*80)
    print("并发性能测试")
    print("="*80 + "\n")

    # 方式1:顺序处理
    print("方式1: 顺序处理")
    rag = ConcurrentRAGSystem(mock_rag_pipeline, max_workers=1)

    start = time.time()
    results_sequential = rag.query_sequential(queries)
    time_sequential = time.time() - start

    print(f"  时间: {time_sequential:.2f}秒")
    print(f"  平均每个查询: {time_sequential/len(queries):.2f}\n")

    # 方式2:并发处理(多线程)
    print("方式2: 并发处理(多线程)")
    rag = ConcurrentRAGSystem(mock_rag_pipeline, max_workers=10)

    start = time.time()
    results_concurrent = rag.query_concurrent(queries)
    time_concurrent = time.time() - start

    print(f"  时间: {time_concurrent:.2f}秒")
    print(f"  平均每个查询: {time_concurrent/len(queries):.2f}\n")

    # 方式3:异步处理
    print("方式3: 异步处理(asyncio)")
    async_rag = AsyncRAGSystem()

    start = time.time()
    results_async = asyncio.run(async_rag.batch_query_async(queries))
    time_async = time.time() - start

    print(f"  时间: {time_async:.2f}秒")
    print(f"  平均每个查询: {time_async/len(queries):.2f}\n")

    # 对比
    print("="*80)
    print("性能对比")
    print("="*80)
    print(f"顺序处理:  {time_sequential:.2f}秒 (基准)")
    print(f"多线程:    {time_concurrent:.2f}秒 ({time_sequential/time_concurrent:.2f}x)")
    print(f"异步:      {time_async:.2f}秒 ({time_sequential/time_async:.2f}x)")

11.5 内存优化

11.5.1 内存问题识别

常见内存问题

问题1:文档全部加载到内存
  现象:内存占用高
  解决:流式处理,按需加载

问题2:嵌入向量占用大
  现象:100万文档 * 768维 * 4字节 = 2.9GB
  解决:使用量化(float16 -> 2字节)

问题3:缓存无限增长
  现象:内存持续增长
  解决:LRU淘汰,设置大小上限

问题4:LLM上下文累积
  现象:长对话导致内存增长
  解决:限制上下文长度,定期清理

11.5.2 优化策略

策略1:生成器

# 文件名:memory_optimization.py
"""
内存优化实现
"""

def load_documents_lazy(file_path: str):
    """
    惰性加载文档(生成器)

    优点:不会一次性加载所有文档到内存

    Yield:
        文档文本
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            yield line.strip()


# 对比:一次性加载 vs 惰性加载

def load_documents_all(file_path: str) -> List[str]:
    """
    一次性加载所有文档

    问题:大文件会导致内存溢出
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f]


# 使用示例
if __name__ == "__main__":
    # 模拟大文件
    import tempfile
    import os

    # 创建临时文件
    temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt')
    for i in range(100000):
        temp_file.write(f"文档{i}的内容...\n")
    temp_file_path = temp_file.name
    temp_file.close()

    print("\n" + "="*80)
    print("内存优化对比")
    print("="*80 + "\n")

    # 方式1:一次性加载
    print("方式1: 一次性加载")
    import tracemalloc
    tracemalloc.start()

    docs_all = load_documents_all(temp_file_path)
    current, peak = tracemalloc.get_traced_memory()
    print(f"  当前内存: {current / 1024 / 1024:.2f} MB")
    print(f"  峰值内存: {peak / 1024 / 1024:.2f} MB")
    print(f"  文档数量: {len(docs_all)}\n")

    tracemalloc.stop()

    # 方式2:惰性加载
    print("方式2: 惰性加载(生成器)")
    tracemalloc.start()

    docs_lazy = load_documents_lazy(temp_file_path)
    count = 0
    for doc in docs_lazy:
        count += 1
        if count >= 10:  # 只处理前10个
            break

    current, peak = tracemalloc.get_traced_memory()
    print(f"  当前内存: {current / 1024 / 1024:.2f} MB")
    print(f"  峰值内存: {peak / 1024 / 1024:.2f} MB")
    print(f"  处理文档: {count}\n")

    tracemalloc.stop()

    # 清理
    os.unlink(temp_file_path)

策略2:向量量化

import numpy as np


class QuantizedVectorStore:
    """
    量化的向量存储

    使用float16代替float32,节省50%内存
    """

    def __init__(self):
        self.vectors = None
        self.dtype = np.float16  # 使用半精度浮点数

    def add_vectors(self, vectors: np.ndarray):
        """
        添加向量(自动量化)
        """
        if vectors.dtype != self.dtype:
            vectors = vectors.astype(self.dtype)
        self.vectors = vectors

    def get_memory_usage(self) -> Dict:
        """
        获取内存使用情况
        """
        if self.vectors is None:
            return {'memory_mb': 0}

        memory_mb = self.vectors.nbytes / 1024 / 1024
        return {
            'memory_mb': memory_mb,
            'num_vectors': len(self.vectors),
            'dimension': self.vectors.shape[1],
            'dtype': str(self.vectors.dtype)
        }


# 对比:float32 vs float16
if __name__ == "__main__":
    print("\n" + "="*80)
    print("向量量化对比")
    print("="*80 + "\n")

    # 生成示例向量
    num_vectors = 100000
    dimension = 768

    print(f"向量数量: {num_vectors}")
    print(f"向量维度: {dimension}\n")

    # float32
    vectors_f32 = np.random.randn(num_vectors, dimension).astype(np.float32)
    memory_f32 = vectors_f32.nbytes / 1024 / 1024

    print(f"Float32:")
    print(f"  内存: {memory_f32:.2f} MB")
    print(f"  每个: {vectors_f32[0].nbytes / 1024:.2f} KB\n")

    # float16
    vectors_f16 = vectors_f32.astype(np.float16)
    memory_f16 = vectors_f16.nbytes / 1024 / 1024

    print(f"Float16:")
    print(f"  内存: {memory_f16:.2f} MB")
    print(f"  每个: {vectors_f16[0].nbytes / 1024:.2f} KB\n")

    print(f"节省: {(memory_f32 - memory_f16) / memory_f32 * 100:.1f}%")

策略3:上下文管理

class ContextManagedRAG:
    """
    带上下文管理的RAG系统

    自动清理不再使用的资源
    """

    def __init__(self, max_context_length: int = 10):
        self.max_context_length = max_context_length
        self.conversation_history = []

    def query(self, query: str) -> Dict:
        """
        查询(自动管理上下文)
        """
        # 添加到历史
        self.conversation_history.append({
            'query': query,
            'timestamp': time.time()
        })

        # 限制历史长度
        if len(self.conversation_history) > self.max_context_length:
            # 移除最旧的记录
            removed = self.conversation_history.pop(0)
            print(f"清理旧上下文: {removed['query'][:30]}...")

        # 获取最近的历史
        recent_history = self.conversation_history[-5:]

        # 处理查询
        return {'answer': f'基于{len(recent_history)}轮上下文的答案'}

    def clear_history(self):
        """
        清空历史
        """
        self.conversation_history.clear()
        print("上下文已清空")

11.6 综合优化案例

11.6.1 完整的优化方案

# 文件名:optimized_rag_system.py
"""
综合优化的RAG系统

集成:缓存 + 批处理 + 并发 + 内存优化
"""

from typing import List, Dict, Optional
import time
from concurrent.futures import ThreadPoolExecutor


class OptimizedRAGSystem:
    """
    优化的RAG系统

    特性:
    1. 多层缓存(L1内存 + L2Redis)
    2. 批处理
    3. 并发处理
    4. 内存优化

    性能目标:
    - 缓存命中率 > 80%
    - 吞吐量 > 100 queries/s
    - 平均响应时间 < 100ms(命中缓存)
    """

    def __init__(self,
                 rag_pipeline,
                 l1_cache_size: int = 1000,
                 max_workers: int = 10):

        # RAG流程
        self.rag_pipeline = rag_pipeline

        # 缓存
        self.l1_cache = MemoryCache(max_size=l1_cache_size, ttl=3600)
        self.l2_cache = None  # 可选:RedisCache()

        # 并发
        self.max_workers = max_workers

        # 统计
        self.stats = {
            'total_queries': 0,
            'cache_hits': 0,
            'cache_misses': 0,
            'total_time': 0.0
        }

    def query(self, query: str) -> Dict:
        """
        查询(带缓存)
        """
        start_time = time.time()
        self.stats['total_queries'] += 1

        # L1缓存查找
        result = self.l1_cache.get(query)
        if result is not None:
            self.stats['cache_hits'] += 1
            self.stats['total_time'] += time.time() - start_time
            return {**result, 'cache': 'L1'}

        # L2缓存查找(如果配置)
        if self.l2_cache:
            result = self.l2_cache.get(query)
            if result is not None:
                self.stats['cache_hits'] += 1
                # 回填L1
                self.l1_cache.set(query, result)
                self.stats['total_time'] += time.time() - start_time
                return {**result, 'cache': 'L2'}

        # 缓存未命中,执行RAG
        self.stats['cache_misses'] += 1
        result = self.rag_pipeline(query)

        # 存储到缓存
        cache_data = {
            'answer': result['answer'],
            'sources': result.get('sources', [])
        }
        self.l1_cache.set(query, cache_data)

        if self.l2_cache:
            self.l2_cache.set(query, cache_data)

        self.stats['total_time'] += time.time() - start_time
        return {**result, 'cache': 'None'}

    def batch_query(self, queries: List[str],
                   batch_size: int = 10) -> List[Dict]:
        """
        批量查询(并发)
        """
        results = [None] * len(queries)

        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            future_to_index = {
                executor.submit(self.query, query): idx
                for idx, query in enumerate(queries)
            }

            for future in future_to_index:
                idx = future_to_index[future]
                results[idx] = future.result()

        return results

    def get_stats(self) -> Dict:
        """
        获取统计信息
        """
        total = self.stats['total_queries']
        hits = self.stats['cache_hits']
        misses = self.stats['cache_misses']

        avg_time = self.stats['total_time'] / total if total > 0 else 0
        hit_rate = hits / total if total > 0 else 0

        return {
            'total_queries': total,
            'cache_hits': hits,
            'cache_misses': misses,
            'hit_rate': hit_rate,
            'avg_response_time_ms': avg_time * 1000,
            'throughput_qps': total / max(self.stats['total_time'], 0.001)
        }


# 使用示例
if __name__ == "__main__":
    # 模拟RAG流程
    def mock_rag(query: str) -> Dict:
        time.sleep(0.2)  # 模拟处理
        return {
            'answer': f'答案: {query}',
            'sources': ['doc1']
        }

    # 创建优化的RAG系统
    rag = OptimizedRAGSystem(
        rag_pipeline=mock_rag,
        l1_cache_size=1000,
        max_workers=10
    )

    # 测试:包含重复查询
    test_queries = [
        "查询1", "查询2", "查询1",  # 重复
        "查询3", "查询2",          # 重复
        "查询4", "查询5", "查询1",  # 重复
        "查询6"
    ]

    print("\n" + "="*80)
    print("优化RAG系统测试")
    print("="*80 + "\n")

    # 单个查询
    for query in test_queries:
        result = rag.query(query)
        print(f"{query} → 缓存: {result['cache']}")

    # 统计
    stats = rag.get_stats()

    print("\n" + "="*80)
    print("性能统计")
    print("="*80)
    print(f"总查询数: {stats['total_queries']}")
    print(f"缓存命中: {stats['cache_hits']}")
    print(f"缓存未命中: {stats['cache_misses']}")
    print(f"命中率: {stats['hit_rate']*100:.1f}%")
    print(f"平均响应时间: {stats['avg_response_time_ms']:.2f} ms")
    print(f"吞吐量: {stats['throughput_qps']:.2f} QPS")

练习题

练习1:基础练习 - 实现内存缓存

题目:实现一个LRU缓存

要求: 1. 支持get和set操作 2. 容量满时自动淘汰最久未使用的项 3. 支持TTL(过期时间) 4. 提供统计信息

提示: - 使用collections.OrderedDict维护访问顺序 - 记录每个key的访问时间戳


练习2:进阶练习 - 批处理优化

题目:优化批量查询性能

要求: 1. 实现批量查询接口 2. 支持并发处理 3. 对比顺序和批处理的性能 4. 达到3x以上加速


练习3:挑战项目 - 完整的优化系统

项目描述:构建一个生产级的高性能RAG系统

功能需求: 1. ✅ 多层缓存 2. ✅ 批处理 3. ✅ 并发处理 4. ✅ 内存优化 5. ✅ 性能监控

性能目标: - 缓存命中率 > 80% - 吞吐量 > 50 QPS - P95延迟 < 500ms - 内存占用 < 4GB(百万级文档)


总结

本章要点回顾

  1. 性能瓶颈识别
  2. LLM生成是最大瓶颈(~80%时间)
  3. 使用profiling工具定位问题
  4. 重点关注热点函数

  5. 缓存策略

  6. L1内存缓存:快速,容量小
  7. L2 Redis缓存:稍慢,容量大
  8. 多层缓存:90%+命中率

  9. 批处理优化

  10. 批量嵌入:减少API调用
  11. 批量检索:提升吞吐量
  12. 批量生成:利用并行能力

  13. 并发优化

  14. 多线程:适合I/O密集
  15. 多进程:适合CPU密集
  16. 异步:适合高并发

  17. 内存优化

  18. 生成器:惰性加载
  19. 向量量化:节省50%内存
  20. 上下文管理:限制历史长度

学习检查清单

  • 能够使用profiling工具分析性能
  • 理解多层缓存架构
  • 掌握批处理优化方法
  • 能够应用并发处理
  • 理解内存优化策略
  • 能够构建高性能RAG系统

下一步学习


返回目录 | 上一章 | 下一章


本章结束

有任何问题或建议?欢迎提交Issue或PR到教程仓库!