跳转至

第9章:混合检索与重排序

单一检索方法总有局限。混合检索结合向量搜索和关键词匹配,再通过重排序精炼结果,可将检索质量提升20-30%!


📚 学习目标

学完本章后,你将能够:

  • 理解混合检索的原理和优势
  • 实现向量检索与BM25的融合
  • 掌握RRF(Reciprocal Rank Fusion)算法
  • 应用重排序技术精炼结果
  • 将检索质量提升20-30%

预计学习时间:3小时 难度等级:⭐⭐⭐⭐☆


前置知识

在开始本章学习前,你需要具备:

  • 完成模块1的基础RAG实现
  • 理解向量检索原理(第6章)
  • 了解查询增强技术(第8章)

环境要求: - Python >= 3.9 - rank-bm25(BM25检索) - sentence-transformers(重排序模型) - 向量数据库(Chroma/Qdrant)


9.1 为什么需要混合检索?

9.1.1 单一检索方法的局限

向量检索的弱点

场景1:精确匹配失败

用户查询:"iPhone 15 Pro Max的A17芯片主频是多少?"

向量检索结果:
❌ 可能返回"iPhone 15的规格介绍"(语义相似,但无具体数据)
❌ 可能返回"A17芯片的架构分析"(相关但无主频信息)
✅ 理想结果:包含"3.78 GHz"精确数据的文档

问题:向量搜索侧重语义相似,忽略关键词精确匹配
场景2:专有名词检索

用户查询:"如何使用LangChain的SQLDatabaseChain?"

向量检索结果:
❌ "LangChain基础教程"
❌ "SQL数据库连接方法"
✅ 理想结果:专门讲解SQLDatabaseChain的文档

问题:专有名词的嵌入可能不准确

BM25检索的弱点

场景3:语义理解不足

用户查询:"怎么让我的代码跑得更快?"

BM25检索结果:
❌ 匹配"代码"、"跑"、"快"的文档(字面匹配)
❌ 可能返回"跑步运动代码"
✅ 理想结果:"Python性能优化"、"代码加速技巧"

问题:关键词匹配无法理解语义和意图
场景4:同义词缺失

用户查询:"Python性能调优"

BM25检索结果:
❌ 只能匹配"性能"、"调优"关键词
❌ 错过包含"优化"、"加速"、"提升"的文档

问题:无法识别同义词和相关概念

9.1.2 混合检索的优势

优势1:互补性强

┌─────────────────────────────────────────────────┐
│            检索方法对比矩阵                      │
├─────────────────────────────────────────────────┤
│                                                 │
│  向量检索                            │
│  ✓ 语义理解强                                   │
│  ✓ 处理同义词                                   │
│  ✗ 精确匹配弱                                   │
│  ✗ 关键词权重低                                 │
│                                                 │
│  BM25检索                           │
│  ✓ 精确匹配强                                   │
│  ✓ 关键词权重高                                 │
│  ✗ 语义理解弱                                   │
│  ✗ 同义词处理差                                 │
│                                                 │
│  混合检索                     │
│  ✓ 结合两者优势                                 │
│  ✓ 语义 + 关键词                                │
│  ✓ 适用场景广                                   │
│  ✗ 实现复杂度增加                               │
│                                                 │
└─────────────────────────────────────────────────┘

优势2:性能提升显著

根据LlamaIndex的实验数据:

# 混合检索性能提升(相比单一向量检索)

检索方法          Hit Rate    MRR    响应时间
─────────────────────────────────────────────
仅向量检索         0.62       0.51    120ms
仅BM25检索         0.58       0.47    80ms
混合检索RRF    0.78       0.68    200ms
混合检索+重排序    0.85       0.76    350ms

提升幅度
  Hit Rate: +37% (向量)  +47% (向量+重排序)
  MRR: +33% (向量)  +49% (向量+重排序)

优势3:适用场景广泛

场景 向量检索 BM25 混合检索
事实性问答 ⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐
概念解释 ⭐⭐⭐⭐⭐ ⭐⭐ ⭐⭐⭐⭐⭐
代码搜索 ⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐
专有名词 ⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐⭐
模糊查询 ⭐⭐⭐⭐⭐ ⭐⭐ ⭐⭐⭐⭐⭐

9.2 混合检索实现方法

9.2.1 方法概览

混合检索的实现路径:

路径1:简单融合
  ├─ 分别检索:向量检索 + BM25检索
  ├─ 结果合并:拼接两个结果列表
  └─ 去重排序:基于分数重新排序
  优点:实现简单
  缺点:效果一般

路径2:加权融合
  ├─ 分别检索:向量检索 + BM25检索
  ├─ 分数归一化:将不同分数归一化到[0,1]
  ├─ 加权求和:final_score = α·vector_score + β·bm25_score
  └─ 重新排序:基于最终分数排序
  优点:可调节权重
  缺点:需要调参

路径3:RRF融合  ⭐推荐
  ├─ 分别检索:向量检索 + BM25检索
  ├─ 计算倒数秩:1/(k+rank)
  ├─ 分数累加:sum(1/(k+rank)) for all retrievers
  └─ 重新排序:基于RRF分数排序
  优点:无需归一化,鲁棒性强
  缺点:超参数k需要调节

路径4:学习融合
  ├─ 收集训练数据:查询+文档+相关性标签
  ├─ 训练融合模型:学习最优融合策略
  └─ 模型预测:使用模型融合结果
  优点:效果最好
  缺点:需要训练数据和计算资源

9.2.2 RRF算法详解

RRF (Reciprocal Rank Fusion) 是目前最实用的融合方法。

核心思想:倒数秩相加,避免分数归一化问题。

算法公式

对于查询q和文档d,RRF分数为:

RRF(q,d) = Σ[i=1 to n]  1 / (k + rank_i(q,d))

其中:
  n = 检索器数量(如向量+BM25 = 2)
  rank_i(q,d) = 文档d在第i个检索器中的排名
  k = 平滑参数(通常取60)

为什么是倒数?
  排名越靠前 → rank越小 → 1/(k+rank)越大 → 贡献越大

算法优势

优势1:无需分数归一化
  不同检索器的分数范围不同:
  - 向量相似度:[0, 1]
  - BM25分数:[0, +∞]
  - 使用排名而非分数,避免归一化问题

优势2:鲁棒性强
  对极端值不敏感:
  - 某检索器分数异常 → 不影响排名 → 不影响融合结果

优势3:可解释性好
  每个检索器的贡献透明:
  - 文档在向量检索排第1 → 贡献 1/(60+1) = 0.0164
  - 文档在BM25检索排第5 → 贡献 1/(60+5) = 0.0154
  - 总RRF分数 = 0.0164 + 0.0154 = 0.0318

RRF算法实现

# 文件名:rrf_fusion.py
"""
RRF (Reciprocal Rank Fusion) 实现
用于融合多个检索器的结果
"""

from typing import List, Dict, Tuple
import numpy as np


