第21章:性能优化¶
全面提升RAG系统的响应速度和吞吐量
📚 章节概述¶
本章将学习如何对RAG系统进行全面性能优化,提升响应速度和降低成本。
学习目标¶
完成本章后,你将能够: - ✅ 识别性能瓶颈 - ✅ 实施多层缓存策略 - ✅ 优化数据库查询 - ✅ 优化向量检索性能 - ✅ 实施批处理和并发 - ✅ 降低API调用成本
预计时间¶
- 理论学习:60分钟
- 实践操作:90-120分钟
- 总计:约3-4小时
1. 性能分析¶
1.1 性能指标¶
关键指标:
延迟(Latency):
- P50: 50%请求的响应时间
- P95: 95%请求的响应时间
- P99: 99%请求的响应时间
吞吐量(Throughput):
- QPS: 每秒查询数
- RPS: 每秒请求数
资源使用:
- CPU使用率
- 内存使用率
- 网络 I/O
- 磁盘 I/O
1.2 性能分析工具¶
Python Profiling:
import cProfile
import pstats
from io import StringIO
def profile_query(func):
"""性能分析装饰器"""
def wrapper(*args, **kwargs):
pr = cProfile.Profile()
pr.enable()
result = func(*args, **kwargs)
pr.disable()
# 输出统计信息
s = StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
ps.print_stats(20) # 打印前20个
print(s.getvalue())
return result
return wrapper
# 使用
@profile_query
async def rag_query(text: str):
# RAG逻辑
pass
内存分析:
import tracemalloc
def trace_memory(func):
"""内存分析装饰器"""
def wrapper(*args, **kwargs):
tracemalloc.start()
result = func(*args, **kwargs)
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
print("[Top 10 memory usage]")
for stat in top_stats[:10]:
print(stat)
tracemalloc.stop()
return result
return wrapper
1.3 性能瓶颈识别¶
常见瓶颈:
- LLM API调用:
- 网络延迟(100-500ms)
- 模型推理时间(1-5s)
-
Token限制处理
-
向量检索:
- 大规模向量搜索
- 向量维度过高
-
索引效率低
-
数据库查询:
- N+1查询问题
- 缺少索引
-
连接池不足
-
I/O操作:
- 同步I/O阻塞
- 大文件读取
- 网络传输
2. 缓存策略¶
2.1 多层缓存架构¶
┌─────────────────────────────────┐
│ Client (Browser) │ L1: 浏览器缓存
└────────────┬────────────────────┘
│
┌────────────▼────────────────────┐
│ CDN / Load Balancer │ L2: CDN缓存
└────────────┬────────────────────┘
│
┌────────────▼────────────────────┐
│ Application (Redis) │ L3: 应用缓存
└────────────┬────────────────────┘
│
┌────────────▼────────────────────┐
│ Database / Vector Store │ L4: 数据库
└─────────────────────────────────┘
2.2 Redis缓存实现¶
import redis
import pickle
from typing import Optional, Any
from functools import wraps
import hashlib
# Redis客户端
redis_client = redis.Redis(
host='localhost',
port=6379,
db=0,
decode_responses=False
)
def cache_result(ttl: int = 3600, key_prefix: str = ""):
"""
缓存装饰器
Args:
ttl: 缓存过期时间(秒)
key_prefix: 缓存键前缀
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# 生成缓存键
key_data = f"{key_prefix}:{func.__name__}:{args}:{kwargs}"
cache_key = hashlib.md5(key_data.encode()).hexdigest()
# 尝试从缓存获取
cached = redis_client.get(cache_key)
if cached:
return pickle.loads(cached)
# 执行函数
result = await func(*args, **kwargs)
# 存入缓存
redis_client.setex(
cache_key,
ttl,
pickle.dumps(result)
)
return result
return wrapper
return decorator
# 使用示例
@cache_result(ttl=1800, key_prefix="rag_query")
async def query_with_cache(text: str):
# RAG查询逻辑
return await rag_query(text)
2.3 查询结果缓存¶
class QueryCache:
"""查询结果缓存"""
def __init__(self, redis_client, max_size: int = 1000):
self.redis = redis_client
self.max_size = max_size
self.cache = {}
async def get(self, query: str) -> Optional[dict]:
"""获取缓存"""
cache_key = f"query:{hashlib.md5(query.encode()).hexdigest()}"
# L1: 内存缓存
if cache_key in self.cache:
return self.cache[cache_key]
# L2: Redis缓存
cached = self.redis.get(cache_key)
if cached:
result = pickle.loads(cached)
self.cache[cache_key] = result
return result
return None
async def set(self, query: str, result: dict, ttl: int = 1800):
"""设置缓存"""
cache_key = f"query:{hashlib.md5(query.encode()).hexdigest()}"
# 内存缓存(LRU)
if len(self.cache) >= self.max_size:
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
self.cache[cache_key] = result
# Redis缓存
self.redis.setex(cache_key, ttl, pickle.dumps(result))
async def invalidate(self, query: str):
"""使缓存失效"""
cache_key = f"query:{hashlib.md5(query.encode()).hexdigest()}"
if cache_key in self.cache:
del self.cache[cache_key]
self.redis.delete(cache_key)
2.4 向量缓存¶
class EmbeddingCache:
"""向量嵌入缓存"""
def __init__(self, redis_client):
self.redis = redis_client
async def get_embeddings(self, texts: list[str]) -> dict[str, np.ndarray]:
"""批量获取嵌入向量"""
cache_keys = [f"emb:{hashlib.md5(t.encode()).hexdigest()}" for t in texts]
# 批量获取
cached_values = self.redis.mget(cache_keys)
results = {}
missing_indices = []
for i, (text, cached) in enumerate(zip(texts, cached_values)):
if cached:
results[text] = pickle.loads(cached)
else:
missing_indices.append(i)
return results, missing_indices
async def set_embeddings(self, texts: list[str], embeddings: list[np.ndarray]):
"""批量设置嵌入向量"""
pipe = self.redis.pipeline()
for text, emb in zip(texts, embeddings):
cache_key = f"emb:{hashlib.md5(text.encode()).hexdigest()}"
pipe.setex(cache_key, 86400, pickle.dumps(emb)) # 24小时
pipe.execute()
3. 数据库优化¶
3.1 索引优化¶
-- 创建索引
CREATE INDEX idx_documents_created_at ON documents(created_at);
CREATE INDEX idx_documents_metadata ON documents USING GIN(metadata);
CREATE INDEX idx_queries_user_id ON queries(user_id);
CREATE INDEX idx_queries_created_at ON queries(created_at DESC);
-- 复合索引
CREATE INDEX idx_documents_user_created ON documents(user_id, created_at DESC);
-- 部分索引(只索引活跃数据)
CREATE INDEX idx_active_documents ON documents(id)
WHERE status = 'active';
3.2 查询优化¶
# 优化前(N+1问题)
async def get_documents_with_tags_optimized(document_ids: list[int]):
"""批量获取文档和标签"""
# 使用IN查询代替多次查询
documents = await db.execute(
select(Document)
.where(Document.id.in_(document_ids))
.options(selectinload(Document.tags)) # 预加载关联
)
return {doc.id: doc for doc in documents}
# 使用批量操作
async def bulk_insert_embeddings(documents: list[dict]):
"""批量插入嵌入向量"""
embeddings = [Embedding(**doc) for doc in documents]
# 批量插入(比逐个插入快10-100倍)
async with database.session() as session:
session.add_all(embeddings)
await session.commit()
3.3 连接池配置¶
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.pool import QueuePool
# 创建连接池
engine = create_async_engine(
"postgresql+asyncpg://user:pass@localhost/ragdb",
poolclass=QueuePool,
pool_size=20, # 连接池大小
max_overflow=40, # 最大溢出连接数
pool_timeout=30, # 获取连接超时
pool_recycle=3600, # 连接回收时间
pool_pre_ping=True, # 连接前ping检查
echo=False
)
4. 向量检索优化¶
4.1 索引选择¶
import chromadb
from chromadb.config import Settings
# 配置ChromaDB
client = chromadb.HttpClient(
host="localhost",
port=8000,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
# 创建带索引的集合
collection = client.create_collection(
name="documents",
metadata={
"hnsw:space": "cosine",
"hnsw:construction_ef": 200, # 构建时的参数
"hnsw:M": 16 # 图的连接数
}
)
4.2 批量检索¶
async def batch_search(
queries: list[str],
collection,
n_results: int = 10
) -> list[list[dict]]:
"""批量检索"""
# 批量生成嵌入
embeddings = await embed_texts(queries)
# 批量检索
results = collection.query(
query_embeddings=embeddings,
n_results=n_results
)
return results
4.3 预过滤¶
async def filtered_search(
query: str,
filters: dict,
collection
) -> list[dict]:
"""带预过滤的检索"""
# 先用where子句过滤
results = collection.query(
query_embeddings=[await embed_text(query)],
where=filters, # {"category": "tech", "date": {"$gt": "2024-01-01"}}
n_results=10
)
return results
5. 并发和异步¶
5.1 异步I/O¶
import asyncio
from concurrent.futures import ThreadPoolExecutor
async def parallel_retrieval(query: str):
"""并行检索多个数据源"""
tasks = [
vector_search(query),
keyword_search(query),
hybrid_search(query)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
return {
"vector": results[0],
"keyword": results[1],
"hybrid": results[2]
}
5.2 批处理¶
class BatchProcessor:
"""批处理处理器"""
def __init__(self, batch_size: int = 10, timeout: float = 1.0):
self.batch_size = batch_size
self.timeout = timeout
self.queue = asyncio.Queue()
self.task = None
async def start(self):
"""启动批处理任务"""
self.task = asyncio.create_task(self._process_batch())
async def _process_batch(self):
"""处理批次"""
while True:
batch = []
deadline = asyncio.time() + self.timeout
# 收集批次
while len(batch) < self.batch_size:
try:
item = await asyncio.wait_for(
self.queue.get(),
timeout=deadline - asyncio.time()
)
batch.append(item)
except asyncio.TimeoutError:
break
if batch:
await self._process(batch)
async def _process(self, batch: list):
"""处理批次数据"""
# 批量处理逻辑
results = await batch_embed([item['text'] for item in batch])
for item, result in zip(batch, results):
if 'future' in item:
item['future'].set_result(result)
async def submit(self, text: str) -> Any:
"""提交待处理项目"""
future = asyncio.Future()
await self.queue.put({'text': text, 'future': future})
return await future
6. LLM调用优化¶
6.1 Token优化¶
async def optimize_context(
query: str,
documents: list[dict],
max_tokens: int = 3000
) -> str:
"""优化上下文长度"""
# 计算每个文档的token数
for doc in documents:
doc['token_count'] = count_tokens(doc['text'])
# 按相关性排序
documents.sort(key=lambda x: x['score'], reverse=True)
# 贪心选择文档(尽可能多的高相关文档)
selected_docs = []
total_tokens = 0
for doc in documents:
if total_tokens + doc['token_count'] <= max_tokens:
selected_docs.append(doc)
total_tokens += doc['token_count']
else:
break
return format_context(selected_docs)
6.2 批量生成¶
async def batch_generate(
queries: list[str],
batch_size: int = 5
) -> list[str]:
"""批量生成回答"""
results = []
for i in range(0, len(queries), batch_size):
batch = queries[i:i+batch_size]
# 并发调用LLM
tasks = [generate_answer(q) for q in batch]
batch_results = await asyncio.gather(*tasks)
results.extend(batch_results)
# 避免速率限制
if i + batch_size < len(queries):
await asyncio.sleep(1)
return results
7. 性能测试¶
7.1 负载测试¶
# 使用Locust进行负载测试
from locust import HttpUser, task, between
class RAGUser(HttpUser):
wait_time = between(1, 3)
@task
def query(self):
query_text = "What is RAG?"
self.client.post(
"/query",
json={"text": query_text}
)
@task(3)
def health_check(self):
self.client.get("/health")
运行负载测试:
7.2 基准测试¶
import time
from statistics import mean, median
def benchmark(func, iterations: int = 100):
"""性能基准测试"""
times = []
results = []
for _ in range(iterations):
start = time.perf_counter()
result = func()
end = time.perf_counter()
times.append(end - start)
results.append(result)
return {
"mean": mean(times),
"median": median(times),
"min": min(times),
"max": max(times),
"p95": sorted(times)[int(iterations * 0.95)],
"p99": sorted(times)[int(iterations * 0.99)],
"throughput": iterations / sum(times)
}
# 使用
result = benchmark(lambda: query("test query"))
print(f"P95延迟: {result['p95']*1000:.2f}ms")
print(f"吞吐量: {result['throughput']:.2f} QPS")
8. 实战练习¶
练习1:实施缓存策略¶
任务: 1. 实现Redis缓存 2. 缓存查询结果和嵌入向量 3. 测试缓存命中率 4. 优化缓存失效策略
验证:
练习2:优化数据库查询¶
任务: 1. 添加适当索引 2. 优化慢查询 3. 实现批量操作 4. 配置连接池
验证:
-- 查看慢查询
SELECT query, mean_exec_time, calls
FROM pg_stat_statements
ORDER BY mean_exec_time DESC
LIMIT 10;
练习3:并发优化¶
任务: 1. 实现异步处理 2. 配置批处理 3. 测试并发性能 4. 找到最优并发数
验证:
9. 最佳实践¶
9.1 优化原则¶
- Measure First: 先测量再优化
- Optimize Hot Path: 优化热点路径
- Avoid Premature Optimization: 避免过早优化
- Trade-offs: 在速度、成本、质量间平衡
9.2 监控指标¶
- 缓存命中率: >80%
- P95延迟: <2s
- 错误率: <1%
- 资源使用: CPU <70%, 内存 <80%
10. 总结¶
关键要点¶
- 性能分析
- 识别瓶颈
- 使用profiling工具
-
监控关键指标
-
缓存策略
- 多层缓存
- LRU失效
-
批量操作
-
优化技术
- 异步处理
- 批处理
- 索引优化
下一步¶
- 学习安全实践(第22章)
- 最佳实践(第23章)
恭喜完成第21章! 🎉
你已经掌握RAG系统性能优化的核心技能!
下一步:第22章 - 安全实践