跳转至

第22章:安全实践

保护RAG系统安全,防范常见安全威胁


📚 章节概述

本章将学习如何为RAG系统实施全面的安全措施,保护数据和系统安全。

学习目标

完成本章后,你将能够: - ✅ 理解RAG系统的安全威胁 - ✅ 实施API认证和授权 - ✅ 保护敏感数据 - ✅ 防范常见攻击 - ✅ 实施安全扫描 - ✅ 满足合规要求

预计时间

  • 理论学习:50分钟
  • 实践操作:60-90分钟
  • 总计:约2-3小时

1. 安全威胁分析

1.1 常见威胁

OWASP Top 10: 1. 注入攻击:SQL注入、NoSQL注入 2. 认证失效:弱密码、会话管理不当 3. 敏感数据泄露:未加密的敏感信息 4. XML外部实体(XXE) 5. 访问控制失效 6. 安全配置错误 7. 跨站脚本攻击(XSS) 8. 不安全的反序列化 9. 使用含有已知漏洞的组件 10. **日志记录和监控不足

1.2 RAG系统特定威胁

数据安全: - 用户查询数据泄露 - 文档内容未授权访问 - API密钥泄露

LLM安全: - 提示注入攻击 - 数据提取攻击 - 模型越狱

系统安全: - DoS攻击 - 资源耗尽 - 恶意查询


2. API认证和授权

2.1 JWT认证

from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from datetime import datetime, timedelta
import jwt

app = FastAPI()

# JWT配置
SECRET_KEY = "your-secret-key-here"  # 从环境变量读取
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

security = HTTPBearer()

def create_access_token(data: dict, expires_delta: timedelta = None):
    """创建访问令牌"""
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
    """验证令牌"""
    token = credentials.credentials
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid authentication credentials",
            )
        return payload
    except jwt.PyJWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authentication credentials",
        )

# 登录端点
@app.post("/token")
async def login(username: str, password: str):
    """获取访问令牌"""
    # 验证用户(实际应该查询数据库)
    if username == "admin" and password == "secret":
        access_token = create_access_token(
            data={"sub": username},
            expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
        )
        return {"access_token": access_token, "token_type": "bearer"}
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
        )

# 受保护的端点
@app.get("/protected")
async def protected_route(payload: dict = Depends(verify_token)):
    """需要认证的端点"""
    return {"message": f"Hello {payload['sub']}"}

2.2 API密钥认证

from fastapi import Header, HTTPException

async def verify_api_key(x_api_key: str = Header(...)):
    """验证API密钥"""
    valid_keys = os.getenv("VALID_API_KEYS", "").split(",")

    if x_api_key not in valid_keys:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid API key"
        )
    return x_api_key

@app.get("/query")
async def query(
    text: str,
    api_key: str = Depends(verify_api_key)
):
    """需要API密钥的查询端点"""
    return {"result": "answer"}

2.3 OAuth2

from fastapi.security import OAuth2PasswordBearer
from fastapi.security.oauth2 import OAuth2

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

@app.get("/users/me")
async def read_users_me(token: str = Depends(oauth2_scheme)):
    """OAuth2认证"""
    return {"token": token}

3. 数据保护

3.1 敏感数据加密

from cryptography.fernet import Fernet
import os

# 加密密钥(从环境变量读取)
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY")
cipher_suite = Fernet(ENCRYPTION_KEY)

def encrypt_data(data: str) -> bytes:
    """加密数据"""
    return cipher_suite.encrypt(data.encode())

def decrypt_data(encrypted_data: bytes) -> str:
    """解密数据"""
    return cipher_suite.decrypt(encrypted_data).decode()

# 使用示例
@api.post("/store-sensitive-data")
async def store_sensitive_data(data: str):
    """存储敏感数据"""
    # 加密后再存储
    encrypted = encrypt_data(data)

    # 存储到数据库
    await db.execute(
        "INSERT INTO sensitive_data (encrypted_data) VALUES ($1)",
        encrypted
    )

    return {"status": "stored"}

@api.get("/retrieve-sensitive-data")
async def retrieve_sensitive_data(id: int):
    """检索敏感数据"""
    # 从数据库获取加密数据
    encrypted = await db.fetch_one(
        "SELECT encrypted_data FROM sensitive_data WHERE id = $1",
        id
    )

    # 解密
    decrypted = decrypt_data(encrypted['encrypted_data'])

    return {"data": decrypted}

3.2 环境变量管理

from pydantic import BaseSettings

