第11章:性能优化¶
RAG系统响应太慢?通过缓存、批处理、并发和内存优化,将系统吞吐量提升2-3倍,响应时间降低50%!
📚 学习目标¶
学完本章后,你将能够:
- 识别RAG系统的性能瓶颈
- 实现多层缓存策略
- 应用批处理优化
- 使用并发提升吞吐量
- 优化内存使用
- 将系统性能提升2-3倍
预计学习时间:3小时 难度等级:⭐⭐⭐☆☆
前置知识¶
在开始本章学习前,你需要具备:
- 完成模块1的基础RAG实现
- 理解高级RAG模式(第10章)
- 熟悉Python异步编程
环境要求: - Python >= 3.9 - redis(用于缓存) - profiling工具(py-spy, memory_profiler)
11.1 性能瓶颈分析¶
11.1.1 RAG系统性能剖析¶
典型RAG流程的时间分解:
用户查询 → RAG系统 → 答案
│ │
│ ├─ [1] 查询处理: ~50ms
│ ├─ [2] 嵌入查询: ~100ms
│ ├─ [3] 向量检索: ~200ms
│ ├─ [4] 文档加载: ~100ms
│ ├─ [5] LLM生成: ~2000ms ⚠️ 瓶颈
│ └─ [6] 后处理: ~50ms
│
总计: ~2500ms
各部分占比:
LLM生成: 80% ⚠️
向量检索: 8%
查询嵌入: 4%
文档加载: 4%
其他: 4%
11.1.2 性能分析工具¶
工具1:cProfile(CPU分析)
# 文件名:profiling_example.py
"""
性能分析示例
"""
import cProfile
import pstats
import io
from pstats import SortKey
def rag_pipeline(query: str) -> str:
"""
模拟RAG流程
"""
# 步骤1:查询处理
import time
time.sleep(0.05)
# 步骤2:嵌入查询
time.sleep(0.1)
# 步骤3:向量检索
time.sleep(0.2)
# 步骤4:文档加载
time.sleep(0.1)
# 步骤5:LLM生成
time.sleep(2.0)
# 步骤6:后处理
time.sleep(0.05)
return "答案"
# 性能分析
if __name__ == "__main__":
# 创建分析器
profiler = cProfile.Profile()
# 运行分析
print("运行性能分析...")
profiler.enable()
# 执行函数
result = rag_pipeline("Python性能优化")
profiler.disable()
# 输出结果
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats(SortKey.CUMULATIVE)
# 打印Top-10函数
ps.print_stats(10)
print(s.getvalue())
输出示例:
运行性能分析...
100 function calls in 2.501 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 2.501 2.501 profiling_example.py:11(rag_pipeline)
1 2.000 2.000 2.000 2.000 {built-in method time.sleep}
1 0.200 0.200 0.200 0.200 {built-in method time.sleep}
1 0.100 0.100 0.100 0.100 {built-in method time.sleep}
...
工具2:memory_profiler(内存分析)
# 安装
# pip install memory_profiler
from memory_profiler import profile
@profile
def rag_pipeline_memory(query: str) -> str:
"""
分析内存使用
"""
# 模拟数据加载
large_data = [i for i in range(1000000)] # ~8MB
# 模拟嵌入
embeddings = [[0.1] * 768 for _ in range(100)] # ~300KB
# 模拟文档
documents = ["文档内容" * 1000 for _ in range(100)] # ~1MB
return "答案"
if __name__ == "__main__":
result = rag_pipeline_memory("测试查询")
输出示例:
Line # Mem usage Increment Occurrences Line Contents
=============================================================
1 50.0 MiB 50.0 MiB 1 @profile
2 def rag_pipeline_memory(query):
3 58.0 MiB 8.0 MiB 1 large_data = [i for i in range(1000000)]
4 58.3 MiB 0.3 MiB 1 embeddings = [[0.1] * 768 for _ in range(100)]
5 59.3 MiB 1.0 MiB 1 documents = ["文档内容" * 1000 for _ in range(100)]
6 59.3 MiB 0.0 MiB 1 return "答案"
11.2 缓存策略¶
11.2.1 为什么需要缓存?¶
问题:重复计算浪费资源
场景:FAQ问答系统
问题列表:
- "如何重置密码?" (每天100次)
- "如何退款?" (每天50次)
- "如何联系客服?" (每天30次)
无缓存:
每次查询都需要:
- 嵌入查询: 100ms
- 向量检索: 200ms
- LLM生成: 2000ms
总计: 2300ms
有缓存:
首次查询: 2300ms
后续查询: 10ms (直接返回缓存)
节省: 99.6%的时间!
11.2.2 多层缓存架构¶
┌─────────────────────────────────────────────────────┐
│ 多层缓存架构 │
├─────────────────────────────────────────────────────┤
│ │
│ L1: 内存缓存 (LRU) │
│ ├─ 容量: 1000条 │
│ ├─ TTL: 1小时 │
│ ├─ 命中率: ~60% │
│ └─ 响应时间: <1ms │
│ ↓ (未命中) │
│ L2: Redis缓存 │
│ ├─ 容量: 10000条 │
│ ├─ TTL: 24小时 │
│ ├─ 命中率: ~30% │
│ └─ 响应时间: <10ms │
│ ↓ (未命中) │
│ L3: 向量数据库 + LLM │
│ ├─ 完整RAG流程 │
│ ├─ 响应时间: ~2000ms │
│ └─ 结果回填到L1和L2 │
│ │
└─────────────────────────────────────────────────────┘
整体命中率: 90% (60% + 30%)
平均响应时间: 0.6*1ms + 0.3*10ms + 0.1*2000ms = 210ms
提升: 10.7倍
11.2.3 缓存实现¶
L1: 内存缓存
# 文件名:cache_strategies.py
"""
缓存策略实现
"""
from typing import Any, Optional, Dict
from functools import lru_cache
import hashlib
import json
import time
class MemoryCache:
"""
内存缓存(LRU)
Args:
max_size: 最大缓存条目数
ttl: 生存时间(秒)
Example:
>>> cache = MemoryCache(max_size=1000, ttl=3600)
>>> cache.set("key", {"answer": "..."})
>>> result = cache.get("key")
"""
def __init__(self, max_size: int = 1000, ttl: int = 3600):
self.max_size = max_size
self.ttl = ttl
self.cache: Dict[str, Dict] = {}
def _generate_key(self, query: str, **kwargs) -> str:
"""
生成缓存键
"""
# 包含查询和参数
data = {'query': query, **kwargs}
data_str = json.dumps(data, sort_keys=True)
return hashlib.md5(data_str.encode()).hexdigest()
def get(self, query: str, **kwargs) -> Optional[Any]:
"""
获取缓存
"""
key = self._generate_key(query, **kwargs)
if key in self.cache:
entry = self.cache[key]
# 检查是否过期
if time.time() - entry['timestamp'] < self.ttl:
entry['hits'] += 1
return entry['value']
else:
# 过期,删除
del self.cache[key]
return None
def set(self, query: str, value: Any, **kwargs):
"""
设置缓存
"""
key = self._generate_key(query, **kwargs)
# 检查容量
if len(self.cache) >= self.max_size:
# LRU删除(简化版:删除第一个)
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
# 存储缓存
self.cache[key] = {
'value': value,
'timestamp': time.time(),
'hits': 0
}
def clear(self):
"""
清空缓存
"""
self.cache.clear()
def get_stats(self) -> Dict:
"""
获取统计信息
"""
total_hits = sum(entry['hits'] for entry in self.cache.values())
return {
'size': len(self.cache),
'max_size': self.max_size,
'total_hits': total_hits,
'hit_rate': total_hits / max(total_hits, 1)
}
# 使用示例
if __name__ == "__main__":
cache = MemoryCache(max_size=100, ttl=60)
# 设置缓存
cache.set("如何重置密码?", {
'answer': '请在设置页面点击"重置密码"...',
'sources': ['doc1', 'doc2']
})
# 获取缓存
result = cache.get("如何重置密码?")
print(f"缓存命中: {result}")
# 再次获取(命中次数+1)
result = cache.get("如何重置密码?")
print(f"再次命中: {result}")
# 统计信息
stats = cache.get_stats()
print(f"\n缓存统计: {stats}")
L2: Redis缓存
# 安装
# pip install redis
import redis
import json
import time
class RedisCache:
"""
Redis缓存
Args:
host: Redis主机
port: Redis端口
db: Redis数据库编号
ttl: 生存时间(秒)
Example:
>>> cache = RedisCache(host='localhost', port=6379, ttl=86400)
>>> cache.set("key", {"answer": "..."})
>>> result = cache.get("key")
"""
def __init__(self,
host: str = 'localhost',
port: int = 6379,
db: int = 0,
ttl: int = 86400):
self.ttl = ttl
self.redis_client = redis.Redis(
host=host,
port=port,
db=db,
decode_responses=True
)
def _generate_key(self, query: str, **kwargs) -> str:
"""
生成缓存键
"""
data = {'query': query, **kwargs}
data_str = json.dumps(data, sort_keys=True)
return f"rag:cache:{hashlib.md5(data_str.encode()).hexdigest()}"
def get(self, query: str, **kwargs) -> Optional[Any]:
"""
获取缓存
"""
key = self._generate_key(query, **kwargs)
try:
value = self.redis_client.get(key)
if value:
return json.loads(value)
except Exception as e:
print(f"Redis获取失败: {e}")
return None
def set(self, query: str, value: Any, **kwargs):
"""
设置缓存
"""
key = self._generate_key(query, **kwargs)
try:
value_json = json.dumps(value, ensure_ascii=False)
self.redis_client.setex(key, self.ttl, value_json)
except Exception as e:
print(f"Redis设置失败: {e}")
def clear(self):
"""
清空所有RAG缓存
"""
try:
# 删除所有rag:cache:*键
for key in self.redis_client.scan_iter("rag:cache:*"):
self.redis_client.delete(key)
except Exception as e:
print(f"Redis清空失败: {e}")
def get_stats(self) -> Dict:
"""
获取统计信息
"""
try:
# 统计键数量
keys = list(self.redis_client.scan_iter("rag:cache:*"))
return {
'size': len(keys),
'ttl': self.ttl
}
except Exception as e:
print(f"Redis统计失败: {e}")
return {'size': 0, 'ttl': self.ttl}
# 使用示例
if __name__ == "__main__":
# 注意:需要先启动Redis服务
# redis-server
try:
cache = RedisCache(host='localhost', port=6379, ttl=3600)
# 设置缓存
cache.set("如何退款?", {
'answer': '您可以在订单页面申请退款...',
'sources': ['doc3']
})
# 获取缓存
result = cache.get("如何退款?")
print(f"Redis缓存: {result}")
# 统计信息
stats = cache.get_stats()
print(f"\nRedis统计: {stats}")
except Exception as e:
print(f"Redis未运行,跳过示例: {e}")
完整的缓存装饰器
# 文件名:cached_rag.py
"""
带缓存的RAG系统
"""
from typing import Optional, Callable, Any
import functools
class CachedRAGSystem:
"""
带缓存的RAG系统
Args:
rag_pipeline: RAG处理函数
l1_cache: L1缓存(内存)
l2_cache: L2缓存(Redis),可选
Example:
>>> rag = CachedRAGSystem(rag_pipeline, l1_cache, l2_cache)
>>> result = rag.query("如何重置密码?")
>>> # 首次查询:完整流程
>>> result = rag.query("如何重置密码?")
>>> # 第二次:从L1缓存返回
"""
def __init__(self,
rag_pipeline: Callable,
l1_cache: MemoryCache,
l2_cache: Optional[RedisCache] = None):
self.rag_pipeline = rag_pipeline
self.l1_cache = l1_cache
self.l2_cache = l2_cache
# 统计信息
self.stats = {
'total_queries': 0,
'l1_hits': 0,
'l2_hits': 0,
'misses': 0
}
def query(self, query: str, **kwargs) -> Dict:
"""
查询(带缓存)
"""
self.stats['total_queries'] += 1
# L1缓存查找
result = self.l1_cache.get(query, **kwargs)
if result is not None:
self.stats['l1_hits'] += 1
return {
'answer': result['answer'],
'sources': result['sources'],
'cache': 'L1'
}
# L2缓存查找
if self.l2_cache:
result = self.l2_cache.get(query, **kwargs)
if result is not None:
self.stats['l2_hits'] += 1
# 回填L1缓存
self.l1_cache.set(query, result, **kwargs)
return {
'answer': result['answer'],
'sources': result['sources'],
'cache': 'L2'
}
# 缓存未命中,执行RAG流程
self.stats['misses'] += 1
print("缓存未命中,执行RAG流程...")
result = self.rag_pipeline(query, **kwargs)
# 存储到缓存
cache_data = {
'answer': result['answer'],
'sources': result.get('sources', [])
}
self.l1_cache.set(query, cache_data, **kwargs)
if self.l2_cache:
self.l2_cache.set(query, cache_data, **kwargs)
return {
'answer': result['answer'],
'sources': result.get('sources', []),
'cache': 'None'
}
def get_stats(self) -> Dict:
"""
获取缓存统计
"""
total = self.stats['total_queries']
l1_rate = self.stats['l1_hits'] / total if total > 0 else 0
l2_rate = self.stats['l2_hits'] / total if total > 0 else 0
miss_rate = self.stats['misses'] / total if total > 0 else 0
return {
'total_queries': total,
'l1_hits': self.stats['l1_hits'],
'l2_hits': self.stats['l2_hits'],
'misses': self.stats['misses'],
'l1_hit_rate': l1_rate,
'l2_hit_rate': l2_rate,
'overall_hit_rate': l1_rate + l2_rate,
'miss_rate': miss_rate
}
# 使用示例
if __name__ == "__main__":
# 模拟RAG流程
def mock_rag_pipeline(query: str, **kwargs) -> Dict:
print(f" [RAG] 处理查询: {query}")
import time
time.sleep(0.5) # 模拟耗时
return {
'answer': f'关于"{query}"的答案...',
'sources': ['doc1', 'doc2']
}
# 创建缓存
l1_cache = MemoryCache(max_size=100, ttl=60)
l2_cache = None # RedisCache() # 如果有Redis
# 创建带缓存的RAG系统
rag = CachedRAGSystem(mock_rag_pipeline, l1_cache, l2_cache)
# 测试查询
test_queries = [
"如何重置密码?",
"如何重置密码?", # 重复查询
"如何退款?",
"如何重置密码?", # 再次重复
"如何联系客服?",
"如何退款?", # 重复查询
]
print("\n" + "="*80)
print("带缓存的RAG系统测试")
print("="*80 + "\n")
for query in test_queries:
print(f"查询: {query}")
result = rag.query(query)
print(f" 缓存: {result['cache']}")
print(f" 答案: {result['answer'][:50]}...")
print()
# 统计信息
stats = rag.get_stats()
print("="*80)
print("缓存统计")
print("="*80)
print(f"总查询数: {stats['total_queries']}")
print(f"L1命中: {stats['l1_hits']} ({stats['l1_hit_rate']*100:.1f}%)")
print(f"L2命中: {stats['l2_hits']} ({stats['l2_hit_rate']*100:.1f}%)")
print(f"未命中: {stats['misses']} ({stats['miss_rate']*100:.1f}%)")
print(f"整体命中率: {stats['overall_hit_rate']*100:.1f}%")
print("\n性能提升:")
print(f" 节省时间: {stats['l1_hits'] + stats['l2_hits']} * 500ms = {(stats['l1_hits'] + stats['l2_hits']) * 0.5:.1f}秒")
11.3 批处理优化¶
11.3.1 为什么需要批处理?¶
问题:逐个处理效率低
场景:批量处理100个查询
无批处理:
for query in queries:
result = process(query) # 每个查询独立处理
总时间: 100 * 500ms = 50秒
批处理:
batch_process(queries) # 批量处理
总时间: 10秒
提升: 5倍
11.3.2 批处理实现¶
# 文件名:batch_processing.py
"""
批处理优化实现
"""
from typing import List, Dict, Any
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
class BatchProcessor:
"""
批处理器
Args:
batch_size: 批大小
process_function: 处理函数
max_workers: 最大工作线程数
Example:
>>> processor = BatchProcessor(
... batch_size=10,
... process_function=rag_query,
... max_workers=5
... )
>>> results = processor.process(queries)
"""
def __init__(self,
batch_size: int = 10,
process_function: Callable = None,
max_workers: int = 4):
self.batch_size = batch_size
self.process_function = process_function
self.max_workers = max_workers
def process(self, items: List[Any]) -> List[Any]:
"""
批量处理
Args:
items: 待处理的项目列表
Returns:
处理结果列表
"""
results = []
# 分批
batches = [
items[i:i + self.batch_size]
for i in range(0, len(items), self.batch_size)
]
print(f"总共 {len(items)} 个项目,分为 {len(batches)} 批")
# 并发处理批次
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交批次任务
future_to_batch = {
executor.submit(self._process_batch, batch): i
for i, batch in enumerate(batches)
}
# 收集结果
for future in as_completed(future_to_batch):
batch_idx = future_to_batch[future]
try:
batch_results = future.result()
results.extend(batch_results)
print(f"批次 {batch_idx + 1}/{len(batches)} 完成")
except Exception as e:
print(f"批次 {batch_idx + 1} 失败: {e}")
# 恢复原始顺序
return results
def _process_batch(self, batch: List[Any]) -> List[Any]:
"""
处理单个批次
"""
if self.process_function:
return [self.process_function(item) for item in batch]
return batch
# RAG批处理示例
class BatchRAGSystem:
"""
批量RAG系统
支持批量查询、批量嵌入等
"""
def __init__(self, retriever, llm_client):
self.retriever = retriever
self.llm_client = llm_client
def batch_query(self, queries: List[str],
batch_size: int = 10) -> List[Dict]:
"""
批量查询
Args:
queries: 查询列表
batch_size: 批大小
Returns:
答案列表
"""
results = []
# 批量嵌入(如果支持)
all_embeddings = self._batch_embed(queries, batch_size)
# 批量检索
for i in range(0, len(queries), batch_size):
batch_queries = queries[i:i + batch_size]
batch_embeddings = all_embeddings[i:i + batch_size]
# 检索
batch_results = self._batch_retrieve(batch_queries, batch_embeddings)
# 生成答案
for query, retrieved_docs in zip(batch_queries, batch_results):
answer = self._generate_answer(query, retrieved_docs)
results.append(answer)
return results
def _batch_embed(self, queries: List[str],
batch_size: int) -> List[List[float]]:
"""
批量嵌入查询
优化:批量调用嵌入模型
"""
all_embeddings = []
for i in range(0, len(queries), batch_size):
batch = queries[i:i + batch_size]
# 批量嵌入(假设retriever支持)
embeddings = self.retriever.embed_batch(batch)
all_embeddings.extend(embeddings)
return all_embeddings
def _batch_retrieve(self,
queries: List[str],
embeddings: List[List[float]]) -> List[List[Dict]]:
"""
批量检索
"""
results = []
for query, embedding in zip(queries, embeddings):
docs = self.retriever.retrieve_by_embedding(embedding)
results.append(docs)
return results
def _generate_answer(self, query: str, docs: List[Dict]) -> Dict:
"""
生成答案
"""
context = "\n".join([doc['text'] for doc in docs[:3]])
# 调用LLM
prompt = f"基于以下信息回答问题:\n\n{context}\n\n问题:{query}\n答案:"
# 模拟LLM调用
return {
'answer': f'基于{len(docs)}个文档生成的答案...',
'sources': [doc['id'] for doc in docs]
}
# 使用示例
if __name__ == "__main__":
# 模拟RAG系统
class MockRAGSystem:
def query(self, query: str) -> Dict:
time.sleep(0.1) # 模拟处理时间
return {'query': query, 'answer': f'答案: {query}'}
rag = MockRAGSystem()
# 测试查询
queries = [f"查询{i+1}" for i in range(20)]
print("\n" + "="*80)
print("批处理性能测试")
print("="*80 + "\n")
# 无批处理
print("方式1: 逐个处理")
start = time.time()
results_single = [rag.query(q) for q in queries]
time_single = time.time() - start
print(f" 时间: {time_single:.2f}秒\n")
# 批处理
print("方式2: 批处理")
processor = BatchProcessor(
batch_size=5,
process_function=rag.query,
max_workers=4
)
start = time.time()
results_batch = processor.process(queries)
time_batch = time.time() - start
print(f" 时间: {time_batch:.2f}秒\n")
# 对比
print("="*80)
print("性能对比")
print("="*80)
print(f"逐个处理: {time_single:.2f}秒")
print(f"批处理: {time_batch:.2f}秒")
print(f"提升: {time_single/time_batch:.2f}x")
print(f"节省: {((time_single - time_batch)/time_single * 100):.1f}%")
11.4 并发优化¶
11.4.1 Python并发模型¶
并发模型对比:
┌─────────────────────────────────────────────────────┐
│ 模型 适用场景 GIL影响 │
├─────────────────────────────────────────────────────┤
│ 多线程 I/O密集型 受限制 │
│ (threading) - 网络请求 │
│ - 文件读写 │
│ - 数据库查询 │
├─────────────────────────────────────────────────────┤
│ 多进程 CPU密集型 不受限制 │
│ (multiprocessing) - 计算密集 │
│ - 数据处理 │
│ - 模型推理 │
├─────────────────────────────────────────────────────┤
│ 异步I/O 高并发I/O 不受限制 │
│ (asyncio) - 大量网络请求 │
│ - WebSocket │
│ - 实时数据流 │
└─────────────────────────────────────────────────────┘
11.4.2 并发实现¶
多线程(适合I/O密集)
# 文件名:concurrent_rag.py
"""
并发RAG实现
"""
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import asyncio
import time
from typing import List, Dict, Callable
class ConcurrentRAGSystem:
"""
并发RAG系统
支持多线程处理多个查询
"""
def __init__(self, rag_pipeline: Callable, max_workers: int = 10):
self.rag_pipeline = rag_pipeline
self.max_workers = max_workers
def query_concurrent(self, queries: List[str]) -> List[Dict]:
"""
并发查询(多线程)
Args:
queries: 查询列表
Returns:
答案列表(保持原始顺序)
"""
results = [None] * len(queries)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有任务
future_to_index = {
executor.submit(self.rag_pipeline, query): idx
for idx, query in enumerate(queries)
}
# 收集结果
for future in as_completed(future_to_index):
idx = future_to_index[future]
try:
results[idx] = future.result()
print(f"查询 {idx + 1}/{len(queries)} 完成")
except Exception as e:
print(f"查询 {idx + 1} 失败: {e}")
results[idx] = {'error': str(e)}
return results
def query_sequential(self, queries: List[str]) -> List[Dict]:
"""
顺序查询(对比)
"""
results = []
for query in queries:
result = self.rag_pipeline(query)
results.append(result)
return results
# 异步RAG系统
class AsyncRAGSystem:
"""
异步RAG系统
使用asyncio实现高并发
"""
def __init__(self):
pass
async def query_async(self, query: str) -> Dict:
"""
异步查询
"""
# 异步嵌入
embedding = await self._embed_async(query)
# 异步检索
docs = await self._retrieve_async(embedding)
# 异步生成答案
answer = await self._generate_async(query, docs)
return answer
async def _embed_async(self, query: str) -> List[float]:
"""
异步嵌入(模拟)
"""
await asyncio.sleep(0.1) # 模拟I/O
return [0.1] * 768
async def _retrieve_async(self, embedding: List[float]) -> List[Dict]:
"""
异步检索(模拟)
"""
await asyncio.sleep(0.2) # 模拟I/O
return [{'id': 'doc1', 'text': '相关文档'}]
async def _generate_async(self, query: str, docs: List[Dict]) -> Dict:
"""
异步生成答案(模拟)
"""
await asyncio.sleep(1.0) # 模拟LLM调用
return {'answer': f'异步答案: {query}'}
async def batch_query_async(self, queries: List[str]) -> List[Dict]:
"""
批量异步查询
"""
tasks = [self.query_async(q) for q in queries]
return await asyncio.gather(*tasks)
# 使用示例
if __name__ == "__main__":
# 模拟RAG流程
def mock_rag_pipeline(query: str) -> Dict:
"""模拟RAG流程(I/O密集)"""
time.sleep(0.5) # 模拟I/O等待
return {'query': query, 'answer': f'答案: {query}'}
# 测试查询
queries = [f"查询{i+1}" for i in range(10)]
print("\n" + "="*80)
print("并发性能测试")
print("="*80 + "\n")
# 方式1:顺序处理
print("方式1: 顺序处理")
rag = ConcurrentRAGSystem(mock_rag_pipeline, max_workers=1)
start = time.time()
results_sequential = rag.query_sequential(queries)
time_sequential = time.time() - start
print(f" 时间: {time_sequential:.2f}秒")
print(f" 平均每个查询: {time_sequential/len(queries):.2f}秒\n")
# 方式2:并发处理(多线程)
print("方式2: 并发处理(多线程)")
rag = ConcurrentRAGSystem(mock_rag_pipeline, max_workers=10)
start = time.time()
results_concurrent = rag.query_concurrent(queries)
time_concurrent = time.time() - start
print(f" 时间: {time_concurrent:.2f}秒")
print(f" 平均每个查询: {time_concurrent/len(queries):.2f}秒\n")
# 方式3:异步处理
print("方式3: 异步处理(asyncio)")
async_rag = AsyncRAGSystem()
start = time.time()
results_async = asyncio.run(async_rag.batch_query_async(queries))
time_async = time.time() - start
print(f" 时间: {time_async:.2f}秒")
print(f" 平均每个查询: {time_async/len(queries):.2f}秒\n")
# 对比
print("="*80)
print("性能对比")
print("="*80)
print(f"顺序处理: {time_sequential:.2f}秒 (基准)")
print(f"多线程: {time_concurrent:.2f}秒 ({time_sequential/time_concurrent:.2f}x)")
print(f"异步: {time_async:.2f}秒 ({time_sequential/time_async:.2f}x)")
11.5 内存优化¶
11.5.1 内存问题识别¶
常见内存问题:
问题1:文档全部加载到内存
现象:内存占用高
解决:流式处理,按需加载
问题2:嵌入向量占用大
现象:100万文档 * 768维 * 4字节 = 2.9GB
解决:使用量化(float16 -> 2字节)
问题3:缓存无限增长
现象:内存持续增长
解决:LRU淘汰,设置大小上限
问题4:LLM上下文累积
现象:长对话导致内存增长
解决:限制上下文长度,定期清理
11.5.2 优化策略¶
策略1:生成器
# 文件名:memory_optimization.py
"""
内存优化实现
"""
def load_documents_lazy(file_path: str):
"""
惰性加载文档(生成器)
优点:不会一次性加载所有文档到内存
Yield:
文档文本
"""
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
yield line.strip()
# 对比:一次性加载 vs 惰性加载
def load_documents_all(file_path: str) -> List[str]:
"""
一次性加载所有文档
问题:大文件会导致内存溢出
"""
with open(file_path, 'r', encoding='utf-8') as f:
return [line.strip() for line in f]
# 使用示例
if __name__ == "__main__":
# 模拟大文件
import tempfile
import os
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt')
for i in range(100000):
temp_file.write(f"文档{i}的内容...\n")
temp_file_path = temp_file.name
temp_file.close()
print("\n" + "="*80)
print("内存优化对比")
print("="*80 + "\n")
# 方式1:一次性加载
print("方式1: 一次性加载")
import tracemalloc
tracemalloc.start()
docs_all = load_documents_all(temp_file_path)
current, peak = tracemalloc.get_traced_memory()
print(f" 当前内存: {current / 1024 / 1024:.2f} MB")
print(f" 峰值内存: {peak / 1024 / 1024:.2f} MB")
print(f" 文档数量: {len(docs_all)}\n")
tracemalloc.stop()
# 方式2:惰性加载
print("方式2: 惰性加载(生成器)")
tracemalloc.start()
docs_lazy = load_documents_lazy(temp_file_path)
count = 0
for doc in docs_lazy:
count += 1
if count >= 10: # 只处理前10个
break
current, peak = tracemalloc.get_traced_memory()
print(f" 当前内存: {current / 1024 / 1024:.2f} MB")
print(f" 峰值内存: {peak / 1024 / 1024:.2f} MB")
print(f" 处理文档: {count} 个\n")
tracemalloc.stop()
# 清理
os.unlink(temp_file_path)
策略2:向量量化
import numpy as np
class QuantizedVectorStore:
"""
量化的向量存储
使用float16代替float32,节省50%内存
"""
def __init__(self):
self.vectors = None
self.dtype = np.float16 # 使用半精度浮点数
def add_vectors(self, vectors: np.ndarray):
"""
添加向量(自动量化)
"""
if vectors.dtype != self.dtype:
vectors = vectors.astype(self.dtype)
self.vectors = vectors
def get_memory_usage(self) -> Dict:
"""
获取内存使用情况
"""
if self.vectors is None:
return {'memory_mb': 0}
memory_mb = self.vectors.nbytes / 1024 / 1024
return {
'memory_mb': memory_mb,
'num_vectors': len(self.vectors),
'dimension': self.vectors.shape[1],
'dtype': str(self.vectors.dtype)
}
# 对比:float32 vs float16
if __name__ == "__main__":
print("\n" + "="*80)
print("向量量化对比")
print("="*80 + "\n")
# 生成示例向量
num_vectors = 100000
dimension = 768
print(f"向量数量: {num_vectors}")
print(f"向量维度: {dimension}\n")
# float32
vectors_f32 = np.random.randn(num_vectors, dimension).astype(np.float32)
memory_f32 = vectors_f32.nbytes / 1024 / 1024
print(f"Float32:")
print(f" 内存: {memory_f32:.2f} MB")
print(f" 每个: {vectors_f32[0].nbytes / 1024:.2f} KB\n")
# float16
vectors_f16 = vectors_f32.astype(np.float16)
memory_f16 = vectors_f16.nbytes / 1024 / 1024
print(f"Float16:")
print(f" 内存: {memory_f16:.2f} MB")
print(f" 每个: {vectors_f16[0].nbytes / 1024:.2f} KB\n")
print(f"节省: {(memory_f32 - memory_f16) / memory_f32 * 100:.1f}%")
策略3:上下文管理
class ContextManagedRAG:
"""
带上下文管理的RAG系统
自动清理不再使用的资源
"""
def __init__(self, max_context_length: int = 10):
self.max_context_length = max_context_length
self.conversation_history = []
def query(self, query: str) -> Dict:
"""
查询(自动管理上下文)
"""
# 添加到历史
self.conversation_history.append({
'query': query,
'timestamp': time.time()
})
# 限制历史长度
if len(self.conversation_history) > self.max_context_length:
# 移除最旧的记录
removed = self.conversation_history.pop(0)
print(f"清理旧上下文: {removed['query'][:30]}...")
# 获取最近的历史
recent_history = self.conversation_history[-5:]
# 处理查询
return {'answer': f'基于{len(recent_history)}轮上下文的答案'}
def clear_history(self):
"""
清空历史
"""
self.conversation_history.clear()
print("上下文已清空")
11.6 综合优化案例¶
11.6.1 完整的优化方案¶
# 文件名:optimized_rag_system.py
"""
综合优化的RAG系统
集成:缓存 + 批处理 + 并发 + 内存优化
"""
from typing import List, Dict, Optional
import time
from concurrent.futures import ThreadPoolExecutor
class OptimizedRAGSystem:
"""
优化的RAG系统
特性:
1. 多层缓存(L1内存 + L2Redis)
2. 批处理
3. 并发处理
4. 内存优化
性能目标:
- 缓存命中率 > 80%
- 吞吐量 > 100 queries/s
- 平均响应时间 < 100ms(命中缓存)
"""
def __init__(self,
rag_pipeline,
l1_cache_size: int = 1000,
max_workers: int = 10):
# RAG流程
self.rag_pipeline = rag_pipeline
# 缓存
self.l1_cache = MemoryCache(max_size=l1_cache_size, ttl=3600)
self.l2_cache = None # 可选:RedisCache()
# 并发
self.max_workers = max_workers
# 统计
self.stats = {
'total_queries': 0,
'cache_hits': 0,
'cache_misses': 0,
'total_time': 0.0
}
def query(self, query: str) -> Dict:
"""
查询(带缓存)
"""
start_time = time.time()
self.stats['total_queries'] += 1
# L1缓存查找
result = self.l1_cache.get(query)
if result is not None:
self.stats['cache_hits'] += 1
self.stats['total_time'] += time.time() - start_time
return {**result, 'cache': 'L1'}
# L2缓存查找(如果配置)
if self.l2_cache:
result = self.l2_cache.get(query)
if result is not None:
self.stats['cache_hits'] += 1
# 回填L1
self.l1_cache.set(query, result)
self.stats['total_time'] += time.time() - start_time
return {**result, 'cache': 'L2'}
# 缓存未命中,执行RAG
self.stats['cache_misses'] += 1
result = self.rag_pipeline(query)
# 存储到缓存
cache_data = {
'answer': result['answer'],
'sources': result.get('sources', [])
}
self.l1_cache.set(query, cache_data)
if self.l2_cache:
self.l2_cache.set(query, cache_data)
self.stats['total_time'] += time.time() - start_time
return {**result, 'cache': 'None'}
def batch_query(self, queries: List[str],
batch_size: int = 10) -> List[Dict]:
"""
批量查询(并发)
"""
results = [None] * len(queries)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_index = {
executor.submit(self.query, query): idx
for idx, query in enumerate(queries)
}
for future in future_to_index:
idx = future_to_index[future]
results[idx] = future.result()
return results
def get_stats(self) -> Dict:
"""
获取统计信息
"""
total = self.stats['total_queries']
hits = self.stats['cache_hits']
misses = self.stats['cache_misses']
avg_time = self.stats['total_time'] / total if total > 0 else 0
hit_rate = hits / total if total > 0 else 0
return {
'total_queries': total,
'cache_hits': hits,
'cache_misses': misses,
'hit_rate': hit_rate,
'avg_response_time_ms': avg_time * 1000,
'throughput_qps': total / max(self.stats['total_time'], 0.001)
}
# 使用示例
if __name__ == "__main__":
# 模拟RAG流程
def mock_rag(query: str) -> Dict:
time.sleep(0.2) # 模拟处理
return {
'answer': f'答案: {query}',
'sources': ['doc1']
}
# 创建优化的RAG系统
rag = OptimizedRAGSystem(
rag_pipeline=mock_rag,
l1_cache_size=1000,
max_workers=10
)
# 测试:包含重复查询
test_queries = [
"查询1", "查询2", "查询1", # 重复
"查询3", "查询2", # 重复
"查询4", "查询5", "查询1", # 重复
"查询6"
]
print("\n" + "="*80)
print("优化RAG系统测试")
print("="*80 + "\n")
# 单个查询
for query in test_queries:
result = rag.query(query)
print(f"{query} → 缓存: {result['cache']}")
# 统计
stats = rag.get_stats()
print("\n" + "="*80)
print("性能统计")
print("="*80)
print(f"总查询数: {stats['total_queries']}")
print(f"缓存命中: {stats['cache_hits']}")
print(f"缓存未命中: {stats['cache_misses']}")
print(f"命中率: {stats['hit_rate']*100:.1f}%")
print(f"平均响应时间: {stats['avg_response_time_ms']:.2f} ms")
print(f"吞吐量: {stats['throughput_qps']:.2f} QPS")
练习题¶
练习1:基础练习 - 实现内存缓存¶
题目:实现一个LRU缓存
要求: 1. 支持get和set操作 2. 容量满时自动淘汰最久未使用的项 3. 支持TTL(过期时间) 4. 提供统计信息
提示:
- 使用collections.OrderedDict维护访问顺序
- 记录每个key的访问时间戳
练习2:进阶练习 - 批处理优化¶
题目:优化批量查询性能
要求: 1. 实现批量查询接口 2. 支持并发处理 3. 对比顺序和批处理的性能 4. 达到3x以上加速
练习3:挑战项目 - 完整的优化系统¶
项目描述:构建一个生产级的高性能RAG系统
功能需求: 1. ✅ 多层缓存 2. ✅ 批处理 3. ✅ 并发处理 4. ✅ 内存优化 5. ✅ 性能监控
性能目标: - 缓存命中率 > 80% - 吞吐量 > 50 QPS - P95延迟 < 500ms - 内存占用 < 4GB(百万级文档)
总结¶
本章要点回顾¶
- 性能瓶颈识别
- LLM生成是最大瓶颈(~80%时间)
- 使用profiling工具定位问题
-
重点关注热点函数
-
缓存策略
- L1内存缓存:快速,容量小
- L2 Redis缓存:稍慢,容量大
-
多层缓存:90%+命中率
-
批处理优化
- 批量嵌入:减少API调用
- 批量检索:提升吞吐量
-
批量生成:利用并行能力
-
并发优化
- 多线程:适合I/O密集
- 多进程:适合CPU密集
-
异步:适合高并发
-
内存优化
- 生成器:惰性加载
- 向量量化:节省50%内存
- 上下文管理:限制历史长度
学习检查清单¶
- 能够使用profiling工具分析性能
- 理解多层缓存架构
- 掌握批处理优化方法
- 能够应用并发处理
- 理解内存优化策略
- 能够构建高性能RAG系统
下一步学习¶
- 下一章:第12章:综合项目优化
- 相关章节:
- 第9章:混合检索与重排序
- 第10章:高级RAG模式
- 扩展阅读:
- Python性能优化: https://docs.python.org/3/library/profile.html
- Redis缓存: https://redis.io/docs/manual/patterns/caching/
返回目录 | 上一章 | 下一章
本章结束
有任何问题或建议?欢迎提交Issue或PR到教程仓库!