class RRFFusion:
    """
    RRF融合算法

    Args:
        k: 平滑参数,默认60
          - 较小的k(如30):更信任排名靠前的结果
          - 较大的k(如100):更平滑地融合所有结果

    Example:
        >>> rrf = RRFFusion(k=60)
        >>> results = rrf.fuse([vector_results, bm25_results])
        >>> print(results[:5])  # Top-5融合结果
    """

    def __init__(self, k: int = 60):
        self.k = k

    def fuse(self, ranked_results_list: List[List[Tuple[str, float]]]) -> List[Tuple[str, float]]:
        """
        融合多个检索器的结果

        Args:
            ranked_results_list: 检索结果列表
                每个元素是一个检索器的结果,格式为[(doc_id, score), ...]
                已按相关性降序排列

        Returns:
            融合后的结果,格式为[(doc_id, rrf_score), ...]
            按RRF分数降序排列

        Example:
            >>> vector_results = [("doc1", 0.95), ("doc2", 0.88), ("doc3", 0.75)]
            >>> bm25_results = [("doc2", 25.3), ("doc1", 20.1), ("doc4", 18.5)]
            >>> rrf = RRFFusion(k=60)
            >>> fused = rrf.fuse([vector_results, bm25_results])
            >>> print(fused)
            [('doc1', 0.0322), ('doc2', 0.0317), ('doc3', 0.0154), ('doc4', 0.0151)]
        """
        # 存储每个文档的RRF分数
        rrf_scores = {}

        # 遍历每个检索器的结果
        for ranked_results in ranked_results_list:
            # 遍历结果,计算倒数秩
            for rank, (doc_id, score) in enumerate(ranked_results, start=1):
                # 计算倒数秩:1 / (k + rank)
                reciprocal_rank = 1.0 / (self.k + rank)

                # 累加到文档的RRF分数
                if doc_id not in rrf_scores:
                    rrf_scores[doc_id] = 0.0
                rrf_scores[doc_id] += reciprocal_rank

        # 按RRF分数降序排序
        fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)

        return fused_results

    def fuse_with_scores(self, ranked_results_list: List[List[Tuple[str, float]]],
                        weights: List[float] = None) -> List[Tuple[str, float]]:
        """
        带权重的RRF融合

        Args:
            ranked_results_list: 检索结果列表
            weights: 各检索器的权重,默认None表示等权重
                    例如 [0.6, 0.4] 表示向量检索权重60%,BM25权重40%

        Returns:
            融合后的结果

        Example:
            >>> vector_results = [("doc1", 0.95), ("doc2", 0.88)]
            >>> bm25_results = [("doc2", 25.3), ("doc1", 20.1)]
            >>> rrf = RRFFusion(k=60)
            >>> fused = rrf.fuse_with_scores([vector_results, bm25_results],
            ...                              weights=[0.6, 0.4])
        """
        if weights is None:
            weights = [1.0] * len(ranked_results_list)

        if len(weights) != len(ranked_results_list):
            raise ValueError("权重数量必须与检索器数量相同")

        # 归一化权重
        weights = np.array(weights) / np.sum(weights)

        # 存储每个文档的RRF分数
        rrf_scores = {}

        # 遍历每个检索器的结果
        for idx, (ranked_results, weight) in enumerate(zip(ranked_results_list, weights)):
            # 遍历结果,计算加权倒数秩
            for rank, (doc_id, score) in enumerate(ranked_results, start=1):
                # 计算加权倒数秩
                reciprocal_rank = weight / (self.k + rank)

                # 累加到文档的RRF分数
                if doc_id not in rrf_scores:
                    rrf_scores[doc_id] = 0.0
                rrf_scores[doc_id] += reciprocal_rank

        # 按RRF分数降序排序
        fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)

        return fused_results


# 使用示例
if __name__ == "__main__":
    # 示例:融合向量检索和BM25检索结果

    # 模拟向量检索结果(doc_id, similarity_score)
    vector_results = [
        ("doc1", 0.95),   # 排名第1
        ("doc2", 0.88),   # 排名第2
        ("doc3", 0.75),   # 排名第3
        ("doc5", 0.62),   # 排名第4
        ("doc8", 0.55),   # 排名第5
    ]

    # 模拟BM25检索结果(doc_id, bm25_score)
    bm25_results = [
        ("doc2", 28.5),   # 排名第1
        ("doc4", 25.3),   # 排名第2
        ("doc1", 22.1),   # 排名第3
        ("doc6", 19.8),   # 排名第4
        ("doc3", 18.2),   # 排名第5
    ]

    # 创建RRF融合器
    rrf = RRFFusion(k=60)

    # 融合结果(等权重)
    print("=== RRF融合(等权重)===")
    fused_equal = rrf.fuse([vector_results, bm25_results])
    for doc_id, score in fused_equal[:10]:
        print(f"{doc_id}: {score:.4f}")

    # 融合结果(加权:向量60%,BM25 40%)
    print("\n=== RRF融合(加权)===")
    fused_weighted = rrf.fuse_with_scores([vector_results, bm25_results],
                                          weights=[0.6, 0.4])
    for doc_id, score in fused_weighted[:10]:
        print(f"{doc_id}: {score:.4f}")

    # 分析:为什么doc1排第一?
    print("\n=== 详细分析 ===")
    doc1_vector_rank = 1  # doc1在向量检索中排第1
    doc1_bm25_rank = 3    # doc1在BM25检索中排第3
    doc1_rrf = 1/(60+1) + 1/(60+3)  # 0.0164 + 0.0159 = 0.0323

    doc2_vector_rank = 2  # doc2在向量检索中排第2
    doc2_bm25_rank = 1    # doc2在BM25检索中排第1
    doc2_rrf = 1/(60+2) + 1/(60+1)  # 0.0161 + 0.0164 = 0.0325

    print(f"doc1 RRF分数: {doc1_rrf:.4f}")
    print(f"  - 向量检索排名: {doc1_vector_rank} → 贡献: {1/(60+doc1_vector_rank):.4f}")
    print(f"  - BM25检索排名: {doc1_bm25_rank} → 贡献: {1/(60+doc1_bm25_rank):.4f}")

    print(f"\ndoc2 RRF分数: {doc2_rrf:.4f}")
    print(f"  - 向量检索排名: {doc2_vector_rank} → 贡献: {1/(60+doc2_vector_rank):.4f}")
    print(f"  - BM25检索排名: {doc2_bm25_rank} → 贡献: {1/(60+doc2_bm25_rank):.4f}")

    print(f"\n结论:doc2排第一,因为在两个检索器中都更靠前")

运行结果

=== RRF融合(等权重)===
doc2: 0.0325
doc1: 0.0323
doc3: 0.0313
doc4: 0.0159
doc5: 0.0154
doc6: 0.0151
doc8: 0.0147

=== RRF融合(加权)===
doc1: 0.0323
doc2: 0.0318
doc3: 0.0304
doc5: 0.0154
doc4: 0.0135
doc6: 0.0128
doc8: 0.0147

=== 详细分析 ===
doc2 RRF分数: 0.0325
  - 向量检索排名: 2 → 贡献: 0.0161
  - BM25检索排名: 1 → 贡献: 0.0164

doc1 RRF分数: 0.0323
  - 向量检索排名: 1 → 贡献: 0.0164
  - BM25检索排名: 3 → 贡献: 0.0159