class Settings(BaseSettings):
    """应用配置"""

    # API配置
    API_KEY: str
    API_SECRET: str

    # 数据库配置
    DATABASE_URL: str

    # LLM API配置
    OPENAI_API_KEY: str

    # 加密配置
    ENCRYPTION_KEY: str

    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"

settings = Settings()

# 使用
@app.get("/")
async def root():
    return {
        "api_configured": bool(settings.API_KEY),
        "llm_configured": bool(settings.OPENAI_API_KEY)
    }

3.3 Secret管理(K8s)

# secret.yaml
apiVersion: v1
kind: Secret
metadata:
  name: rag-secrets
type: Opaque
stringData:
  api-key: "your-api-key"
  openai-api-key: "sk-xxxxx"
  encryption-key: "your-encryption-key"
# 在K8s中读取Secret
import os

api_key = os.getenv("API_KEY")  # 从Secret挂载的环境变量读取

4. 输入验证和清理

4.1 输入验证

from pydantic import BaseModel, validator, Field
from typing import Optional

class QueryRequest(BaseModel):
    """查询请求模型"""

    text: str = Field(..., min_length=1, max_length=2000)
    top_k: Optional[int] = Field(5, ge=1, le=20)
    temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0)

    @validator('text')
    def validate_text(cls, v):
        """验证文本内容"""
        # 检查危险字符
        dangerous_patterns = ['<script>', 'javascript:', 'onerror=']
        for pattern in dangerous_patterns:
            if pattern.lower() in v.lower():
                raise ValueError(f"Dangerous pattern detected: {pattern}")

        # 限制长度
        if len(v) > 2000:
            raise ValueError("Text too long")

        return v.strip()

@app.post("/query")
async def query(request: QueryRequest):
    """安全的查询端点"""
    result = await rag_query(request.text, request.top_k)
    return result

4.2 SQL注入防护

from sqlalchemy import text

# ❌ 不安全:字符串拼接
async def unsafe_query(user_id: str):
    query = f"SELECT * FROM users WHERE id = '{user_id}'"  # 危险!
    return await db.execute(query)

# ✅ 安全:参数化查询
async def safe_query(user_id: str):
    query = text("SELECT * FROM users WHERE id = :user_id")
    return await db.execute(query, {"user_id": user_id})

4.3 XSS防护

from fastapi.responses import JSONResponse
import html

def sanitize_output(text: str) -> str:
    """清理输出文本"""
    return html.escape(text)

@app.get("/search")
async def search(query: str):
    """搜索端点(防XSS)"""
    results = await search_documents(query)

    # 清理输出
    safe_results = []
    for result in results:
        safe_results.append({
            "title": sanitize_output(result['title']),
            "content": sanitize_output(result['content'][:200])
        })

    return JSONResponse(content=safe_results)

5. 提示注入防护

5.1 识别攻击

import re

def detect_prompt_injection(user_input: str) -> bool:
    """检测提示注入攻击"""

    # 常见注入模式
    injection_patterns = [
        r"ignore (all )?(previous|above) instructions?",
        r"forget (everything|all instructions)",
        r"override (your )?(programming|instructions)",
        r"disregard (previous|above)",
        r"instead of",
        r"new instructions?",
        r"system:",
        r"\[SYSTEM\]",
        r"<SYSTEM>",
        r"<script>",
        r"<iframe>",
    ]

    user_input_lower = user_input.lower()

    for pattern in injection_patterns:
        if re.search(pattern, user_input_lower, re.IGNORECASE):
            return True

    return False

@app.post("/chat")
async def chat(message: str):
    """聊天端点(防注入)"""
    if detect_prompt_injection(message):
        raise HTTPException(
            status_code=400,
            detail="Invalid input detected"
        )

    response = await generate_response(message)
    return response

5.2 输出过滤

def filter_llm_output(output: str) -> str:
    """过滤LLM输出"""

    # 移除系统指令泄露
    filtered = re.sub(
        r"(As an AI|I am an AI|I'm an AI).+?(?=\n\n|$)",
        "[Filtered]",
        output,
        flags=re.DOTALL
    )

    # 移除JSON/代码结构(可能包含指令)
    filtered = re.sub(
        r"\{.*?\"system.*?:.*?\}",
        "[Filtered]",
        filtered,
        flags=re.DOTALL
    )

    return filtered

6. 速率限制

6.1 令牌桶算法

from fastapi import HTTPException, Request
from datetime import datetime, timedelta
import asyncio

