跳转至

第21章:性能优化

全面提升RAG系统的响应速度和吞吐量


📚 章节概述

本章将学习如何对RAG系统进行全面性能优化,提升响应速度和降低成本。

学习目标

完成本章后,你将能够: - ✅ 识别性能瓶颈 - ✅ 实施多层缓存策略 - ✅ 优化数据库查询 - ✅ 优化向量检索性能 - ✅ 实施批处理和并发 - ✅ 降低API调用成本

预计时间

  • 理论学习:60分钟
  • 实践操作:90-120分钟
  • 总计:约3-4小时

1. 性能分析

1.1 性能指标

关键指标

延迟(Latency):
  - P50: 50%请求的响应时间
  - P95: 95%请求的响应时间
  - P99: 99%请求的响应时间

吞吐量(Throughput):
  - QPS: 每秒查询数
  - RPS: 每秒请求数

资源使用:
  - CPU使用率
  - 内存使用率
  - 网络 I/O
  - 磁盘 I/O

1.2 性能分析工具

Python Profiling

import cProfile
import pstats
from io import StringIO

def profile_query(func):
    """性能分析装饰器"""
    def wrapper(*args, **kwargs):
        pr = cProfile.Profile()
        pr.enable()
        result = func(*args, **kwargs)
        pr.disable()

        # 输出统计信息
        s = StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
        ps.print_stats(20)  # 打印前20个
        print(s.getvalue())
        return result
    return wrapper

# 使用
@profile_query
async def rag_query(text: str):
    # RAG逻辑
    pass

内存分析

import tracemalloc

def trace_memory(func):
    """内存分析装饰器"""
    def wrapper(*args, **kwargs):
        tracemalloc.start()
        result = func(*args, **kwargs)

        snapshot = tracemalloc.take_snapshot()
        top_stats = snapshot.statistics('lineno')
        print("[Top 10 memory usage]")
        for stat in top_stats[:10]:
            print(stat)

        tracemalloc.stop()
        return result
    return wrapper

1.3 性能瓶颈识别

常见瓶颈

  1. LLM API调用
  2. 网络延迟(100-500ms)
  3. 模型推理时间(1-5s)
  4. Token限制处理

  5. 向量检索

  6. 大规模向量搜索
  7. 向量维度过高
  8. 索引效率低

  9. 数据库查询

  10. N+1查询问题
  11. 缺少索引
  12. 连接池不足

  13. I/O操作

  14. 同步I/O阻塞
  15. 大文件读取
  16. 网络传输

2. 缓存策略

2.1 多层缓存架构

┌─────────────────────────────────┐
│     Client (Browser)            │  L1: 浏览器缓存
└────────────┬────────────────────┘
┌────────────▼────────────────────┐
│     CDN / Load Balancer         │  L2: CDN缓存
└────────────┬────────────────────┘
┌────────────▼────────────────────┐
│     Application (Redis)         │  L3: 应用缓存
└────────────┬────────────────────┘
┌────────────▼────────────────────┐
│     Database / Vector Store     │  L4: 数据库
└─────────────────────────────────┘

2.2 Redis缓存实现

import redis
import pickle
from typing import Optional, Any
from functools import wraps
import hashlib

# Redis客户端
redis_client = redis.Redis(
    host='localhost',
    port=6379,
    db=0,
    decode_responses=False
)

def cache_result(ttl: int = 3600, key_prefix: str = ""):
    """
    缓存装饰器

    Args:
        ttl: 缓存过期时间(秒)
        key_prefix: 缓存键前缀
    """
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            # 生成缓存键
            key_data = f"{key_prefix}:{func.__name__}:{args}:{kwargs}"
            cache_key = hashlib.md5(key_data.encode()).hexdigest()

            # 尝试从缓存获取
            cached = redis_client.get(cache_key)
            if cached:
                return pickle.loads(cached)

            # 执行函数
            result = await func(*args, **kwargs)

            # 存入缓存
            redis_client.setex(
                cache_key,
                ttl,
                pickle.dumps(result)
            )

            return result
        return wrapper
    return decorator

