第22章:安全实践¶
保护RAG系统安全,防范常见安全威胁
📚 章节概述¶
本章将学习如何为RAG系统实施全面的安全措施,保护数据和系统安全。
学习目标¶
完成本章后,你将能够: - ✅ 理解RAG系统的安全威胁 - ✅ 实施API认证和授权 - ✅ 保护敏感数据 - ✅ 防范常见攻击 - ✅ 实施安全扫描 - ✅ 满足合规要求
预计时间¶
- 理论学习:50分钟
- 实践操作:60-90分钟
- 总计:约2-3小时
1. 安全威胁分析¶
1.1 常见威胁¶
OWASP Top 10: 1. 注入攻击:SQL注入、NoSQL注入 2. 认证失效:弱密码、会话管理不当 3. 敏感数据泄露:未加密的敏感信息 4. XML外部实体(XXE) 5. 访问控制失效 6. 安全配置错误 7. 跨站脚本攻击(XSS) 8. 不安全的反序列化 9. 使用含有已知漏洞的组件 10. **日志记录和监控不足
1.2 RAG系统特定威胁¶
数据安全: - 用户查询数据泄露 - 文档内容未授权访问 - API密钥泄露
LLM安全: - 提示注入攻击 - 数据提取攻击 - 模型越狱
系统安全: - DoS攻击 - 资源耗尽 - 恶意查询
2. API认证和授权¶
2.1 JWT认证¶
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from datetime import datetime, timedelta
import jwt
app = FastAPI()
# JWT配置
SECRET_KEY = "your-secret-key-here" # 从环境变量读取
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
security = HTTPBearer()
def create_access_token(data: dict, expires_delta: timedelta = None):
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""验证令牌"""
token = credentials.credentials
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
return payload
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
# 登录端点
@app.post("/token")
async def login(username: str, password: str):
"""获取访问令牌"""
# 验证用户(实际应该查询数据库)
if username == "admin" and password == "secret":
access_token = create_access_token(
data={"sub": username},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {"access_token": access_token, "token_type": "bearer"}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
# 受保护的端点
@app.get("/protected")
async def protected_route(payload: dict = Depends(verify_token)):
"""需要认证的端点"""
return {"message": f"Hello {payload['sub']}"}
2.2 API密钥认证¶
from fastapi import Header, HTTPException
async def verify_api_key(x_api_key: str = Header(...)):
"""验证API密钥"""
valid_keys = os.getenv("VALID_API_KEYS", "").split(",")
if x_api_key not in valid_keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
return x_api_key
@app.get("/query")
async def query(
text: str,
api_key: str = Depends(verify_api_key)
):
"""需要API密钥的查询端点"""
return {"result": "answer"}
2.3 OAuth2¶
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.oauth2 import OAuth2
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@app.get("/users/me")
async def read_users_me(token: str = Depends(oauth2_scheme)):
"""OAuth2认证"""
return {"token": token}
3. 数据保护¶
3.1 敏感数据加密¶
from cryptography.fernet import Fernet
import os
# 加密密钥(从环境变量读取)
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY")
cipher_suite = Fernet(ENCRYPTION_KEY)
def encrypt_data(data: str) -> bytes:
"""加密数据"""
return cipher_suite.encrypt(data.encode())
def decrypt_data(encrypted_data: bytes) -> str:
"""解密数据"""
return cipher_suite.decrypt(encrypted_data).decode()
# 使用示例
@api.post("/store-sensitive-data")
async def store_sensitive_data(data: str):
"""存储敏感数据"""
# 加密后再存储
encrypted = encrypt_data(data)
# 存储到数据库
await db.execute(
"INSERT INTO sensitive_data (encrypted_data) VALUES ($1)",
encrypted
)
return {"status": "stored"}
@api.get("/retrieve-sensitive-data")
async def retrieve_sensitive_data(id: int):
"""检索敏感数据"""
# 从数据库获取加密数据
encrypted = await db.fetch_one(
"SELECT encrypted_data FROM sensitive_data WHERE id = $1",
id
)
# 解密
decrypted = decrypt_data(encrypted['encrypted_data'])
return {"data": decrypted}
3.2 环境变量管理¶
from pydantic import BaseSettings
class Settings(BaseSettings):
"""应用配置"""
# API配置
API_KEY: str
API_SECRET: str
# 数据库配置
DATABASE_URL: str
# LLM API配置
OPENAI_API_KEY: str
# 加密配置
ENCRYPTION_KEY: str
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
settings = Settings()
# 使用
@app.get("/")
async def root():
return {
"api_configured": bool(settings.API_KEY),
"llm_configured": bool(settings.OPENAI_API_KEY)
}
3.3 Secret管理(K8s)¶
# secret.yaml
apiVersion: v1
kind: Secret
metadata:
name: rag-secrets
type: Opaque
stringData:
api-key: "your-api-key"
openai-api-key: "sk-xxxxx"
encryption-key: "your-encryption-key"
4. 输入验证和清理¶
4.1 输入验证¶
from pydantic import BaseModel, validator, Field
from typing import Optional
class QueryRequest(BaseModel):
"""查询请求模型"""
text: str = Field(..., min_length=1, max_length=2000)
top_k: Optional[int] = Field(5, ge=1, le=20)
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0)
@validator('text')
def validate_text(cls, v):
"""验证文本内容"""
# 检查危险字符
dangerous_patterns = ['<script>', 'javascript:', 'onerror=']
for pattern in dangerous_patterns:
if pattern.lower() in v.lower():
raise ValueError(f"Dangerous pattern detected: {pattern}")
# 限制长度
if len(v) > 2000:
raise ValueError("Text too long")
return v.strip()
@app.post("/query")
async def query(request: QueryRequest):
"""安全的查询端点"""
result = await rag_query(request.text, request.top_k)
return result
4.2 SQL注入防护¶
from sqlalchemy import text
# ❌ 不安全:字符串拼接
async def unsafe_query(user_id: str):
query = f"SELECT * FROM users WHERE id = '{user_id}'" # 危险!
return await db.execute(query)
# ✅ 安全:参数化查询
async def safe_query(user_id: str):
query = text("SELECT * FROM users WHERE id = :user_id")
return await db.execute(query, {"user_id": user_id})
4.3 XSS防护¶
from fastapi.responses import JSONResponse
import html
def sanitize_output(text: str) -> str:
"""清理输出文本"""
return html.escape(text)
@app.get("/search")
async def search(query: str):
"""搜索端点(防XSS)"""
results = await search_documents(query)
# 清理输出
safe_results = []
for result in results:
safe_results.append({
"title": sanitize_output(result['title']),
"content": sanitize_output(result['content'][:200])
})
return JSONResponse(content=safe_results)
5. 提示注入防护¶
5.1 识别攻击¶
import re
def detect_prompt_injection(user_input: str) -> bool:
"""检测提示注入攻击"""
# 常见注入模式
injection_patterns = [
r"ignore (all )?(previous|above) instructions?",
r"forget (everything|all instructions)",
r"override (your )?(programming|instructions)",
r"disregard (previous|above)",
r"instead of",
r"new instructions?",
r"system:",
r"\[SYSTEM\]",
r"<SYSTEM>",
r"<script>",
r"<iframe>",
]
user_input_lower = user_input.lower()
for pattern in injection_patterns:
if re.search(pattern, user_input_lower, re.IGNORECASE):
return True
return False
@app.post("/chat")
async def chat(message: str):
"""聊天端点(防注入)"""
if detect_prompt_injection(message):
raise HTTPException(
status_code=400,
detail="Invalid input detected"
)
response = await generate_response(message)
return response
5.2 输出过滤¶
def filter_llm_output(output: str) -> str:
"""过滤LLM输出"""
# 移除系统指令泄露
filtered = re.sub(
r"(As an AI|I am an AI|I'm an AI).+?(?=\n\n|$)",
"[Filtered]",
output,
flags=re.DOTALL
)
# 移除JSON/代码结构(可能包含指令)
filtered = re.sub(
r"\{.*?\"system.*?:.*?\}",
"[Filtered]",
filtered,
flags=re.DOTALL
)
return filtered
6. 速率限制¶
6.1 令牌桶算法¶
from fastapi import HTTPException, Request
from datetime import datetime, timedelta
import asyncio
class RateLimiter:
"""速率限制器"""
def __init__(self, max_calls: int, time_window: int):
self.max_calls = max_calls # 最大调用次数
self.time_window = time_window # 时间窗口(秒)
self.calls = {} # {key: [timestamp1, timestamp2, ...]}
async def is_allowed(self, key: str) -> bool:
"""检查是否允许调用"""
now = datetime.utcnow()
# 获取该key的调用记录
if key not in self.calls:
self.calls[key] = []
# 移除时间窗口外的记录
self.calls[key] = [
call_time for call_time in self.calls[key]
if (now - call_time).total_seconds() < self.time_window
]
# 检查是否超过限制
if len(self.calls[key]) >= self.max_calls:
return False
# 记录本次调用
self.calls[key].append(now)
return True
# 全局速率限制器
rate_limiter = RateLimiter(max_calls=100, time_window=60)
@app.post("/query")
async def query(request: Request, text: str):
"""带速率限制的查询"""
client_ip = request.client.host
if not await rate_limiter.is_allowed(client_ip):
raise HTTPException(
status_code=429,
detail="Too many requests"
)
result = await rag_query(text)
return result
6.2 Redis速率限制¶
import redis
redis_client = redis.Redis(host='localhost', port=6379, db=0)
async def check_rate_limit_redis(
key: str,
max_calls: int,
time_window: int
) -> bool:
"""使用Redis检查速率限制"""
pipe = redis_client.pipeline()
now = datetime.utcnow().timestamp()
# 移除时间窗口外的记录
pipe.zremrangebyscore(key, 0, now - time_window)
# 获取当前计数
pipe.zcard(key)
# 添加本次调用
pipe.zadd(key, {str(now): now})
# 设置过期时间
pipe.expire(key, time_window)
results = pipe.execute()
current_count = results[1]
return current_count < max_calls
7. 安全扫描¶
7.1 依赖漏洞扫描¶
# 使用pip-audit扫描Python依赖
pip install pip-audit
pip-audit
# 使用Safety扫描
pip install safety
safety check
7.2 Docker镜像扫描¶
# 使用Trivy扫描镜像
trivy image rag-api:latest
# 在CI/CD中集成
# .github/workflows/security.yml
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
scan-type: 'image'
image-ref: 'rag-api:latest'
format: 'sarif'
output: 'trivy-results.sarif'
7.3 代码安全扫描¶
8. 实战练习¶
练习1:实施认证¶
任务: 1. 添加JWT认证 2. 保护敏感端点 3. 实现token刷新
验证:
# 无token访问
curl http://localhost:8000/protected
# 返回401
# 有token访问
curl -H "Authorization: Bearer <token>" http://localhost:8000/protected
# 返回200
练习2:敏感数据加密¶
任务: 1. 加密敏感字段 2. 实现密钥管理 3. 测试加解密流程
验证:
练习3:速率限制¶
任务: 1. 实现速率限制 2. 测试限制效果 3. 配置不同限制策略
验证:
9. 安全最佳实践¶
9.1 安全检查清单¶
- ✅ 所有API端点都有认证
- ✅ 敏感数据已加密
- ✅ 输入已验证和清理
- ✅ 实施了速率限制
- ✅ 定期更新依赖
- ✅ 启用了日志和监控
- ✅ 配置了CORS策略
- ✅ 使用HTTPS
- ✅ 定期安全扫描
- ✅ 应急响应预案
9.2 安全配置¶
# CORS配置
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["https://yourdomain.com"], # 限制域名
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# 安全头
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["yourdomain.com", "*.yourdomain.com"]
)
10. 总结¶
关键要点¶
- 认证授权
- JWT认证
- API密钥
-
OAuth2
-
数据保护
- 敏感数据加密
- Secret管理
-
安全存储
-
防护措施
- 输入验证
- 速率限制
- 防注入攻击
下一步¶
- 学习最佳实践(第23章)
恭喜完成第22章! 🎉
你已经掌握RAG系统安全的各项技能!
下一步:第23章 - 最佳实践和案例分析