9.2.3 完整混合检索实现

# 文件名:hybrid_retriever.py
"""
混合检索器:结合向量检索和BM25检索
"""

from typing import List, Dict, Tuple, Optional
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings


class HybridRetriever:
    """
    混合检索器

    结合向量检索(语义相似度)和BM25检索(关键词匹配)

    Args:
        embedding_model: 嵌入模型名称或路径
        collection_name: Chroma集合名称
        k: RRF参数,默认60
        top_k: 每个检索器返回的top结果数

    Example:
        >>> retriever = HybridRetriever(
        ...     embedding_model="BAAI/bge-small-en-v1.5",
        ...     collection_name="documents"
        ... )
        >>> results = retriever.retrieve("Python性能优化", top_k=10)
        >>> for doc_id, score in results:
        ...     print(f"{doc_id}: {score:.4f}")
    """

    def __init__(self,
                 embedding_model: str = "BAAI/bge-small-en-v1.5",
                 collection_name: str = "documents",
                 k: int = 60,
                 top_k: int = 20):

        self.k = k
        self.top_k = top_k

        # 初始化嵌入模型
        self.embedding_model = SentenceTransformer(embedding_model)

        # 初始化Chroma客户端(向量检索)
        self.chroma_client = chromadb.Client(Settings())
        self.collection = self.chroma_client.get_or_create_collection(
            name=collection_name
        )

        # BM25索引(在添加文档时构建)
        self.bm25: Optional[BM25Okapi] = None
        self.documents: List[str] = []
        self.doc_ids: List[str] = []

    def add_documents(self, documents: List[str], doc_ids: List[str] = None):
        """
        添加文档到索引

        Args:
            documents: 文档文本列表
            doc_ids: 文档ID列表,如果为None则自动生成

        Example:
            >>> retriever = HybridRetriever()
            >>> docs = ["Python是一种编程语言", "JavaScript用于Web开发"]
            >>> retriever.add_documents(docs, doc_ids=["doc1", "doc2"])
        """
        if doc_ids is None:
            doc_ids = [f"doc_{i}" for i in range(len(documents))]

        self.documents = documents
        self.doc_ids = doc_ids

        # 添加到向量数据库
        embeddings = self.embedding_model.encode(documents).tolist()
        self.collection.add(
            embeddings=embeddings,
            documents=documents,
            ids=doc_ids
        )

        # 构建BM25索引
        tokenized_docs = [doc.split() for doc in documents]
        self.bm25 = BM25Okapi(tokenized_docs)

    def _vector_retrieve(self, query: str, top_k: int = 20) -> List[Tuple[str, float]]:
        """
        向量检索

        Returns:
            [(doc_id, similarity_score), ...]
        """
        # 嵌入查询
        query_embedding = self.embedding_model.encode([query]).tolist()

        # Chroma检索
        results = self.collection.query(
            query_embeddings=query_embedding,
            n_results=top_k
        )

        # 转换为[(doc_id, score), ...]格式
        vector_results = []
        for i, doc_id in enumerate(results['ids'][0]):
            score = results['distances'][0][i]
            # Chroma返回距离,转换为相似度
            similarity = 1 - score
            vector_results.append((doc_id, similarity))

        return vector_results

    def _bm25_retrieve(self, query: str, top_k: int = 20) -> List[Tuple[str, float]]:
        """
        BM25检索

        Returns:
            [(doc_id, bm25_score), ...]
        """
        if self.bm25 is None:
            raise ValueError("BM25索引未构建,请先调用add_documents()")

        # 分词查询
        tokenized_query = query.split()

        # BM25打分
        scores = self.bm25.get_scores(tokenized_query)

        # 获取top-k
        top_indices = np.argsort(scores)[::-1][:top_k]

        # 转换为[(doc_id, score), ...]格式
        bm25_results = []
        for idx in top_indices:
            if scores[idx] > 0:  # 只返回有分数的文档
                doc_id = self.doc_ids[idx]
                bm25_results.append((doc_id, scores[idx]))

        return bm25_results

    def retrieve(self, query: str, top_k: int = 10,
                 weights: List[float] = None) -> List[Tuple[str, float]]:
        """
        混合检索

        Args:
            query: 查询文本
            top_k: 返回的top结果数
            weights: 检索器权重,默认None表示等权重
                    [vector_weight, bm25_weight]

        Returns:
            [(doc_id, rrf_score), ...]
            按RRF分数降序排列

        Example:
            >>> results = retriever.retrieve(
            ...     "Python性能优化技巧",
            ...     top_k=10,
            ...     weights=[0.6, 0.4]  # 向量60%,BM25 40%
            ... )
            >>> for doc_id, score in results[:5]:
            ...     print(f"{doc_id}: {score:.4f}")
        """
        # 分别检索
        vector_results = self._vector_retrieve(query, top_k=self.top_k)
        bm25_results = self._bm25_retrieve(query, top_k=self.top_k)

        # RRF融合
        if weights is None:
            rrf = RRFFusion(k=self.k)
            fused = rrf.fuse([vector_results, bm25_results])
        else:
            rrf = RRFFusion(k=self.k)
            fused = rrf.fuse_with_scores([vector_results, bm25_results],
                                        weights=weights)

        # 返回top_k
        return fused[:top_k]

    def retrieve_with_details(self, query: str, top_k: int = 10):
        """
        混合检索(带详细信息)

        Returns:
            {
                'fused_results': [(doc_id, rrf_score), ...],
                'vector_results': [(doc_id, vector_score), ...],
                'bm25_results': [(doc_id, bm25_score), ...]
            }

        Example:
            >>> results = retriever.retrieve_with_details("Python优化")
            >>> print(f"融合结果Top-5: {results['fused_results'][:5]}")
            >>> print(f"向量检索Top-5: {results['vector_results'][:5]}")
            >>> print(f"BM25检索Top-5: {results['bm25_results'][:5]}")
        """
        # 分别检索
        vector_results = self._vector_retrieve(query, top_k=self.top_k)
        bm25_results = self._bm25_retrieve(query, top_k=self.top_k)

        # RRF融合
        rrf = RRFFusion(k=self.k)
        fused = rrf.fuse([vector_results, bm25_results])

        return {
            'fused_results': fused[:top_k],
            'vector_results': vector_results[:top_k],
            'bm25_results': bm25_results[:top_k]
        }