# 使用示例
@cache_result(ttl=1800, key_prefix="rag_query")
async def query_with_cache(text: str):
    # RAG查询逻辑
    return await rag_query(text)

2.3 查询结果缓存

class QueryCache:
    """查询结果缓存"""

    def __init__(self, redis_client, max_size: int = 1000):
        self.redis = redis_client
        self.max_size = max_size
        self.cache = {}

    async def get(self, query: str) -> Optional[dict]:
        """获取缓存"""
        cache_key = f"query:{hashlib.md5(query.encode()).hexdigest()}"

        # L1: 内存缓存
        if cache_key in self.cache:
            return self.cache[cache_key]

        # L2: Redis缓存
        cached = self.redis.get(cache_key)
        if cached:
            result = pickle.loads(cached)
            self.cache[cache_key] = result
            return result

        return None

    async def set(self, query: str, result: dict, ttl: int = 1800):
        """设置缓存"""
        cache_key = f"query:{hashlib.md5(query.encode()).hexdigest()}"

        # 内存缓存(LRU)
        if len(self.cache) >= self.max_size:
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        self.cache[cache_key] = result

        # Redis缓存
        self.redis.setex(cache_key, ttl, pickle.dumps(result))

    async def invalidate(self, query: str):
        """使缓存失效"""
        cache_key = f"query:{hashlib.md5(query.encode()).hexdigest()}"
        if cache_key in self.cache:
            del self.cache[cache_key]
        self.redis.delete(cache_key)

2.4 向量缓存

class EmbeddingCache:
    """向量嵌入缓存"""

    def __init__(self, redis_client):
        self.redis = redis_client

    async def get_embeddings(self, texts: list[str]) -> dict[str, np.ndarray]:
        """批量获取嵌入向量"""
        cache_keys = [f"emb:{hashlib.md5(t.encode()).hexdigest()}" for t in texts]

        # 批量获取
        cached_values = self.redis.mget(cache_keys)

        results = {}
        missing_indices = []

        for i, (text, cached) in enumerate(zip(texts, cached_values)):
            if cached:
                results[text] = pickle.loads(cached)
            else:
                missing_indices.append(i)

        return results, missing_indices

    async def set_embeddings(self, texts: list[str], embeddings: list[np.ndarray]):
        """批量设置嵌入向量"""
        pipe = self.redis.pipeline()
        for text, emb in zip(texts, embeddings):
            cache_key = f"emb:{hashlib.md5(text.encode()).hexdigest()}"
            pipe.setex(cache_key, 86400, pickle.dumps(emb))  # 24小时
        pipe.execute()

3. 数据库优化

3.1 索引优化

-- 创建索引
CREATE INDEX idx_documents_created_at ON documents(created_at);
CREATE INDEX idx_documents_metadata ON documents USING GIN(metadata);
CREATE INDEX idx_queries_user_id ON queries(user_id);
CREATE INDEX idx_queries_created_at ON queries(created_at DESC);

-- 复合索引
CREATE INDEX idx_documents_user_created ON documents(user_id, created_at DESC);

-- 部分索引(只索引活跃数据)
CREATE INDEX idx_active_documents ON documents(id)
WHERE status = 'active';

3.2 查询优化

# 优化前(N+1问题)
async def get_documents_with_tags_optimized(document_ids: list[int]):
    """批量获取文档和标签"""
    # 使用IN查询代替多次查询
    documents = await db.execute(
        select(Document)
        .where(Document.id.in_(document_ids))
        .options(selectinload(Document.tags))  # 预加载关联
    )

    return {doc.id: doc for doc in documents}

# 使用批量操作
async def bulk_insert_embeddings(documents: list[dict]):
    """批量插入嵌入向量"""
    embeddings = [Embedding(**doc) for doc in documents]

    # 批量插入(比逐个插入快10-100倍)
    async with database.session() as session:
        session.add_all(embeddings)
        await session.commit()