class RateLimiter:
    """速率限制器"""

    def __init__(self, max_calls: int, time_window: int):
        self.max_calls = max_calls  # 最大调用次数
        self.time_window = time_window  # 时间窗口(秒)
        self.calls = {}  # {key: [timestamp1, timestamp2, ...]}

    async def is_allowed(self, key: str) -> bool:
        """检查是否允许调用"""
        now = datetime.utcnow()

        # 获取该key的调用记录
        if key not in self.calls:
            self.calls[key] = []

        # 移除时间窗口外的记录
        self.calls[key] = [
            call_time for call_time in self.calls[key]
            if (now - call_time).total_seconds() < self.time_window
        ]

        # 检查是否超过限制
        if len(self.calls[key]) >= self.max_calls:
            return False

        # 记录本次调用
        self.calls[key].append(now)
        return True

# 全局速率限制器
rate_limiter = RateLimiter(max_calls=100, time_window=60)

@app.post("/query")
async def query(request: Request, text: str):
    """带速率限制的查询"""
    client_ip = request.client.host

    if not await rate_limiter.is_allowed(client_ip):
        raise HTTPException(
            status_code=429,
            detail="Too many requests"
        )

    result = await rag_query(text)
    return result

6.2 Redis速率限制

import redis

redis_client = redis.Redis(host='localhost', port=6379, db=0)

async def check_rate_limit_redis(
    key: str,
    max_calls: int,
    time_window: int
) -> bool:
    """使用Redis检查速率限制"""

    pipe = redis_client.pipeline()
    now = datetime.utcnow().timestamp()

    # 移除时间窗口外的记录
    pipe.zremrangebyscore(key, 0, now - time_window)

    # 获取当前计数
    pipe.zcard(key)

    # 添加本次调用
    pipe.zadd(key, {str(now): now})

    # 设置过期时间
    pipe.expire(key, time_window)

    results = pipe.execute()
    current_count = results[1]

    return current_count < max_calls

7. 安全扫描

7.1 依赖漏洞扫描

# 使用pip-audit扫描Python依赖
pip install pip-audit
pip-audit

# 使用Safety扫描
pip install safety
safety check

7.2 Docker镜像扫描

# 使用Trivy扫描镜像
trivy image rag-api:latest

# 在CI/CD中集成
# .github/workflows/security.yml
- name: Run Trivy vulnerability scanner
  uses: aquasecurity/trivy-action@master
  with:
    scan-type: 'image'
    image-ref: 'rag-api:latest'
    format: 'sarif'
    output: 'trivy-results.sarif'

7.3 代码安全扫描

# 使用Bandit扫描Python代码
pip install bandit
bandit -r app/

# 使用Semgrep
semgrep --config=auto app/

8. 实战练习

练习1:实施认证

任务: 1. 添加JWT认证 2. 保护敏感端点 3. 实现token刷新

验证

# 无token访问
curl http://localhost:8000/protected
# 返回401

# 有token访问
curl -H "Authorization: Bearer <token>" http://localhost:8000/protected
# 返回200


练习2:敏感数据加密

任务: 1. 加密敏感字段 2. 实现密钥管理 3. 测试加解密流程

验证

# 测试加密
encrypted = encrypt_data("secret")
assert decrypt_data(encrypted) == "secret"


练习3:速率限制

任务: 1. 实现速率限制 2. 测试限制效果 3. 配置不同限制策略

验证

# 发送100个请求
for i in {1..100}; do curl http://localhost:8000/query; done

# 第101个请求应该返回429


9. 安全最佳实践

9.1 安全检查清单

  • ✅ 所有API端点都有认证
  • ✅ 敏感数据已加密
  • ✅ 输入已验证和清理
  • ✅ 实施了速率限制
  • ✅ 定期更新依赖
  • ✅ 启用了日志和监控
  • ✅ 配置了CORS策略
  • ✅ 使用HTTPS
  • ✅ 定期安全扫描
  • ✅ 应急响应预案

9.2 安全配置

# CORS配置
from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://yourdomain.com"],  # 限制域名
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

# 安全头
from fastapi.middleware.trustedhost import TrustedHostMiddleware

app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["yourdomain.com", "*.yourdomain.com"]
)

10. 总结

关键要点

  1. 认证授权
  2. JWT认证
  3. API密钥
  4. OAuth2

  5. 数据保护

  6. 敏感数据加密
  7. Secret管理
  8. 安全存储

  9. 防护措施

  10. 输入验证
  11. 速率限制
  12. 防注入攻击

下一步

  • 学习最佳实践(第23章)

恭喜完成第22章! 🎉

你已经掌握RAG系统安全的各项技能!

下一步:第23章 - 最佳实践和案例分析