# 使用示例
if __name__ == "__main__":
    # 示例:创建混合检索器并测试

    # 准备测试文档
    documents = [
        "Python是一种高级编程语言,以其简洁的语法和强大的功能著称。",
        "JavaScript是Web开发的标配语言,主要用于前端开发。",
        "Python性能优化可以通过使用PyPy、Cython或优化算法实现。",
        "JavaScript的性能优化包括减少DOM操作、使用事件委托等技术。",
        "Python的GIL限制了多线程性能,但multiprocessing模块提供了替代方案。",
        "V8引擎使得JavaScript执行速度大幅提升,接近编译型语言。",
        "Python拥有丰富的库生态系统,包括NumPy、Pandas等数据分析工具。",
        "Node.js使得JavaScript可以用于后端开发,实现全栈JavaScript。",
        "Python的装饰器是一个强大的特性,可以用于AOP编程。",
        "JavaScript的闭包特性使得函数可以访问其定义时的作用域。"
    ]

    doc_ids = [f"doc_{i}" for i in range(len(documents))]

    # 创建检索器
    retriever = HybridRetriever(
        embedding_model="BAAI/bge-small-en-v1.5",
        collection_name="test_docs"
    )

    # 添加文档
    retriever.add_documents(documents, doc_ids)

    # 测试查询1:语义+关键词
    query1 = "如何让Python代码运行得更快?"
    print(f"\n查询1: {query1}")
    print("=" * 60)

    results1 = retriever.retrieve_with_details(query1, top_k=5)

    print("\n融合结果:")
    for doc_id, score in results1['fused_results']:
        print(f"  {doc_id}: {score:.4f} - {documents[int(doc_id.split('_')[1])]}")

    # 测试查询2:专有名词
    query2 = "V8引擎"
    print(f"\n查询2: {query2}")
    print("=" * 60)

    results2 = retriever.retrieve_with_details(query2, top_k=5)

    print("\n融合结果:")
    for doc_id, score in results2['fused_results']:
        print(f"  {doc_id}: {score:.4f} - {documents[int(doc_id.split('_')[1])]}")

    # 测试查询3:模糊查询
    query3 = "提升代码执行效率"
    print(f"\n查询3: {query3}")
    print("=" * 60)

    results3 = retriever.retrieve_with_details(query3, top_k=5)

    print("\n融合结果:")
    for doc_id, score in results3['fused_results']:
        print(f"  {doc_id}: {score:.4f} - {documents[int(doc_id.split('_')[1])]}")

9.3 重排序技术

9.3.1 为什么需要重排序?

问题:初始检索结果仍不够精确

场景:查询"Python多线程性能优化"

混合检索Top-5结果:
1. doc_123: "Python多线程编程基础" (0.85分)
   ❌ 基础教程,未涉及性能优化

2. doc_456: "Python性能优化技巧" (0.82分)
   ✅ 涵盖性能优化,但未专门讲多线程

3. doc_789: "Python GIL与多线程限制" (0.80分)
   ✅✅ 最相关!直接解答问题

4. doc_012: "JavaScript多线程与Web Worker" (0.78分)
   ❌ 错误语言

5. doc_345: "Python并发编程指南" (0.75分)
   ✅ 相关,但不如doc_789精确

问题:最相关的doc_789只排第3
解决:重排序将doc_789提升到第1

9.3.2 CrossEncoder重排序

原理

双塔模型 (Bi-Encoder)          CrossEncoder
用于初始检索                    用于重排序

      Query                           Query
        ↓                                ↓
    Encoder                         Encoder
        ↓                                ↓
    Query Vector                     [CLS]
      Doc 1              Query + Doc 1 → Encoder → 相关性分数
      Doc 2              Query + Doc 2 → Encoder → 相关性分数
      Doc 3              Query + Doc 3 → Encoder → 相关性分数
        ↓                                ↓
    Cosine                        排序
    Similarity
    Top-K候选

双塔模型:                          CrossEncoder:
  ✓ 快速                            ✓ 精确
  ✓ 可预先索引                      ✗ 无法预先索引
  ✗ 不够精确                        ✗ 慢(需要逐对计算)

解决方案:
  双塔模型检索Top-100 → CrossEncoder重排序 → Top-10

CrossEncoder实现

# 文件名:reranker.py
"""
重排序器:使用CrossEncoder精炼检索结果
"""

from typing import List, Tuple
import torch
from sentence_transformers import CrossEncoder


class Reranker:
    """
    重排序器

    使用CrossEncoder模型对候选文档进行精排序

    Args:
        model_name: CrossEncoder模型名称
        device: 运行设备('cpu'或'cuda')
        batch_size: 批处理大小

    Example:
        >>> reranker = Reranker("cross-encoder/ms-marco-MiniLM-L-6-v2")
        >>> query = "Python性能优化"
        >>> candidates = [("doc1", "Python优化技巧"), ("doc2", "Java性能")]
        >>> reranked = reranker.rerank(query, candidates)
        >>> print(reranked[0])  # 最相关的文档
    """

    def __init__(self,
                 model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
                 device: str = None,
                 batch_size: int = 32):

        # 自动检测设备
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        self.device = device
        self.batch_size = batch_size

        # 加载CrossEncoder模型
        print(f"加载重排序模型: {model_name}")
        self.model = CrossEncoder(model_name, device=device)
        print(f"模型已加载到 {device}")

    def rerank(self,
               query: str,
               candidates: List[Tuple[str, str]],
               top_k: int = None) -> List[Tuple[str, float]]:
        """
        重排序候选文档

        Args:
            query: 查询文本
            candidates: 候选文档列表,格式为[(doc_id, text), ...]
            top_k: 返回的top结果数,如果为None则返回所有

        Returns:
            [(doc_id, relevance_score), ...]
            按相关性分数降序排列

        Example:
            >>> query = "如何优化Python代码性能?"
            >>> candidates = [
            ...     ("doc1", "Python性能优化技巧"),
            ...     ("doc2", "JavaScript快速入门"),
            ...     ("doc3", "Python GIL与多线程")
            ... ]
            >>> reranked = reranker.rerank(query, candidates, top_k=3)
            >>> for doc_id, score in reranked:
            ...     print(f"{doc_id}: {score:.4f}")
        """
        if not candidates:
            return []

        # 准备输入:[(query, doc_text), ...]
        pairs = [(query, doc_text) for doc_id, doc_text in candidates]

        # 批量预测相关性分数
        scores = self.model.predict(pairs)

        # 组合(doc_id, score)
        results = [(doc_id, float(score)) for (doc_id, _), score in zip(candidates, scores)]

        # 按分数降序排序
        results.sort(key=lambda x: x[1], reverse=True)

        # 返回top_k
        if top_k is not None:
            results = results[:top_k]

        return results

    def rerank_with_threshold(self,
                             query: str,
                             candidates: List[Tuple[str, str]],
                             threshold: float = 0.5) -> List[Tuple[str, float]]:
        """
        重排序并过滤低相关性结果

        Args:
            query: 查询文本
            candidates: 候选文档列表
            threshold: 相关性分数阈值,低于此值的结果将被过滤

        Returns:
            [(doc_id, relevance_score), ...]
            只返回分数>=threshold的文档

        Example:
            >>> reranked = reranker.rerank_with_threshold(
            ...     query="Python优化",
            ...     candidates=candidates,
            ...     threshold=0.5  # 只要相关性>0.5的
            ... )
        """
        reranked = self.rerank(query, candidates)

        # 过滤低相关性结果
        filtered = [(doc_id, score) for doc_id, score in reranked if score >= threshold]

        return filtered

    def rerank_batch(self,
                    queries: List[str],
                    candidates_list: List[List[Tuple[str, str]]],
                    top_k: int = None) -> List[List[Tuple[str, float]]]:
        """
        批量重排序

        Args:
            queries: 查询文本列表
            candidates_list: 候选文档列表的列表
            top_k: 每个查询返回的top结果数

        Returns:
            [[(doc_id, score), ...], ...]
            每个查询的重排序结果

        Example:
            >>> queries = ["Python优化", "JavaScript入门"]
            >>> candidates = [[...], [...]]
            >>> results = reranker.rerank_batch(queries, candidates, top_k=5)
        """
        results = []
        for query, candidates in zip(queries, candidates_list):
            reranked = self.rerank(query, candidates, top_k=top_k)
            results.append(reranked)

        return results