3.3 连接池配置

from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.pool import QueuePool

# 创建连接池
engine = create_async_engine(
    "postgresql+asyncpg://user:pass@localhost/ragdb",
    poolclass=QueuePool,
    pool_size=20,          # 连接池大小
    max_overflow=40,       # 最大溢出连接数
    pool_timeout=30,       # 获取连接超时
    pool_recycle=3600,     # 连接回收时间
    pool_pre_ping=True,    # 连接前ping检查
    echo=False
)

4. 向量检索优化

4.1 索引选择

import chromadb
from chromadb.config import Settings

# 配置ChromaDB
client = chromadb.HttpClient(
    host="localhost",
    port=8000,
    settings=Settings(
        anonymized_telemetry=False,
        allow_reset=True
    )
)

# 创建带索引的集合
collection = client.create_collection(
    name="documents",
    metadata={
        "hnsw:space": "cosine",
        "hnsw:construction_ef": 200,  # 构建时的参数
        "hnsw:M": 16                  # 图的连接数
    }
)

4.2 批量检索

async def batch_search(
    queries: list[str],
    collection,
    n_results: int = 10
) -> list[list[dict]]:
    """批量检索"""
    # 批量生成嵌入
    embeddings = await embed_texts(queries)

    # 批量检索
    results = collection.query(
        query_embeddings=embeddings,
        n_results=n_results
    )

    return results

4.3 预过滤

async def filtered_search(
    query: str,
    filters: dict,
    collection
) -> list[dict]:
    """带预过滤的检索"""

    # 先用where子句过滤
    results = collection.query(
        query_embeddings=[await embed_text(query)],
        where=filters,  # {"category": "tech", "date": {"$gt": "2024-01-01"}}
        n_results=10
    )

    return results

5. 并发和异步

5.1 异步I/O

import asyncio
from concurrent.futures import ThreadPoolExecutor

async def parallel_retrieval(query: str):
    """并行检索多个数据源"""

    tasks = [
        vector_search(query),
        keyword_search(query),
        hybrid_search(query)
    ]

    results = await asyncio.gather(*tasks, return_exceptions=True)

    return {
        "vector": results[0],
        "keyword": results[1],
        "hybrid": results[2]
    }

5.2 批处理

class BatchProcessor:
    """批处理处理器"""

    def __init__(self, batch_size: int = 10, timeout: float = 1.0):
        self.batch_size = batch_size
        self.timeout = timeout
        self.queue = asyncio.Queue()
        self.task = None

    async def start(self):
        """启动批处理任务"""
        self.task = asyncio.create_task(self._process_batch())

    async def _process_batch(self):
        """处理批次"""
        while True:
            batch = []
            deadline = asyncio.time() + self.timeout

            # 收集批次
            while len(batch) < self.batch_size:
                try:
                    item = await asyncio.wait_for(
                        self.queue.get(),
                        timeout=deadline - asyncio.time()
                    )
                    batch.append(item)
                except asyncio.TimeoutError:
                    break

            if batch:
                await self._process(batch)

    async def _process(self, batch: list):
        """处理批次数据"""
        # 批量处理逻辑
        results = await batch_embed([item['text'] for item in batch])

        for item, result in zip(batch, results):
            if 'future' in item:
                item['future'].set_result(result)

    async def submit(self, text: str) -> Any:
        """提交待处理项目"""
        future = asyncio.Future()
        await self.queue.put({'text': text, 'future': future})
        return await future

6. LLM调用优化

6.1 Token优化

