1. 环境准备¶
In [ ]:
Copied!
# 导入必要的库
import numpy as np
from typing import List, Dict, Any
from dataclasses import dataclass
print("检查环境...")
print(f"NumPy版本: {np.__version__}")
print("\n环境准备完成!")
# 导入必要的库
import numpy as np
from typing import List, Dict, Any
from dataclasses import dataclass
print("检查环境...")
print(f"NumPy版本: {np.__version__}")
print("\n环境准备完成!")
In [ ]:
Copied!
# 示例:有问题的用户查询
problematic_queries = [
{
"query": "那个修电脑的东西怎么用?",
"problem": "表达模糊,缺少主体",
"intent": "如何使用螺丝刀"
},
{
"query": "提高代码性能",
"problem": "信息缺失,缺少上下文",
"intent": "如何提高Python代码的性能?"
},
{
"query": "比较Python和JavaScript在Web开发中的差异",
"problem": "复杂多跳,需要分解",
"intent": "需要分解为3个子查询"
}
]
print("有问题的用户查询:")
print("=" * 80)
for i, example in enumerate(problematic_queries, 1):
print(f"\n例子 {i}:")
print(f" 查询: {example['query']}")
print(f" 问题: {example['problem']}")
print(f" 意图: {example['intent']}")
# 示例:有问题的用户查询
problematic_queries = [
{
"query": "那个修电脑的东西怎么用?",
"problem": "表达模糊,缺少主体",
"intent": "如何使用螺丝刀"
},
{
"query": "提高代码性能",
"problem": "信息缺失,缺少上下文",
"intent": "如何提高Python代码的性能?"
},
{
"query": "比较Python和JavaScript在Web开发中的差异",
"problem": "复杂多跳,需要分解",
"intent": "需要分解为3个子查询"
}
]
print("有问题的用户查询:")
print("=" * 80)
for i, example in enumerate(problematic_queries, 1):
print(f"\n例子 {i}:")
print(f" 查询: {example['query']}")
print(f" 问题: {example['problem']}")
print(f" 意图: {example['intent']}")
In [ ]:
Copied!
class HyDEQueryEnhancer:
"""
HyDE查询增强器
"""
def __init__(self, llm_generator=None):
"""
Args:
llm_generator: LLM生成函数(简化版)
"""
self.llm_generator = llm_generator or self._default_llm
def _default_llm(self, query: str) -> str:
"""
简化的LLM生成器(用于演示)
"""
# 实际应用中应该调用真实的LLM
hypothetical_docs = {
"python优化": "优化Python代码的方法包括:使用列表推导式替代循环、
使用生成器处理大数据、避免不必要的全局变量、
选择合适的数据结构、使用内置函数和库等。",
"机器学习": "机器学习是人工智能的分支,通过算法让计算机从数据中学习。
主要包括监督学习、无监督学习和强化学习。常用算法有
线性回归、决策树、神经网络等。",
"default": "这是一个关于{query}的详细回答。包含了相关概念、
实现方法、最佳实践和注意事项。"
}
for key, doc in hypothetical_docs.items():
if key in query.lower():
return doc
return hypothetical_docs["default"].format(query=query)
def enhance(self, query: str) -> Dict[str, str]:
"""
增强查询
Args:
query: 原始查询
Returns:
包含原始查询和假设答案的字典
"""
hypothetical_answer = self.llm_generator(query)
return {
"original_query": query,
"hypothetical_answer": hypothetical_answer,
"enhanced_query": hypothetical_answer # 用于检索
}
# 测试HyDE
hyde = HyDEQueryEnhancer()
test_queries = [
"如何优化Python代码?",
"什么是机器学习?",
]
print("HyDE查询增强示例:")
print("=" * 80)
for query in test_queries:
result = hyde.enhance(query)
print(f"\n原始查询: {result['original_query']}")
print(f"\n假设答案:\n{result['hypothetical_answer']}")
print("-" * 80)
class HyDEQueryEnhancer:
"""
HyDE查询增强器
"""
def __init__(self, llm_generator=None):
"""
Args:
llm_generator: LLM生成函数(简化版)
"""
self.llm_generator = llm_generator or self._default_llm
def _default_llm(self, query: str) -> str:
"""
简化的LLM生成器(用于演示)
"""
# 实际应用中应该调用真实的LLM
hypothetical_docs = {
"python优化": "优化Python代码的方法包括:使用列表推导式替代循环、
使用生成器处理大数据、避免不必要的全局变量、
选择合适的数据结构、使用内置函数和库等。",
"机器学习": "机器学习是人工智能的分支,通过算法让计算机从数据中学习。
主要包括监督学习、无监督学习和强化学习。常用算法有
线性回归、决策树、神经网络等。",
"default": "这是一个关于{query}的详细回答。包含了相关概念、
实现方法、最佳实践和注意事项。"
}
for key, doc in hypothetical_docs.items():
if key in query.lower():
return doc
return hypothetical_docs["default"].format(query=query)
def enhance(self, query: str) -> Dict[str, str]:
"""
增强查询
Args:
query: 原始查询
Returns:
包含原始查询和假设答案的字典
"""
hypothetical_answer = self.llm_generator(query)
return {
"original_query": query,
"hypothetical_answer": hypothetical_answer,
"enhanced_query": hypothetical_answer # 用于检索
}
# 测试HyDE
hyde = HyDEQueryEnhancer()
test_queries = [
"如何优化Python代码?",
"什么是机器学习?",
]
print("HyDE查询增强示例:")
print("=" * 80)
for query in test_queries:
result = hyde.enhance(query)
print(f"\n原始查询: {result['original_query']}")
print(f"\n假设答案:\n{result['hypothetical_answer']}")
print("-" * 80)
In [ ]:
Copied!
class QueryRewriter:
"""
查询重写器
"""
def __init__(self):
# 预设的重写规则
self.rewrite_rules = {
# 模糊表达 -> 明确表达
"那个": "具体",
"东西": "工具",
# 补充信息
"提高": "如何提高",
"优化": "如何优化",
# 明确主体
"修电脑": "维修电脑硬件",
}
def rewrite(self, query: str) -> str:
"""
重写查询
"""
rewritten = query
# 应用规则
for old, new in self.rewrite_rules.items():
rewritten = rewritten.replace(old, new)
# 添加疑问词(如果没有)
if not rewritten.startswith(('什么', '如何', '怎么', '为什么')):
if '?' not in rewritten and '?' not in rewritten:
rewritten = '如何' + rewritten
return rewritten
# 测试查询重写
rewriter = QueryRewriter()
queries_to_rewrite = [
"那个修电脑的东西怎么用",
"提高代码性能",
"优化数据库查询",
]
print("\n查询重写示例:")
print("=" * 60)
for query in queries_to_rewrite:
rewritten = rewriter.rewrite(query)
print(f"\n原始: {query}")
print(f"重写: {rewritten}")
class QueryRewriter:
"""
查询重写器
"""
def __init__(self):
# 预设的重写规则
self.rewrite_rules = {
# 模糊表达 -> 明确表达
"那个": "具体",
"东西": "工具",
# 补充信息
"提高": "如何提高",
"优化": "如何优化",
# 明确主体
"修电脑": "维修电脑硬件",
}
def rewrite(self, query: str) -> str:
"""
重写查询
"""
rewritten = query
# 应用规则
for old, new in self.rewrite_rules.items():
rewritten = rewritten.replace(old, new)
# 添加疑问词(如果没有)
if not rewritten.startswith(('什么', '如何', '怎么', '为什么')):
if '?' not in rewritten and '?' not in rewritten:
rewritten = '如何' + rewritten
return rewritten
# 测试查询重写
rewriter = QueryRewriter()
queries_to_rewrite = [
"那个修电脑的东西怎么用",
"提高代码性能",
"优化数据库查询",
]
print("\n查询重写示例:")
print("=" * 60)
for query in queries_to_rewrite:
rewritten = rewriter.rewrite(query)
print(f"\n原始: {query}")
print(f"重写: {rewritten}")
In [ ]:
Copied!
class MultiQueryStrategy:
"""
多查询策略
"""
def decompose_query(self, query: str) -> List[str]:
"""
将复杂查询分解为多个子查询
示例: "比较Python和JavaScript在Web开发中的差异"
-> [
"Python在Web开发中的特点",
"JavaScript在Web开发中的特点",
"Python和JavaScript的对比"
]
"""
# 简化实现:基于规则
sub_queries = []
# 检测比较类查询
if "比较" in query or "对比" in query or "差异" in query:
# 提取比较的两个主体
if "和" in query:
parts = query.split("和")
if len(parts) == 2:
entity1 = parts[0].replace("比较", "").replace("的", "").strip()
rest = parts[1]
if "在" in rest:
context_idx = rest.index("在")
entity2 = rest[:context_idx].strip()
context = rest[context_idx:].replace("中的", "").replace("差异", "").strip()
sub_queries = [
f"{entity1}在{context}中的特点",
f"{entity2}在{context}中的特点",
f"{entity1}和{entity2}的对比"
]
# 如果没有分解,返回原查询
if not sub_queries:
sub_queries = [query]
return sub_queries
def generate_variations(self, query: str, num_variations: int = 3) -> List[str]:
"""
生成查询的多个变体
"""
variations = [query] # 包含原查询
# 简化实现:基于模板
templates = [
"如何{query}",
"{query}的方法",
"{query}的最佳实践"
]
for i, template in enumerate(templates[:num_variations]):
variation = template.format(query=query)
if variation not in variations:
variations.append(variation)
return variations
# 测试多查询策略
multi_query = MultiQueryStrategy()
# 测试查询分解
complex_query = "比较Python和JavaScript在Web开发中的差异"
sub_queries = multi_query.decompose_query(complex_query)
print("\n查询分解示例:")
print("=" * 60)
print(f"\n原始查询: {complex_query}")
print(f"\n分解后的子查询:")
for i, sub_q in enumerate(sub_queries, 1):
print(f" {i}. {sub_q}")
# 测试查询变体
print("\n" + "=" * 60)
simple_query = "优化Python代码"
variations = multi_query.generate_variations(simple_query, num_variations=3)
print(f"\n原始查询: {simple_query}")
print(f"\n生成的变体:")
for i, variation in enumerate(variations, 1):
print(f" {i}. {variation}")
class MultiQueryStrategy:
"""
多查询策略
"""
def decompose_query(self, query: str) -> List[str]:
"""
将复杂查询分解为多个子查询
示例: "比较Python和JavaScript在Web开发中的差异"
-> [
"Python在Web开发中的特点",
"JavaScript在Web开发中的特点",
"Python和JavaScript的对比"
]
"""
# 简化实现:基于规则
sub_queries = []
# 检测比较类查询
if "比较" in query or "对比" in query or "差异" in query:
# 提取比较的两个主体
if "和" in query:
parts = query.split("和")
if len(parts) == 2:
entity1 = parts[0].replace("比较", "").replace("的", "").strip()
rest = parts[1]
if "在" in rest:
context_idx = rest.index("在")
entity2 = rest[:context_idx].strip()
context = rest[context_idx:].replace("中的", "").replace("差异", "").strip()
sub_queries = [
f"{entity1}在{context}中的特点",
f"{entity2}在{context}中的特点",
f"{entity1}和{entity2}的对比"
]
# 如果没有分解,返回原查询
if not sub_queries:
sub_queries = [query]
return sub_queries
def generate_variations(self, query: str, num_variations: int = 3) -> List[str]:
"""
生成查询的多个变体
"""
variations = [query] # 包含原查询
# 简化实现:基于模板
templates = [
"如何{query}",
"{query}的方法",
"{query}的最佳实践"
]
for i, template in enumerate(templates[:num_variations]):
variation = template.format(query=query)
if variation not in variations:
variations.append(variation)
return variations
# 测试多查询策略
multi_query = MultiQueryStrategy()
# 测试查询分解
complex_query = "比较Python和JavaScript在Web开发中的差异"
sub_queries = multi_query.decompose_query(complex_query)
print("\n查询分解示例:")
print("=" * 60)
print(f"\n原始查询: {complex_query}")
print(f"\n分解后的子查询:")
for i, sub_q in enumerate(sub_queries, 1):
print(f" {i}. {sub_q}")
# 测试查询变体
print("\n" + "=" * 60)
simple_query = "优化Python代码"
variations = multi_query.generate_variations(simple_query, num_variations=3)
print(f"\n原始查询: {simple_query}")
print(f"\n生成的变体:")
for i, variation in enumerate(variations, 1):
print(f" {i}. {variation}")
6. 完整的查询增强流程¶
In [ ]:
Copied!
class QueryEnhancementPipeline:
"""
完整的查询增强流程
"""
def __init__(self):
self.hyde = HyDEQueryEnhancer()
self.rewriter = QueryRewriter()
self.multi_query = MultiQueryStrategy()
def enhance(
self,
query: str,
use_hyde: bool = True,
use_rewrite: bool = True,
use_multi_query: bool = False
) -> Dict[str, Any]:
"""
完整的查询增强流程
Args:
query: 原始查询
use_hyde: 是否使用HyDE
use_rewrite: 是否重写查询
use_multi_query: 是否使用多查询
Returns:
增强结果
"""
result = {
"original_query": query,
"enhanced_queries": [],
"techniques_used": []
}
# 步骤1:查询重写
if use_rewrite:
rewritten_query = self.rewriter.rewrite(query)
result["rewritten_query"] = rewritten_query
result["enhanced_queries"].append(rewritten_query)
result["techniques_used"].append("query_rewrite")
# 步骤2:HyDE
if use_hyde:
hyde_result = self.hyde.enhance(rewritten_query if use_rewrite else query)
result["hypothetical_answer"] = hyde_result["hypothetical_answer"]
result["enhanced_queries"].append(hyde_result["enhanced_query"])
result["techniques_used"].append("hyde")
# 步骤3:多查询
if use_multi_query:
base_query = rewritten_query if use_rewrite else query
sub_queries = self.multi_query.decompose_query(base_query)
result["sub_queries"] = sub_queries
result["enhanced_queries"].extend(sub_queries)
result["techniques_used"].append("multi_query")
return result
# 测试完整流程
pipeline = QueryEnhancementPipeline()
test_query = "提高Python代码性能"
print("\n完整的查询增强流程:")
print("=" * 80)
result = pipeline.enhance(
test_query,
use_hyde=True,
use_rewrite=True,
use_multi_query=False
)
print(f"\n原始查询: {result['original_query']}")
print(f"\n重写查询: {result.get('rewritten_query', 'N/A')}")
print(f"\n假设答案: {result.get('hypothetical_answer', 'N/A')[:100]}...")
print(f"\n使用的技术: {', '.join(result['techniques_used'])}")
print(f"\n增强后的查询数量: {len(result['enhanced_queries'])}")
class QueryEnhancementPipeline:
"""
完整的查询增强流程
"""
def __init__(self):
self.hyde = HyDEQueryEnhancer()
self.rewriter = QueryRewriter()
self.multi_query = MultiQueryStrategy()
def enhance(
self,
query: str,
use_hyde: bool = True,
use_rewrite: bool = True,
use_multi_query: bool = False
) -> Dict[str, Any]:
"""
完整的查询增强流程
Args:
query: 原始查询
use_hyde: 是否使用HyDE
use_rewrite: 是否重写查询
use_multi_query: 是否使用多查询
Returns:
增强结果
"""
result = {
"original_query": query,
"enhanced_queries": [],
"techniques_used": []
}
# 步骤1:查询重写
if use_rewrite:
rewritten_query = self.rewriter.rewrite(query)
result["rewritten_query"] = rewritten_query
result["enhanced_queries"].append(rewritten_query)
result["techniques_used"].append("query_rewrite")
# 步骤2:HyDE
if use_hyde:
hyde_result = self.hyde.enhance(rewritten_query if use_rewrite else query)
result["hypothetical_answer"] = hyde_result["hypothetical_answer"]
result["enhanced_queries"].append(hyde_result["enhanced_query"])
result["techniques_used"].append("hyde")
# 步骤3:多查询
if use_multi_query:
base_query = rewritten_query if use_rewrite else query
sub_queries = self.multi_query.decompose_query(base_query)
result["sub_queries"] = sub_queries
result["enhanced_queries"].extend(sub_queries)
result["techniques_used"].append("multi_query")
return result
# 测试完整流程
pipeline = QueryEnhancementPipeline()
test_query = "提高Python代码性能"
print("\n完整的查询增强流程:")
print("=" * 80)
result = pipeline.enhance(
test_query,
use_hyde=True,
use_rewrite=True,
use_multi_query=False
)
print(f"\n原始查询: {result['original_query']}")
print(f"\n重写查询: {result.get('rewritten_query', 'N/A')}")
print(f"\n假设答案: {result.get('hypothetical_answer', 'N/A')[:100]}...")
print(f"\n使用的技术: {', '.join(result['techniques_used'])}")
print(f"\n增强后的查询数量: {len(result['enhanced_queries'])}")