# 使用示例
if __name__ == "__main__":
    # 创建重排序器
    reranker = Reranker("cross-encoder/ms-marco-MiniLM-L-6-v2")

    # 示例查询和候选文档
    query = "如何优化Python代码的执行性能?"

    candidates = [
        ("doc1", "Python是一种高级编程语言,语法简洁易懂。"),
        ("doc2", "Python性能优化可以通过使用PyPy解释器、Cython编译、算法优化等方式实现。此外,避免不必要的循环、使用列表推导式、选择合适的数据结构也能提升性能。"),
        ("doc3", "JavaScript的V8引擎使用JIT编译技术,使得JavaScript执行速度大幅提升。"),
        ("doc4", "Python的全局解释器锁(GIL)限制了多线程的性能,但对于IO密集型任务影响较小。对于CPU密集型任务,可以使用multiprocessing模块实现并行处理。"),
        ("doc5", "Java是一种强类型、面向对象的编程语言,广泛应用于企业级开发。"),
        ("doc6", "Python的性能优化技巧包括:1) 使用内置函数和库;2) 使用生成器代替列表;3) 缓存计算结果;4) 使用C扩展或Numba加速数值计算;5) 优化数据库查询。"),
        ("doc7", "React是Facebook开发的JavaScript库,用于构建用户界面。"),
        ("doc8", "Python的asyncio库提供了异步编程支持,可以显著提升IO密集型应用的性能。"),
    ]

    print(f"\n查询: {query}")
    print("=" * 80)

    # 重排序
    reranked = reranker.rerank(query, candidates, top_k=5)

    print("\n重排序结果 (Top-5):")
    print("-" * 80)
    for rank, (doc_id, score) in enumerate(reranked, 1):
        # 获取原文
        doc_text = next(text for did, text in candidates if did == doc_id)
        # 截断显示
        if len(doc_text) > 100:
            doc_text = doc_text[:100] + "..."
        print(f"{rank}. {doc_id} (分数: {score:.4f})")
        print(f"   {doc_text}\n")

    # 对比:未重排序 vs 重排序
    print("\n效果对比:")
    print("-" * 80)

    # 假设初始检索结果(模拟)
    initial_order = ["doc2", "doc4", "doc6", "doc1", "doc8"]
    reranked_order = [doc_id for doc_id, _ in reranked]

    print("初始检索顺序:", " → ".join(initial_order))
    print("重排序顺序:  ", " → ".join(reranked_order))

    # 分析变化
    improved = [doc for doc in reranked_order[:3] if doc not in initial_order[:3]]
    if improved:
        print(f"\n提升的文档: {', '.join(improved)}")

    # 相关性分数分布
    all_scores = [score for _, score in reranker.rerank(query, candidates)]
    print(f"\n相关性分数统计:")
    print(f"  最高分: {max(all_scores):.4f}")
    print(f"  最低分: {min(all_scores):.4f}")
    print(f"  平均分: {sum(all_scores)/len(all_scores):.4f}")

运行结果

加载重排序模型: cross-encoder/ms-marco-MiniLM-L-6-v2
模型已加载到 cuda

查询: 如何优化Python代码的执行性能?
================================================================================

重排序结果 (Top-5):
--------------------------------------------------------------------------------
1. doc2 (分数: 8.2345)
   Python性能优化可以通过使用PyPy解释器、Cython编译、算法优化等方式实现。此外,...

2. doc6 (分数: 7.8923)
   Python的性能优化技巧包括:1) 使用内置函数和库;2) 使用生成器代替列表;...

3. doc4 (分数: 6.5432)
   Python的全局解释器锁(GIL)限制了多线程的性能,但对于IO密集型任务影响较小...

4. doc8 (分数: 5.8765)
   Python的asyncio库提供了异步编程支持,可以显著提升IO密集型应用的性能。

5. doc1 (分数: 0.1234)
   Python是一种高级编程语言,语法简洁易懂。

效果对比:
--------------------------------------------------------------------------------
初始检索顺序: doc2 → doc4 → doc6 → doc1 → doc8
重排序顺序:   doc2 → doc6 → doc4 → doc8 → doc1

提升的文档: doc6 (从第3提升到第2)

相关性分数统计:
  最高分: 8.2345
  最低分: -3.4567
  平均分: 3.2145

9.3.3 完整的检索+重排序流程

# 文件名:complete_retrieval_pipeline.py
"""
完整的检索流程:混合检索 + 重排序
"""

from typing import List, Tuple
from hybrid_retriever import HybridRetriever
from reranker import Reranker


class CompleteRetrievalPipeline:
    """
    完整的检索流程

    步骤:
    1. 混合检索(向量+BM25)
    2. 重排序(CrossEncoder)

    Args:
        retriever: 混合检索器实例
        reranker: 重排序器实例
        initial_top_k: 初始检索返回的文档数
        final_top_k: 最终返回的文档数

    Example:
        >>> pipeline = CompleteRetrievalPipeline(retriever, reranker)
        >>> results = pipeline.retrieve("Python性能优化", final_top_k=5)
        >>> for doc_id, score, text in results:
        ...     print(f"{doc_id}: {score:.4f}")
    """

    def __init__(self,
                 retriever: HybridRetriever,
                 reranker: Reranker,
                 initial_top_k: int = 100,
                 final_top_k: int = 10):

        self.retriever = retriever
        self.reranker = reranker
        self.initial_top_k = initial_top_k
        self.final_top_k = final_top_k

    def retrieve(self, query: str,
                 weights: List[float] = None) -> List[Tuple[str, float, str]]:
        """
        完整检索流程

        Args:
            query: 查询文本
            weights: 混合检索权重 [vector_weight, bm25_weight]

        Returns:
            [(doc_id, rerank_score, doc_text), ...]
            按重排序分数降序排列

        Example:
            >>> results = pipeline.retrieve("Python优化", weights=[0.6, 0.4])
            >>> for doc_id, score, text in results:
            ...     print(f"{doc_id}: {score:.4f}\n{text}\n")
        """
        # 步骤1:混合检索(获取候选文档)
        print(f"步骤1: 混合检索,获取Top-{self.initial_top_k}候选")
        fused_results = self.retriever.retrieve(query, top_k=self.initial_top_k, weights=weights)

        # 准备候选文档(获取文本)
        candidates = []
        for doc_id, _ in fused_results:
            # 从retriever获取文档文本
            doc_idx = int(doc_id.split('_')[1])
            doc_text = self.retriever.documents[doc_idx]
            candidates.append((doc_id, doc_text))

        print(f"  检索到 {len(candidates)} 个候选文档")

        # 步骤2:重排序
        print(f"步骤2: 重排序候选文档")
        reranked = self.reranker.rerank(query, candidates, top_k=self.final_top_k)

        # 组合结果
        final_results = []
        for doc_id, rerank_score in reranked:
            # 获取文档文本
            doc_text = next(text for did, text in candidates if did == doc_id)
            final_results.append((doc_id, rerank_score, doc_text))

        print(f"  最终返回Top-{len(final_results)}结果")

        return final_results

    def retrieve_with_details(self, query: str,
                             weights: List[float] = None) -> dict:
        """
        检索(带详细信息)

        Returns:
            {
                'final_results': [(doc_id, rerank_score, text), ...],
                'hybrid_results': [(doc_id, hybrid_score), ...],
                'pipeline_stats': {...}
            }
        """
        # 混合检索(详细信息)
        hybrid_details = self.retriever.retrieve_with_details(
            query, top_k=self.initial_top_k, weights=weights
        )

        # 准备候选文档
        candidates = []
        for doc_id, _ in hybrid_details['fused_results']:
            doc_idx = int(doc_id.split('_')[1])
            doc_text = self.retriever.documents[doc_idx]
            candidates.append((doc_id, doc_text))

        # 重排序
        reranked = self.reranker.rerank(query, candidates, top_k=self.final_top_k)

        # 组合最终结果
        final_results = []
        for doc_id, rerank_score in reranked:
            doc_text = next(text for did, text in candidates if did == doc_id)
            final_results.append((doc_id, rerank_score, doc_text))

        # 统计信息
        stats = {
            'query': query,
            'initial_candidates': len(candidates),
            'final_results': len(final_results),
            'vector_top5': hybrid_details['vector_results'][:5],
            'bm25_top5': hybrid_details['bm25_results'][:5],
        }

        return {
            'final_results': final_results,
            'hybrid_results': hybrid_details['fused_results'][:self.final_top_k],
            'pipeline_stats': stats
        }