async def optimize_context(
    query: str,
    documents: list[dict],
    max_tokens: int = 3000
) -> str:
    """优化上下文长度"""

    # 计算每个文档的token数
    for doc in documents:
        doc['token_count'] = count_tokens(doc['text'])

    # 按相关性排序
    documents.sort(key=lambda x: x['score'], reverse=True)

    # 贪心选择文档(尽可能多的高相关文档)
    selected_docs = []
    total_tokens = 0

    for doc in documents:
        if total_tokens + doc['token_count'] <= max_tokens:
            selected_docs.append(doc)
            total_tokens += doc['token_count']
        else:
            break

    return format_context(selected_docs)

6.2 批量生成

async def batch_generate(
    queries: list[str],
    batch_size: int = 5
) -> list[str]:
    """批量生成回答"""

    results = []
    for i in range(0, len(queries), batch_size):
        batch = queries[i:i+batch_size]

        # 并发调用LLM
        tasks = [generate_answer(q) for q in batch]
        batch_results = await asyncio.gather(*tasks)

        results.extend(batch_results)

        # 避免速率限制
        if i + batch_size < len(queries):
            await asyncio.sleep(1)

    return results

7. 性能测试

7.1 负载测试

# 使用Locust进行负载测试
from locust import HttpUser, task, between

class RAGUser(HttpUser):
    wait_time = between(1, 3)

    @task
    def query(self):
        query_text = "What is RAG?"
        self.client.post(
            "/query",
            json={"text": query_text}
        )

    @task(3)
    def health_check(self):
        self.client.get("/health")

运行负载测试

locust -f locustfile.py --host=http://localhost:8000

7.2 基准测试

import time
from statistics import mean, median

def benchmark(func, iterations: int = 100):
    """性能基准测试"""

    times = []
    results = []

    for _ in range(iterations):
        start = time.perf_counter()
        result = func()
        end = time.perf_counter()

        times.append(end - start)
        results.append(result)

    return {
        "mean": mean(times),
        "median": median(times),
        "min": min(times),
        "max": max(times),
        "p95": sorted(times)[int(iterations * 0.95)],
        "p99": sorted(times)[int(iterations * 0.99)],
        "throughput": iterations / sum(times)
    }

# 使用
result = benchmark(lambda: query("test query"))
print(f"P95延迟: {result['p95']*1000:.2f}ms")
print(f"吞吐量: {result['throughput']:.2f} QPS")

8. 实战练习

练习1:实施缓存策略

任务: 1. 实现Redis缓存 2. 缓存查询结果和嵌入向量 3. 测试缓存命中率 4. 优化缓存失效策略

验证

# 测试缓存效果
# 无缓存:1000ms
# 有缓存:10ms
# 提升:100倍


练习2:优化数据库查询

任务: 1. 添加适当索引 2. 优化慢查询 3. 实现批量操作 4. 配置连接池

验证

-- 查看慢查询
SELECT query, mean_exec_time, calls
FROM pg_stat_statements
ORDER BY mean_exec_time DESC
LIMIT 10;


练习3:并发优化

任务: 1. 实现异步处理 2. 配置批处理 3. 测试并发性能 4. 找到最优并发数

验证

# 负载测试
# 单线程:10 QPS
# 并发20:150 QPS
# 提升:15倍


9. 最佳实践

9.1 优化原则

  • Measure First: 先测量再优化
  • Optimize Hot Path: 优化热点路径
  • Avoid Premature Optimization: 避免过早优化
  • Trade-offs: 在速度、成本、质量间平衡

9.2 监控指标

  • 缓存命中率: >80%
  • P95延迟: <2s
  • 错误率: <1%
  • 资源使用: CPU <70%, 内存 <80%

10. 总结

关键要点

  1. 性能分析
  2. 识别瓶颈
  3. 使用profiling工具
  4. 监控关键指标

  5. 缓存策略

  6. 多层缓存
  7. LRU失效
  8. 批量操作

  9. 优化技术

  10. 异步处理
  11. 批处理
  12. 索引优化

下一步

  • 学习安全实践(第22章)
  • 最佳实践(第23章)

恭喜完成第21章! 🎉

你已经掌握RAG系统性能优化的核心技能!

下一步:第22章 - 安全实践