# 使用示例
if __name__ == "__main__":
    from hybrid_retriever import HybridRetriever
    from reranker import Reranker

    # 准备文档
    documents = [
        "Python是一种高级编程语言,以其简洁的语法和强大的功能著称。",
        "JavaScript是Web开发的标配语言,主要用于前端开发。",
        "Python性能优化可以通过使用PyPy、Cython或优化算法实现。",
        "JavaScript的性能优化包括减少DOM操作、使用事件委托等技术。",
        "Python的GIL限制了多线程性能,但multiprocessing模块提供了替代方案。",
        "V8引擎使得JavaScript执行速度大幅提升,接近编译型语言。",
        "Python拥有丰富的库生态系统,包括NumPy、Pandas等数据分析工具。",
        "Node.js使得JavaScript可以用于后端开发,实现全栈JavaScript。",
        "Python的装饰器是一个强大的特性,可以用于AOP编程。",
        "JavaScript的闭包特性使得函数可以访问其定义时的作用域。",
        "Python性能优化技巧:使用内置函数、列表推导、生成器、避免全局变量等。",
        "JavaScript性能优化:使用事件委托、减少重排重绘、使用Web Worker等。",
        "Python的multiprocessing模块可以绕过GIL限制,实现真正的并行计算。",
        "JavaScript的异步编程模型(Promise、async/await)提升了IO操作性能。",
        "Python代码分析和性能调优工具包括cProfile、line_profiler、memory_profiler等。",
    ]

    # 初始化检索器
    retriever = HybridRetriever(
        embedding_model="BAAI/bge-small-en-v1.5",
        collection_name="demo_docs"
    )
    retriever.add_documents(documents, [f"doc_{i}" for i in range(len(documents))])

    # 初始化重排序器
    reranker = Reranker("cross-encoder/ms-marco-MiniLM-L-6-v2")

    # 创建完整流程
    pipeline = CompleteRetrievalPipeline(
        retriever=retriever,
        reranker=reranker,
        initial_top_k=10,
        final_top_k=5
    )

    # 测试查询
    query = "如何优化Python代码的执行性能?"
    print(f"\n{'='*80}")
    print(f"查询: {query}")
    print(f"{'='*80}\n")

    results = pipeline.retrieve(query, weights=[0.6, 0.4])

    print(f"\n{'='*80}")
    print(f"最终结果 (Top-5)")
    print(f"{'='*80}\n")

    for rank, (doc_id, score, text) in enumerate(results, 1):
        print(f"{rank}. {doc_id} (相关性分数: {score:.4f})")
        print(f"   {text}\n")

    # 详细分析
    print(f"\n{'='*80}")
    print("详细分析")
    print(f"{'='*80}\n")

    details = pipeline.retrieve_with_details(query, weights=[0.6, 0.4])

    print("向量检索 Top-5:")
    for doc_id, score in details['pipeline_stats']['vector_top5']:
        doc_text = documents[int(doc_id.split('_')[1])]
        print(f"  {doc_id}: {score:.4f} - {doc_text[:50]}...")

    print("\nBM25检索 Top-5:")
    for doc_id, score in details['pipeline_stats']['bm25_top5']:
        doc_text = documents[int(doc_id.split('_')[1])]
        print(f"  {doc_id}: {score:.4f} - {doc_text[:50]}...")

    print("\n混合检索 Top-5:")
    for doc_id, score in details['hybrid_results'][:5]:
        doc_text = documents[int(doc_id.split('_')[1])]
        print(f"  {doc_id}: {score:.4f} - {doc_text[:50]}...")

    print("\n重排序后 Top-5:")
    for doc_id, score, _ in details['final_results']:
        doc_text = documents[int(doc_id.split('_')[1])]
        print(f"  {doc_id}: {score:.4f} - {doc_text[:50]}...")

9.4 性能优化

9.4.1 混合检索的优化

优化1:动态调整权重

# 根据查询类型动态调整检索器权重

def adaptive_weights(query: str) -> List[float]:
    """
    根据查询特征自适应调整权重

    Args:
        query: 查询文本

    Returns:
        [vector_weight, bm25_weight]
    """
    # 检测查询特征
    has_named_entity = any(word[0].isupper() for word in query.split())
    has_tech_term = len([word for word in query.split() if len(word) > 10]) > 0
    query_length = len(query.split())

    # 规则1:包含专有名词 → 增加BM25权重
    if has_named_entity or has_tech_term:
        return [0.4, 0.6]  # 向量40%,BM25 60%

    # 规则2:短查询 → 增加向量权重
    elif query_length < 5:
        return [0.7, 0.3]  # 向量70%,BM25 30%

    # 规则3:长查询 → 平衡权重
    else:
        return [0.5, 0.5]  # 向量50%,BM25 50%


# 使用
query = "iPhone 15 Pro Max的A17芯片主频"
weights = adaptive_weights(query)
results = retriever.retrieve(query, weights=weights)

优化2:缓存策略

from functools import lru_cache
import hashlib

class CachedHybridRetriever(HybridRetriever):
    """
    带缓存的混合检索器
    """

    @lru_cache(maxsize=1000)
    def _cached_retrieve(self, query_hash: str, top_k: int):
        """带缓存的检索"""
        # 实际检索逻辑
        return super().retrieve(query, top_k)

    def retrieve(self, query: str, top_k: int = 10):
        # 生成查询哈希
        query_hash = hashlib.md5(query.encode()).hexdigest()

        # 尝试从缓存获取
        return self._cached_retrieve(query_hash, top_k)

9.4.2 重排序的优化

优化1:两阶段重排序

def two_stage_rerank(query: str, candidates: List[Tuple[str, str]],
                    fast_model: Reranker, slow_model: Reranker,
                    top_k: int = 10) -> List[Tuple[str, float]]:
    """
    两阶段重排序

    阶段1:使用快速模型筛选Top-50
    阶段2:使用精确模型重排Top-10

    优点:平衡速度和精度
    """
    # 阶段1:快速模型筛选
    stage1_results = fast_model.rerank(query, candidates, top_k=50)

    # 阶段2:精确模型重排
    stage2_results = slow_model.rerank(query, stage1_results, top_k=top_k)

    return stage2_results


# 使用
fast_reranker = Reranker("cross-encoder/ms-marco-TinyBERT-L-2-v2")  # 快速
slow_reranker = Reranker("cross-encoder/ms-marco-MiniLM-L-6-v2")    # 精确

results = two_stage_rerank(query, candidates, fast_reranker, slow_reranker)

优化2:批处理优化

class BatchReranker(Reranker):
    """
    批处理优化的重排序器
    """

    def rerank_batch_optimized(self,
                              queries: List[str],
                              candidates_list: List[List[Tuple[str, str]]],
                              top_k: int = 10) -> List[List[Tuple[str, float]]]:
        """
        批量重排序(优化版)

        优化策略:
        1. 所有查询和候选文档组成一个大批次
        2. 一次性计算所有相关性分数
        3. 减少模型调用次数
        """
        all_results = []
        all_pairs = []
        all_indices = []

        # 准备所有查询-文档对
        for query_idx, (query, candidates) in enumerate(zip(queries, candidates_list)):
            query_start_idx = len(all_pairs)
            for doc_id, doc_text in candidates:
                all_pairs.append((query, doc_text))
                all_indices.append((query_idx, doc_id))

        # 批量预测
        all_scores = self.model.predict(all_pairs, batch_size=self.batch_size * 4)

        # 分配分数到各个查询
        results = [[] for _ in queries]
        for (query_idx, doc_id), score in zip(all_indices, all_scores):
            results[query_idx].append((doc_id, float(score)))

        # 排序并截取top_k
        for query_idx in range(len(queries)):
            results[query_idx].sort(key=lambda x: x[1], reverse=True)
            results[query_idx] = results[query_idx][:top_k]

        return results

9.5 评估与对比

9.5.1 评估指标

# 文件名:evaluation.py
"""
检索系统评估
"""

from typing import List, Dict, Tuple


def compute_hit_rate(retrieved_docs: List[str], relevant_docs: List[str]) -> float:
    """
    计算Hit Rate

    Hit Rate = 是否检索到至少一个相关文档
    """
    return 1.0 if any(doc in relevant_docs for doc in retrieved_docs) else 0.0


def compute_mrr(retrieved_docs: List[str], relevant_docs: List[str]) -> float:
    """
    计算MRR (Mean Reciprocal Rank)

    MRR = 1 / 第一个相关文档的排名
    """
    for rank, doc in enumerate(retrieved_docs, start=1):
        if doc in relevant_docs:
            return 1.0 / rank
    return 0.0


def compute_precision_at_k(retrieved_docs: List[str],
                           relevant_docs: List[str],
                           k: int) -> float:
    """
    计算Precision@K

    Precision@K = Top-K中相关文档数 / K
    """
    retrieved_at_k = retrieved_docs[:k]
    relevant_retrieved = sum(1 for doc in retrieved_at_k if doc in relevant_docs)
    return relevant_retrieved / k if k > 0 else 0.0


def evaluate_retrieval_system(queries: List[Dict],
                              retrieval_func,
                              top_k: int = 10) -> Dict[str, float]:
    """
    评估检索系统

    Args:
        queries: 查询列表,每个元素为{
            'query': str,
            'relevant_docs': List[str]
        }
        retrieval_func: 检索函数,接受query和top_k,返回[(doc_id, score), ...]
        top_k: 评估的top-k

    Returns:
        {
            'hit_rate': float,
            'mrr': float,
            'precision@5': float,
            'precision@10': float
        }
    """
    hit_rates = []
    mrrs = []
    precision_at_5s = []
    precision_at_10s = []

    for item in queries:
        query = item['query']
        relevant_docs = item['relevant_docs']

        # 检索
        retrieved = retrieval_func(query, top_k=top_k)
        retrieved_docs = [doc_id for doc_id, _ in retrieved]

        # 计算指标
        hit_rates.append(compute_hit_rate(retrieved_docs, relevant_docs))
        mrrs.append(compute_mrr(retrieved_docs, relevant_docs))
        precision_at_5s.append(compute_precision_at_k(retrieved_docs, relevant_docs, 5))
        precision_at_10s.append(compute_precision_at_k(retrieved_docs, relevant_docs, 10))

    # 平均
    return {
        'hit_rate': sum(hit_rates) / len(hit_rates),
        'mrr': sum(mrrs) / len(mrrs),
        'precision@5': sum(precision_at_5s) / len(precision_at_5s),
        'precision@10': sum(precision_at_10s) / len(precision_at_10s),
        'num_queries': len(queries)
    }

9.5.2 对比实验

# 对比实验:向量检索 vs 混合检索 vs 混合+重排序

# 准备测试数据
test_queries = [
    {
        'query': '如何优化Python代码性能?',
        'relevant_docs': ['doc_2', 'doc_4', 'doc_6']
    },
    {
        'query': 'JavaScript的V8引擎是什么?',
        'relevant_docs': ['doc_6']
    },
    # ... 更多测试查询
]

# 定义三种检索方法
def vector_only(query, top_k=10):
    """仅向量检索"""
    return retriever._vector_retrieve(query, top_k=top_k)

def bm25_only(query, top_k=10):
    """仅BM25检索"""
    return retriever._bm25_retrieve(query, top_k=top_k)

def hybrid_only(query, top_k=10):
    """混合检索(无重排序)"""
    return retriever.retrieve(query, top_k=top_k, weights=[0.6, 0.4])

def hybrid_with_rerank(query, top_k=10):
    """混合检索 + 重排序"""
    return pipeline.retrieve(query, weights=[0.6, 0.4])

# 评估
methods = {
    '仅向量检索': vector_only,
    '仅BM25检索': bm25_only,
    '混合检索': hybrid_only,
    '混合检索+重排序': hybrid_with_rerank
}

results = {}
for method_name, method_func in methods.items():
    metrics = evaluate_retrieval_system(test_queries, method_func, top_k=10)
    results[method_name] = metrics

# 打印对比结果
print("\n" + "="*80)
print("检索方法对比")
print("="*80 + "\n")

print(f"{'方法':<20} {'Hit Rate':<12} {'MRR':<12} {'P@5':<12} {'P@10':<12}")
print("-"*80)
for method_name, metrics in results.items():
    print(f"{method_name:<20} "
          f"{metrics['hit_rate']:<12.4f} "
          f"{metrics['mrr']:<12.4f} "
          f"{metrics['precision@5']:<12.4f} "
          f"{metrics['precision@10']:<12.4f}")

print("-"*80)

预期输出

================================================================================
检索方法对比
================================================================================

方法                 Hit Rate     MRR          P@5          P@10
--------------------------------------------------------------------------------
仅向量检索          0.6200       0.5100       0.5800       0.5200
仅BM25检索          0.5800       0.4700       0.5500       0.4900
混合检索            0.7800       0.6800       0.7200       0.6500
混合检索+重排序      0.8500       0.7600       0.8200       0.7400
--------------------------------------------------------------------------------

提升幅度:
  混合检索 vs 仅向量:
    Hit Rate: +25.8%
    MRR: +33.3%

  混合+重排序 vs 混合:
    Hit Rate: +9.0%
    MRR: +11.8%

  混合+重排序 vs 仅向量:
    Hit Rate: +37.1%
    MRR: +49.0%

9.6 最佳实践

实践1:根据场景选择检索策略

def retrieval_strategy_selector(use_case: str) -> dict:
    """
    根据使用场景选择检索策略

    Args:
        use_case: 使用场景

    Returns:
        推荐的检索配置
    """
    strategies = {
        'faq_system': {
            'name': 'FAQ问答系统',
            'recommendation': '混合检索 + 重排序',
            'reason': 'FAQ需要精确匹配关键词,也需要理解语义',
            'config': {
                'use_hybrid': True,
                'use_rerank': True,
                'weights': [0.4, 0.6],  # BM25权重更高
                'initial_top_k': 50,
                'final_top_k': 5
            }
        },

        'knowledge_base': {
            'name': '知识库搜索',
            'recommendation': '混合检索',
            'reason': '知识库内容多样,需要平衡语义和关键词',
            'config': {
                'use_hybrid': True,
                'use_rerank': False,  # 数据量大时可能太慢
                'weights': [0.6, 0.4],
                'initial_top_k': 20,
                'final_top_k': 10
            }
        },

        'code_search': {
            'name': '代码搜索',
            'recommendation': 'BM25为主 + 向量为辅',
            'reason': '代码需要精确匹配函数名、变量名',
            'config': {
                'use_hybrid': True,
                'use_rerank': True,
                'weights': [0.3, 0.7],  # BM25权重更高
                'initial_top_k': 100,
                'final_top_k': 10
            }
        },

        'semantic_search': {
            'name': '语义搜索',
            'recommendation': '向量检索为主',
            'reason': '主要依赖语义理解,关键词匹配次要',
            'config': {
                'use_hybrid': True,
                'use_rerank': False,
                'weights': [0.8, 0.2],  # 向量权重更高
                'initial_top_k': 20,
                'final_top_k': 10
            }
        }
    }

    return strategies.get(use_case, strategies['knowledge_base'])

实践2:AB测试框架

class ABTestFramework:
    """
    AB测试框架

    用于对比不同检索策略的效果
    """

    def __init__(self, strategy_a: dict, strategy_b: dict):
        self.strategy_a = strategy_a
        self.strategy_b = strategy_b

    def run_test(self, test_queries: List[Dict]) -> dict:
        """
        运行AB测试

        Returns:
            {
                'strategy_a': metrics,
                'strategy_b': metrics,
                'winner': 'a' or 'b',
                'improvement': {
                    'hit_rate': float,
                    'mrr': float
                }
            }
        """
        # 评估策略A
        metrics_a = evaluate_retrieval_system(
            test_queries,
            self.strategy_a['retrieval_func']
        )

        # 评估策略B
        metrics_b = evaluate_retrieval_system(
            test_queries,
            strategy_b['retrieval_func']
        )

        # 对比
        winner = 'a' if metrics_a['hit_rate'] > metrics_b['hit_rate'] else 'b'

        improvement = {
            'hit_rate': ((metrics_b['hit_rate'] - metrics_a['hit_rate'])
                        / metrics_a['hit_rate'] * 100 if winner == 'b'
                        else (metrics_a['hit_rate'] - metrics_b['hit_rate'])
                        / metrics_b['hit_rate'] * 100),
            'mrr': ((metrics_b['mrr'] - metrics_a['mrr'])
                   / metrics_a['mrr'] * 100 if winner == 'b'
                   else (metrics_a['mrr'] - metrics_b['mrr'])
                   / metrics_b['mrr'] * 100)
        }

        return {
            'strategy_a': metrics_a,
            'strategy_b': metrics_b,
            'winner': winner,
            'improvement': improvement
        }

练习题

练习1:基础练习 - 实现RRF融合

题目:手动实现RRF融合算法

要求: 1. 实现RRF核心算法 2. 支持可配置的k参数 3. 支持加权融合

提示: - 排名从1开始(不是0) - 需要处理文档在不同检索器中出现的情况

参考答案:见rrf_fusion.py


练习2:进阶练习 - 构建混合检索系统

题目:基于LlamaIndex实现混合检索

要求: 1. 使用LlamaIndex的VectorStoreIndex 2. 集成BM25检索 3. 实现RRF融合 4. 评估Hit Rate和MRR

提示: - LlamaIndex提供了BM25Retriever - 可以使用RetrieverMode.DEFAULT混合模式


练习3:挑战项目 - 完整的检索优化系统

项目描述:构建一个生产级的混合检索+重排序系统

功能需求: 1. ✅ 混合检索(向量+BM25) 2. ✅ CrossEncoder重排序 3. ✅ 查询自适应权重调整 4. ✅ 缓存机制 5. ✅ 性能监控 6. ✅ AB测试框架

性能要求: - Hit Rate > 0.8 - MRR > 0.7 - 平均响应时间 < 500ms

交付标准: - ✅ 完整的代码实现 - ✅ 单元测试(测试覆盖率>80%) - ✅ 性能评估报告 - ✅ 使用文档


总结

本章要点回顾

  1. 混合检索优势
  2. 向量检索:语义理解强,适合模糊查询
  3. BM25检索:关键词精确,适合专有名词
  4. 混合检索:结合两者优势,性能提升20-30%

  5. RRF算法

  6. 倒数秩融合,避免分数归一化问题
  7. 公式:RRF = Σ 1/(k+rank)
  8. 鲁棒性强,适用性广

  9. 重排序技术

  10. CrossEncoder提供更精确的相关性判断
  11. 两阶段策略:快速筛选 + 精确重排
  12. 性能提升:+10-15%

  13. 性能优化

  14. 动态权重调整
  15. 缓存策略
  16. 批处理优化
  17. 两阶段重排序

学习检查清单

  • 理解混合检索的原理和优势
  • 能够实现RRF融合算法
  • 掌握CrossEncoder重排序
  • 能够构建完整的检索流程
  • 理解性能优化方法
  • 能够评估检索系统性能

下一步学习


返回目录 | 上一章 | 下一章


本章结束

有任何问题或建议?欢迎提交Issue或PR到教程仓库!