1. 环境准备¶
导入必要的库并设置环境。
In [ ]:
Copied!
# 安装依赖
!pip install networkx matplotlib sentence-transformers spacy -q
# !python -m spacy download zh_core_web_sm # 中文NER
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Tuple, Set, Optional, Any
from dataclasses import dataclass
from collections import defaultdict
import re
import json
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei']
plt.rcParams['axes.unicode_minus'] = False
print("✅ 环境准备完成")
print(f"📦 NetworkX版本: {nx.__version__}")
# 安装依赖
!pip install networkx matplotlib sentence-transformers spacy -q
# !python -m spacy download zh_core_web_sm # 中文NER
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Tuple, Set, Optional, Any
from dataclasses import dataclass
from collections import defaultdict
import re
import json
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei']
plt.rcParams['axes.unicode_minus'] = False
print("✅ 环境准备完成")
print(f"📦 NetworkX版本: {nx.__version__}")
2. 知识图谱基础¶
2.1 什么是知识图谱?¶
知识图谱 = 实体(Entity)+ 关系(Relation)
示例:技术领域知识图谱
实体:
- Python(编程语言)
- JavaScript(编程语言)
- Guido van Rossum(人物)
- Brendan Eich(人物)
- Data Science(领域)
关系:
- Guido van Rossum --[创造者]--> Python
- Brendan Eich --[创造者]--> JavaScript
- Python --[应用于]--> Data Science
- Python --[竞争者]--> JavaScript
2.2 GraphRAG的优势¶
传统RAG:
- 基于向量相似度
- 无法捕获结构关系
- 难以处理多跳推理
GraphRAG:
- ✅ 捕获实体间关系
- ✅ 支持多跳推理
- ✅ 提供结构化上下文
- ✅ 解释推理路径
In [ ]:
Copied!
@dataclass
class Entity:
"""实体"""
id: str
name: str
type: str # Person, Technology, Company, Concept等
description: str = ""
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
if not isinstance(other, Entity):
return False
return self.id == other.id
@dataclass
class Relation:
"""关系"""
source: str # 源实体ID
target: str # 目标实体ID
relation_type: str # 关系类型
weight: float = 1.0 # 关系权重
def __hash__(self):
return hash((self.source, self.target, self.relation_type))
def __eq__(self, other):
if not isinstance(other, Relation):
return False
return (self.source == other.source and
self.target == other.target and
self.relation_type == other.relation_type)
class KnowledgeGraph:
"""知识图谱"""
def __init__(self):
# 使用NetworkX存储图谱
self.graph = nx.MultiDiGraph()
# 实体和关系的索引
self.entities: Dict[str, Entity] = {}
self.relations: List[Relation] = []
def add_entity(self, entity: Entity):
"""添加实体"""
self.entities[entity.id] = entity
self.graph.add_node(
entity.id,
name=entity.name,
type=entity.type,
description=entity.description
)
def add_relation(self, relation: Relation):
"""添加关系"""
if relation.source not in self.entities:
raise ValueError(f"源实体 {relation.source} 不存在")
if relation.target not in self.entities:
raise ValueError(f"目标实体 {relation.target} 不存在")
self.relations.append(relation)
self.graph.add_edge(
relation.source,
relation.target,
relation_type=relation.relation_type,
weight=relation.weight
)
def get_entity(self, entity_id: str) -> Optional[Entity]:
"""获取实体"""
return self.entities.get(entity_id)
def get_neighbors(self, entity_id: str,
relation_type: str = None) -> List[Entity]:
"""获取邻居实体"""
if entity_id not in self.graph:
return []
neighbors = []
for neighbor in self.graph.neighbors(entity_id):
# 检查关系类型
if relation_type:
edge_data = self.graph.get_edge_data(entity_id, neighbor)
if edge_data and edge_data.get('relation_type') == relation_type:
neighbors.append(self.entities[neighbor])
else:
neighbors.append(self.entities[neighbor])
return neighbors
def find_path(self, source_id: str, target_id: str) -> List[str]:
"""查找两个实体间的最短路径"""
try:
path = nx.shortest_path(self.graph, source_id, target_id)
return path
except nx.NetworkXNoPath:
return []
def get_subgraph(self, entity_ids: List[str],
hops: int = 1) -> 'KnowledgeGraph':
"""提取子图
Args:
entity_ids: 中心实体ID列表
hops: 跳数
Returns:
子图
"""
subgraph = KnowledgeGraph()
# 收集所有相关节点
nodes_to_include = set(entity_ids)
for hop in range(hops):
new_nodes = set()
for node_id in nodes_to_include:
if node_id in self.graph:
new_nodes.update(self.graph.neighbors(node_id))
nodes_to_include.update(new_nodes)
# 复制节点和边
for node_id in nodes_to_include:
if node_id in self.entities:
subgraph.add_entity(self.entities[node_id])
for relation in self.relations:
if (relation.source in nodes_to_include and
relation.target in nodes_to_include):
subgraph.add_relation(relation)
return subgraph
def get_stats(self) -> Dict:
"""获取图谱统计信息"""
return {
'实体数': len(self.entities),
'关系数': len(self.relations),
'节点数': self.graph.number_of_nodes(),
'边数': self.graph.number_of_edges(),
'密度': nx.density(self.graph),
'连通分量': nx.number_weakly_connected_components(self.graph)
}
print("✅ 知识图谱类定义完成")
@dataclass
class Entity:
"""实体"""
id: str
name: str
type: str # Person, Technology, Company, Concept等
description: str = ""
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
if not isinstance(other, Entity):
return False
return self.id == other.id
@dataclass
class Relation:
"""关系"""
source: str # 源实体ID
target: str # 目标实体ID
relation_type: str # 关系类型
weight: float = 1.0 # 关系权重
def __hash__(self):
return hash((self.source, self.target, self.relation_type))
def __eq__(self, other):
if not isinstance(other, Relation):
return False
return (self.source == other.source and
self.target == other.target and
self.relation_type == other.relation_type)
class KnowledgeGraph:
"""知识图谱"""
def __init__(self):
# 使用NetworkX存储图谱
self.graph = nx.MultiDiGraph()
# 实体和关系的索引
self.entities: Dict[str, Entity] = {}
self.relations: List[Relation] = []
def add_entity(self, entity: Entity):
"""添加实体"""
self.entities[entity.id] = entity
self.graph.add_node(
entity.id,
name=entity.name,
type=entity.type,
description=entity.description
)
def add_relation(self, relation: Relation):
"""添加关系"""
if relation.source not in self.entities:
raise ValueError(f"源实体 {relation.source} 不存在")
if relation.target not in self.entities:
raise ValueError(f"目标实体 {relation.target} 不存在")
self.relations.append(relation)
self.graph.add_edge(
relation.source,
relation.target,
relation_type=relation.relation_type,
weight=relation.weight
)
def get_entity(self, entity_id: str) -> Optional[Entity]:
"""获取实体"""
return self.entities.get(entity_id)
def get_neighbors(self, entity_id: str,
relation_type: str = None) -> List[Entity]:
"""获取邻居实体"""
if entity_id not in self.graph:
return []
neighbors = []
for neighbor in self.graph.neighbors(entity_id):
# 检查关系类型
if relation_type:
edge_data = self.graph.get_edge_data(entity_id, neighbor)
if edge_data and edge_data.get('relation_type') == relation_type:
neighbors.append(self.entities[neighbor])
else:
neighbors.append(self.entities[neighbor])
return neighbors
def find_path(self, source_id: str, target_id: str) -> List[str]:
"""查找两个实体间的最短路径"""
try:
path = nx.shortest_path(self.graph, source_id, target_id)
return path
except nx.NetworkXNoPath:
return []
def get_subgraph(self, entity_ids: List[str],
hops: int = 1) -> 'KnowledgeGraph':
"""提取子图
Args:
entity_ids: 中心实体ID列表
hops: 跳数
Returns:
子图
"""
subgraph = KnowledgeGraph()
# 收集所有相关节点
nodes_to_include = set(entity_ids)
for hop in range(hops):
new_nodes = set()
for node_id in nodes_to_include:
if node_id in self.graph:
new_nodes.update(self.graph.neighbors(node_id))
nodes_to_include.update(new_nodes)
# 复制节点和边
for node_id in nodes_to_include:
if node_id in self.entities:
subgraph.add_entity(self.entities[node_id])
for relation in self.relations:
if (relation.source in nodes_to_include and
relation.target in nodes_to_include):
subgraph.add_relation(relation)
return subgraph
def get_stats(self) -> Dict:
"""获取图谱统计信息"""
return {
'实体数': len(self.entities),
'关系数': len(self.relations),
'节点数': self.graph.number_of_nodes(),
'边数': self.graph.number_of_edges(),
'密度': nx.density(self.graph),
'连通分量': nx.number_weakly_connected_components(self.graph)
}
print("✅ 知识图谱类定义完成")
3.2 构建示例知识图谱¶
In [ ]:
Copied!
# 创建技术领域知识图谱
kg = KnowledgeGraph()
# 添加实体
entities_data = [
Entity("E1", "Python", "Technology", "一种高级编程语言"),
Entity("E2", "JavaScript", "Technology", "Web开发脚本语言"),
Entity("E3", "Java", "Technology", "通用编程语言"),
Entity("E4", "Guido van Rossum", "Person", "Python创造者"),
Entity("E5", "Brendan Eich", "Person", "JavaScript创造者"),
Entity("E6", "James Gosling", "Person", "Java创造者"),
Entity("E7", "Data Science", "Field", "数据科学"),
Entity("E8", "Web Development", "Field", "Web开发"),
Entity("E9", "Machine Learning", "Field", "机器学习"),
Entity("E10", "Backend", "Field", "后端开发"),
Entity("E11", "RAG", "Concept", "检索增强生成"),
Entity("E12", "Vector DB", "Concept", "向量数据库"),
Entity("E13", "LLM", "Concept", "大语言模型"),
Entity("E14", "Transformer", "Concept", "Transformer架构"),
Entity("E15", "Attention", "Concept", "注意力机制"),
]
for entity in entities_data:
kg.add_entity(entity)
# 添加关系
relations_data = [
# 创造者关系
Relation("E4", "E1", "创造者"),
Relation("E5", "E2", "创造者"),
Relation("E6", "E3", "创造者"),
# 应用领域
Relation("E1", "E7", "应用于"),
Relation("E1", "E9", "应用于"),
Relation("E2", "E8", "应用于"),
Relation("E3", "E10", "应用于"),
# 技术关系
Relation("E11", "E12", "依赖"),
Relation("E11", "E13", "依赖"),
Relation("E13", "E14", "基于"),
Relation("E14", "E15", "包含"),
# 竞争关系
Relation("E1", "E2", "竞争者"),
Relation("E1", "E3", "竞争者"),
Relation("E2", "E3", "竞争者"),
# 语言特点
Relation("E1", "E7", "适合"),
Relation("E2", "E8", "适合"),
]
for relation in relations_data:
kg.add_relation(relation)
print("✅ 知识图谱构建完成")
print(f"📊 图谱统计: {kg.get_stats()}")
# 创建技术领域知识图谱
kg = KnowledgeGraph()
# 添加实体
entities_data = [
Entity("E1", "Python", "Technology", "一种高级编程语言"),
Entity("E2", "JavaScript", "Technology", "Web开发脚本语言"),
Entity("E3", "Java", "Technology", "通用编程语言"),
Entity("E4", "Guido van Rossum", "Person", "Python创造者"),
Entity("E5", "Brendan Eich", "Person", "JavaScript创造者"),
Entity("E6", "James Gosling", "Person", "Java创造者"),
Entity("E7", "Data Science", "Field", "数据科学"),
Entity("E8", "Web Development", "Field", "Web开发"),
Entity("E9", "Machine Learning", "Field", "机器学习"),
Entity("E10", "Backend", "Field", "后端开发"),
Entity("E11", "RAG", "Concept", "检索增强生成"),
Entity("E12", "Vector DB", "Concept", "向量数据库"),
Entity("E13", "LLM", "Concept", "大语言模型"),
Entity("E14", "Transformer", "Concept", "Transformer架构"),
Entity("E15", "Attention", "Concept", "注意力机制"),
]
for entity in entities_data:
kg.add_entity(entity)
# 添加关系
relations_data = [
# 创造者关系
Relation("E4", "E1", "创造者"),
Relation("E5", "E2", "创造者"),
Relation("E6", "E3", "创造者"),
# 应用领域
Relation("E1", "E7", "应用于"),
Relation("E1", "E9", "应用于"),
Relation("E2", "E8", "应用于"),
Relation("E3", "E10", "应用于"),
# 技术关系
Relation("E11", "E12", "依赖"),
Relation("E11", "E13", "依赖"),
Relation("E13", "E14", "基于"),
Relation("E14", "E15", "包含"),
# 竞争关系
Relation("E1", "E2", "竞争者"),
Relation("E1", "E3", "竞争者"),
Relation("E2", "E3", "竞争者"),
# 语言特点
Relation("E1", "E7", "适合"),
Relation("E2", "E8", "适合"),
]
for relation in relations_data:
kg.add_relation(relation)
print("✅ 知识图谱构建完成")
print(f"📊 图谱统计: {kg.get_stats()}")
3.3 可视化知识图谱¶
In [ ]:
Copied!
def visualize_graph(kg: KnowledgeGraph,
figsize=(14, 10),
highlight_nodes: List[str] = None):
"""可视化知识图谱"""
plt.figure(figsize=figsize)
# 使用spring layout
pos = nx.spring_layout(kg.graph, k=2, iterations=50, seed=42)
# 按类型设置颜色
node_colors = []
color_map = {
'Technology': '#FF6B6B',
'Person': '#4ECDC4',
'Field': '#45B7D1',
'Concept': '#96CEB4'
}
for node_id in kg.graph.nodes():
entity = kg.entities[node_id]
node_colors.append(color_map.get(entity.type, '#CCCCCC'))
# 绘制节点
nx.draw_networkx_nodes(kg.graph, pos,
node_color=node_colors,
node_size=1500,
alpha=0.8)
# 高亮节点
if highlight_nodes:
nx.draw_networkx_nodes(kg.graph,
pos={n: pos[n] for n in highlight_nodes},
node_color='yellow',
node_size=2000,
alpha=0.5)
# 绘制边
nx.draw_networkx_edges(kg.graph, pos,
edge_color='gray',
arrows=True,
arrowsize=20,
width=1.5,
alpha=0.6)
# 绘制标签
labels = {node_id: kg.entities[node_id].name
for node_id in kg.graph.nodes()}
nx.draw_networkx_labels(kg.graph, pos,
labels=labels,
font_size=10,
font_weight='bold')
# 绘制边标签(关系类型)
edge_labels = {}
for source, target, data in kg.graph.edges(data=True):
edge_labels[(source, target)] = data.get('relation_type', '')
nx.draw_networkx_edge_labels(kg.graph, pos,
edge_labels=edge_labels,
font_size=8)
# 图例
legend_elements = [plt.Line2D([0], [0], marker='o', color='w',
label=entity_type,
markersize=15,
markerfacecolor=color)
for entity_type, color in color_map.items()]
plt.legend(handles=legend_elements, loc='upper left')
plt.title("知识图谱可视化", fontsize=16, weight='bold')
plt.axis('off')
plt.tight_layout()
plt.show()
# 可视化完整图谱
visualize_graph(kg)
def visualize_graph(kg: KnowledgeGraph,
figsize=(14, 10),
highlight_nodes: List[str] = None):
"""可视化知识图谱"""
plt.figure(figsize=figsize)
# 使用spring layout
pos = nx.spring_layout(kg.graph, k=2, iterations=50, seed=42)
# 按类型设置颜色
node_colors = []
color_map = {
'Technology': '#FF6B6B',
'Person': '#4ECDC4',
'Field': '#45B7D1',
'Concept': '#96CEB4'
}
for node_id in kg.graph.nodes():
entity = kg.entities[node_id]
node_colors.append(color_map.get(entity.type, '#CCCCCC'))
# 绘制节点
nx.draw_networkx_nodes(kg.graph, pos,
node_color=node_colors,
node_size=1500,
alpha=0.8)
# 高亮节点
if highlight_nodes:
nx.draw_networkx_nodes(kg.graph,
pos={n: pos[n] for n in highlight_nodes},
node_color='yellow',
node_size=2000,
alpha=0.5)
# 绘制边
nx.draw_networkx_edges(kg.graph, pos,
edge_color='gray',
arrows=True,
arrowsize=20,
width=1.5,
alpha=0.6)
# 绘制标签
labels = {node_id: kg.entities[node_id].name
for node_id in kg.graph.nodes()}
nx.draw_networkx_labels(kg.graph, pos,
labels=labels,
font_size=10,
font_weight='bold')
# 绘制边标签(关系类型)
edge_labels = {}
for source, target, data in kg.graph.edges(data=True):
edge_labels[(source, target)] = data.get('relation_type', '')
nx.draw_networkx_edge_labels(kg.graph, pos,
edge_labels=edge_labels,
font_size=8)
# 图例
legend_elements = [plt.Line2D([0], [0], marker='o', color='w',
label=entity_type,
markersize=15,
markerfacecolor=color)
for entity_type, color in color_map.items()]
plt.legend(handles=legend_elements, loc='upper left')
plt.title("知识图谱可视化", fontsize=16, weight='bold')
plt.axis('off')
plt.tight_layout()
plt.show()
# 可视化完整图谱
visualize_graph(kg)
In [ ]:
Copied!
class EntityExtractor:
"""实体提取器(简化版)"""
def __init__(self, knowledge_graph: KnowledgeGraph):
self.kg = knowledge_graph
self.entity_names = {e.name.lower(): e.id
for e in knowledge_graph.entities.values()}
def extract_entities(self, text: str) -> List[Tuple[str, str]]:
"""从文本中提取实体
Returns:
[(实体名称, 实体ID)]
"""
found_entities = []
text_lower = text.lower()
# 简单的字符串匹配
for entity_name, entity_id in self.entity_names.items():
if entity_name in text_lower:
entity = self.kg.entities[entity_id]
found_entities.append((entity.name, entity_id))
return found_entities
# 测试实体提取
extractor = EntityExtractor(kg)
test_queries = [
"Python和JavaScript有什么区别?",
"RAG使用了哪些技术?",
"Transformer架构包含什么?"
]
print("🔍 实体识别测试:\n")
for query in test_queries:
entities = extractor.extract_entities(query)
print(f"查询: {query}")
print(f"识别实体: {[e[0] for e in entities]}")
print()
class EntityExtractor:
"""实体提取器(简化版)"""
def __init__(self, knowledge_graph: KnowledgeGraph):
self.kg = knowledge_graph
self.entity_names = {e.name.lower(): e.id
for e in knowledge_graph.entities.values()}
def extract_entities(self, text: str) -> List[Tuple[str, str]]:
"""从文本中提取实体
Returns:
[(实体名称, 实体ID)]
"""
found_entities = []
text_lower = text.lower()
# 简单的字符串匹配
for entity_name, entity_id in self.entity_names.items():
if entity_name in text_lower:
entity = self.kg.entities[entity_id]
found_entities.append((entity.name, entity_id))
return found_entities
# 测试实体提取
extractor = EntityExtractor(kg)
test_queries = [
"Python和JavaScript有什么区别?",
"RAG使用了哪些技术?",
"Transformer架构包含什么?"
]
print("🔍 实体识别测试:\n")
for query in test_queries:
entities = extractor.extract_entities(query)
print(f"查询: {query}")
print(f"识别实体: {[e[0] for e in entities]}")
print()
In [ ]:
Copied!
class GraphRetriever:
"""图谱检索器"""
def __init__(self, knowledge_graph: KnowledgeGraph):
self.kg = knowledge_graph
self.entity_extractor = EntityExtractor(knowledge_graph)
def retrieve_by_entity(self, entity_id: str,
max_hops: int = 2) -> Dict:
"""基于实体的多跳检索
Args:
entity_id: 实体ID
max_hops: 最大跳数
Returns:
{'entities': List[Entity], 'relations': List[Relation], 'paths': List}
"""
if entity_id not in self.kg.entities:
return {'entities': [], 'relations': [], 'paths': []}
# 提取子图
subgraph = self.kg.get_subgraph([entity_id], hops=max_hops)
# 收集路径
paths = self._find_paths(entity_id, subgraph, max_hops)
return {
'entities': list(subgraph.entities.values()),
'relations': subgraph.relations,
'paths': paths,
'subgraph': subgraph
}
def retrieve_by_query(self, query: str,
max_hops: int = 2) -> Dict:
"""基于查询的检索"""
# 提取实体
entities = self.entity_extractor.extract_entities(query)
if not entities:
return {'entities': [], 'relations': [], 'paths': []}
# 合并所有实体的检索结果
all_entities = set()
all_relations = []
all_paths = []
for _, entity_id in entities:
result = self.retrieve_by_entity(entity_id, max_hops)
all_entities.update(e.id for e in result['entities'])
all_relations.extend(result['relations'])
all_paths.extend(result['paths'])
return {
'entities': [self.kg.entities[eid] for eid in all_entities],
'relations': all_relations,
'paths': all_paths[:10], # 限制路径数量
'matched_entities': entities
}
def _find_paths(self, source_id: str,
subgraph: KnowledgeGraph,
max_length: int) -> List[List[str]]:
"""查找从源实体出发的所有路径"""
paths = []
for target_id in subgraph.entities.keys():
if target_id == source_id:
continue
try:
path = nx.shortest_path(
subgraph.graph,
source_id,
target_id
)
if len(path) <= max_length + 1:
paths.append(path)
except nx.NetworkXNoPath:
continue
# 按长度排序
paths.sort(key=len)
return paths
# 创建检索器
retriever = GraphRetriever(kg)
print("✅ 图谱检索器创建完成")
class GraphRetriever:
"""图谱检索器"""
def __init__(self, knowledge_graph: KnowledgeGraph):
self.kg = knowledge_graph
self.entity_extractor = EntityExtractor(knowledge_graph)
def retrieve_by_entity(self, entity_id: str,
max_hops: int = 2) -> Dict:
"""基于实体的多跳检索
Args:
entity_id: 实体ID
max_hops: 最大跳数
Returns:
{'entities': List[Entity], 'relations': List[Relation], 'paths': List}
"""
if entity_id not in self.kg.entities:
return {'entities': [], 'relations': [], 'paths': []}
# 提取子图
subgraph = self.kg.get_subgraph([entity_id], hops=max_hops)
# 收集路径
paths = self._find_paths(entity_id, subgraph, max_hops)
return {
'entities': list(subgraph.entities.values()),
'relations': subgraph.relations,
'paths': paths,
'subgraph': subgraph
}
def retrieve_by_query(self, query: str,
max_hops: int = 2) -> Dict:
"""基于查询的检索"""
# 提取实体
entities = self.entity_extractor.extract_entities(query)
if not entities:
return {'entities': [], 'relations': [], 'paths': []}
# 合并所有实体的检索结果
all_entities = set()
all_relations = []
all_paths = []
for _, entity_id in entities:
result = self.retrieve_by_entity(entity_id, max_hops)
all_entities.update(e.id for e in result['entities'])
all_relations.extend(result['relations'])
all_paths.extend(result['paths'])
return {
'entities': [self.kg.entities[eid] for eid in all_entities],
'relations': all_relations,
'paths': all_paths[:10], # 限制路径数量
'matched_entities': entities
}
def _find_paths(self, source_id: str,
subgraph: KnowledgeGraph,
max_length: int) -> List[List[str]]:
"""查找从源实体出发的所有路径"""
paths = []
for target_id in subgraph.entities.keys():
if target_id == source_id:
continue
try:
path = nx.shortest_path(
subgraph.graph,
source_id,
target_id
)
if len(path) <= max_length + 1:
paths.append(path)
except nx.NetworkXNoPath:
continue
# 按长度排序
paths.sort(key=len)
return paths
# 创建检索器
retriever = GraphRetriever(kg)
print("✅ 图谱检索器创建完成")
5.2 测试图谱检索¶
In [ ]:
Copied!
# 测试1:基于实体的检索
print("🔍 测试1: 查找Python相关信息 (2跳)\n")
result1 = retriever.retrieve_by_entity("E1", max_hops=2)
print(f"找到实体数: {len(result1['entities'])}")
print(f"找到关系数: {len(result1['relations'])}")
print(f"推理路径数: {len(result1['paths'])}")
print("\n推理路径示例:")
for i, path in enumerate(result1['paths'][:5], 1):
path_str = " → ".join([kg.entities[eid].name for eid in path])
print(f" {i}. {path_str}")
# 测试1:基于实体的检索
print("🔍 测试1: 查找Python相关信息 (2跳)\n")
result1 = retriever.retrieve_by_entity("E1", max_hops=2)
print(f"找到实体数: {len(result1['entities'])}")
print(f"找到关系数: {len(result1['relations'])}")
print(f"推理路径数: {len(result1['paths'])}")
print("\n推理路径示例:")
for i, path in enumerate(result1['paths'][:5], 1):
path_str = " → ".join([kg.entities[eid].name for eid in path])
print(f" {i}. {path_str}")
In [ ]:
Copied!
# 测试2:基于查询的检索
print("🔍 测试2: 查询 'RAG依赖哪些技术'\n")
result2 = retriever.retrieve_by_query("RAG依赖哪些技术", max_hops=2)
print(f"匹配实体: {[e[0] for e in result2['matched_entities']]}")
print(f"\n相关实体:")
for entity in result2['entities']:
print(f" - {entity.name} ({entity.type})")
print(f"\n推理路径:")
for i, path in enumerate(result2['paths'][:5], 1):
path_str = " → ".join([kg.entities[eid].name for eid in path])
print(f" {i}. {path_str}")
# 测试2:基于查询的检索
print("🔍 测试2: 查询 'RAG依赖哪些技术'\n")
result2 = retriever.retrieve_by_query("RAG依赖哪些技术", max_hops=2)
print(f"匹配实体: {[e[0] for e in result2['matched_entities']]}")
print(f"\n相关实体:")
for entity in result2['entities']:
print(f" - {entity.name} ({entity.type})")
print(f"\n推理路径:")
for i, path in enumerate(result2['paths'][:5], 1):
path_str = " → ".join([kg.entities[eid].name for eid in path])
print(f" {i}. {path_str}")
In [ ]:
Copied!
class GraphRAG:
"""
GraphRAG系统
结合知识图谱检索和向量检索
"""
def __init__(self,
knowledge_graph: KnowledgeGraph,
use_graph: bool = True,
graph_weight: float = 0.5):
self.kg = knowledge_graph
self.graph_retriever = GraphRetriever(knowledge_graph)
self.use_graph = use_graph
self.graph_weight = graph_weight
def query(self,
query: str,
max_hops: int = 2,
verbose: bool = False) -> Dict:
"""查询
Args:
query: 查询文本
max_hops: 最大跳数
verbose: 是否显示详细信息
Returns:
{'answer': str, 'graph_context': ..., 'paths': ...}
"""
# 图谱检索
graph_result = self.graph_retriever.retrieve_by_query(
query, max_hops=max_hops
)
if verbose:
print(f"\n=== GraphRAG查询 ===")
print(f"查询: {query}")
print(f"\n匹配实体: {[e[0] for e in graph_result['matched_entities']]}")
print(f"找到实体数: {len(graph_result['entities'])}")
print(f"推理路径数: {len(graph_result['paths'])}")
# 构建图谱上下文
graph_context = self._build_graph_context(graph_result)
# 生成答案(这里使用简单的模板)
answer = self._generate_answer(query, graph_context, graph_result)
return {
'answer': answer,
'graph_context': graph_context,
'entities': graph_result['entities'],
'relations': graph_result['relations'],
'paths': graph_result['paths'],
'matched_entities': graph_result['matched_entities']
}
def _build_graph_context(self, graph_result: Dict) -> str:
"""构建图谱上下文文本"""
if not graph_result['entities']:
return "未找到相关知识图谱信息"
context_parts = []
# 实体信息
entities_info = []
for entity in graph_result['entities'][:10]:
entities_info.append(f"- {entity.name}: {entity.description}")
if entities_info:
context_parts.append("相关实体:\n" + "\n".join(entities_info))
# 关系信息
if graph_result['relations']:
relations_info = []
for rel in graph_result['relations'][:15]:
source = self.kg.entities[rel.source].name
target = self.kg.entities[rel.target].name
relations_info.append(f"- {source} --[{rel.relation_type}]--> {target}")
context_parts.append("关系:\n" + "\n".join(relations_info))
# 推理路径
if graph_result['paths']:
paths_info = []
for i, path in enumerate(graph_result['paths'][:5], 1):
path_str = " → ".join([self.kg.entities[eid].name for eid in path])
paths_info.append(f"{i}. {path_str}")
context_parts.append("推理路径:\n" + "\n".join(paths_info))
return "\n\n".join(context_parts)
def _generate_answer(self,
query: str,
graph_context: str,
graph_result: Dict) -> str:
"""生成答案(简化版)"""
# 在实际应用中,这里会调用LLM
matched = [e[0] for e in graph_result['matched_entities']]
if not matched:
return "抱歉,我无法在知识图谱中找到相关信息。"
# 构建简单答案
answer_parts = [f"基于知识图谱,我找到了关于 {', '.join(matched)} 的信息:"]
# 添加路径信息
if graph_result['paths']:
answer_parts.append("\n\n关键发现:")
for path in graph_result['paths'][:3]:
path_str = " → ".join([self.kg.entities[eid].name for eid in path])
answer_parts.append(f"- {path_str}")
return "\n".join(answer_parts)
# 创建GraphRAG系统
graph_rag = GraphRAG(kg, use_graph=True, graph_weight=0.5)
print("✅ GraphRAG系统创建完成")
class GraphRAG:
"""
GraphRAG系统
结合知识图谱检索和向量检索
"""
def __init__(self,
knowledge_graph: KnowledgeGraph,
use_graph: bool = True,
graph_weight: float = 0.5):
self.kg = knowledge_graph
self.graph_retriever = GraphRetriever(knowledge_graph)
self.use_graph = use_graph
self.graph_weight = graph_weight
def query(self,
query: str,
max_hops: int = 2,
verbose: bool = False) -> Dict:
"""查询
Args:
query: 查询文本
max_hops: 最大跳数
verbose: 是否显示详细信息
Returns:
{'answer': str, 'graph_context': ..., 'paths': ...}
"""
# 图谱检索
graph_result = self.graph_retriever.retrieve_by_query(
query, max_hops=max_hops
)
if verbose:
print(f"\n=== GraphRAG查询 ===")
print(f"查询: {query}")
print(f"\n匹配实体: {[e[0] for e in graph_result['matched_entities']]}")
print(f"找到实体数: {len(graph_result['entities'])}")
print(f"推理路径数: {len(graph_result['paths'])}")
# 构建图谱上下文
graph_context = self._build_graph_context(graph_result)
# 生成答案(这里使用简单的模板)
answer = self._generate_answer(query, graph_context, graph_result)
return {
'answer': answer,
'graph_context': graph_context,
'entities': graph_result['entities'],
'relations': graph_result['relations'],
'paths': graph_result['paths'],
'matched_entities': graph_result['matched_entities']
}
def _build_graph_context(self, graph_result: Dict) -> str:
"""构建图谱上下文文本"""
if not graph_result['entities']:
return "未找到相关知识图谱信息"
context_parts = []
# 实体信息
entities_info = []
for entity in graph_result['entities'][:10]:
entities_info.append(f"- {entity.name}: {entity.description}")
if entities_info:
context_parts.append("相关实体:\n" + "\n".join(entities_info))
# 关系信息
if graph_result['relations']:
relations_info = []
for rel in graph_result['relations'][:15]:
source = self.kg.entities[rel.source].name
target = self.kg.entities[rel.target].name
relations_info.append(f"- {source} --[{rel.relation_type}]--> {target}")
context_parts.append("关系:\n" + "\n".join(relations_info))
# 推理路径
if graph_result['paths']:
paths_info = []
for i, path in enumerate(graph_result['paths'][:5], 1):
path_str = " → ".join([self.kg.entities[eid].name for eid in path])
paths_info.append(f"{i}. {path_str}")
context_parts.append("推理路径:\n" + "\n".join(paths_info))
return "\n\n".join(context_parts)
def _generate_answer(self,
query: str,
graph_context: str,
graph_result: Dict) -> str:
"""生成答案(简化版)"""
# 在实际应用中,这里会调用LLM
matched = [e[0] for e in graph_result['matched_entities']]
if not matched:
return "抱歉,我无法在知识图谱中找到相关信息。"
# 构建简单答案
answer_parts = [f"基于知识图谱,我找到了关于 {', '.join(matched)} 的信息:"]
# 添加路径信息
if graph_result['paths']:
answer_parts.append("\n\n关键发现:")
for path in graph_result['paths'][:3]:
path_str = " → ".join([self.kg.entities[eid].name for eid in path])
answer_parts.append(f"- {path_str}")
return "\n".join(answer_parts)
# 创建GraphRAG系统
graph_rag = GraphRAG(kg, use_graph=True, graph_weight=0.5)
print("✅ GraphRAG系统创建完成")
6.2 GraphRAG查询测试¶
In [ ]:
Copied!
# 查询1:RAG相关技术
result1 = graph_rag.query(
"RAG使用了哪些技术?",
max_hops=2,
verbose=True
)
print("\n" + "="*50)
print("📝 答案:")
print(result1['answer'])
# 查询1:RAG相关技术
result1 = graph_rag.query(
"RAG使用了哪些技术?",
max_hops=2,
verbose=True
)
print("\n" + "="*50)
print("📝 答案:")
print(result1['answer'])
In [ ]:
Copied!
# 查询2:Python的应用领域
result2 = graph_rag.query(
"Python适合什么领域?",
max_hops=2,
verbose=True
)
print("\n" + "="*50)
print("📝 答案:")
print(result2['answer'])
# 查询2:Python的应用领域
result2 = graph_rag.query(
"Python适合什么领域?",
max_hops=2,
verbose=True
)
print("\n" + "="*50)
print("📝 答案:")
print(result2['answer'])
In [ ]:
Copied!
# 查询3:复杂推理
result3 = graph_rag.query(
"Transformer和RAG有什么关系?",
max_hops=3,
verbose=True
)
print("\n" + "="*50)
print("📝 答案:")
print(result3['answer'])
# 查询3:复杂推理
result3 = graph_rag.query(
"Transformer和RAG有什么关系?",
max_hops=3,
verbose=True
)
print("\n" + "="*50)
print("📝 答案:")
print(result3['answer'])
In [ ]:
Copied!
def visualize_reasoning_path(kg: KnowledgeGraph,
path: List[str],
figsize=(12, 6)):
"""可视化推理路径"""
plt.figure(figsize=figsize)
# 提取路径子图
subgraph_nodes = set(path)
# 创建子图
path_graph = nx.DiGraph()
# 添加节点
for i, node_id in enumerate(path):
entity = kg.entities[node_id]
path_graph.add_node(
node_id,
name=entity.name,
type=entity.type,
pos=(i, 0)
)
# 添加边
for i in range(len(path) - 1):
source, target = path[i], path[i+1]
# 找到关系类型
edge_data = kg.graph.get_edge_data(source, target)
if edge_data:
rel_type = edge_data[0].get('relation_type', 'related')
path_graph.add_edge(source, target, relation_type=rel_type)
# 绘制
pos = nx.get_node_attributes(path_graph, 'pos')
# 节点
nx.draw_networkx_nodes(path_graph, pos,
node_color='lightblue',
node_size=2000,
alpha=0.8)
# 边
nx.draw_networkx_edges(path_graph, pos,
edge_color='gray',
arrows=True,
arrowsize=30,
width=2)
# 节点标签
labels = nx.get_node_attributes(path_graph, 'name')
nx.draw_networkx_labels(path_graph, pos,
labels=labels,
font_size=10,
font_weight='bold')
# 边标签
edge_labels = nx.get_edge_attributes(path_graph, 'relation_type')
nx.draw_networkx_edge_labels(path_graph, pos,
edge_labels=edge_labels,
font_size=9)
plt.title("推理路径可视化", fontsize=14, weight='bold')
plt.axis('off')
plt.tight_layout()
plt.show()
# 可视化推理路径
if result3['paths']:
visualize_reasoning_path(kg, result3['paths'][0])
def visualize_reasoning_path(kg: KnowledgeGraph,
path: List[str],
figsize=(12, 6)):
"""可视化推理路径"""
plt.figure(figsize=figsize)
# 提取路径子图
subgraph_nodes = set(path)
# 创建子图
path_graph = nx.DiGraph()
# 添加节点
for i, node_id in enumerate(path):
entity = kg.entities[node_id]
path_graph.add_node(
node_id,
name=entity.name,
type=entity.type,
pos=(i, 0)
)
# 添加边
for i in range(len(path) - 1):
source, target = path[i], path[i+1]
# 找到关系类型
edge_data = kg.graph.get_edge_data(source, target)
if edge_data:
rel_type = edge_data[0].get('relation_type', 'related')
path_graph.add_edge(source, target, relation_type=rel_type)
# 绘制
pos = nx.get_node_attributes(path_graph, 'pos')
# 节点
nx.draw_networkx_nodes(path_graph, pos,
node_color='lightblue',
node_size=2000,
alpha=0.8)
# 边
nx.draw_networkx_edges(path_graph, pos,
edge_color='gray',
arrows=True,
arrowsize=30,
width=2)
# 节点标签
labels = nx.get_node_attributes(path_graph, 'name')
nx.draw_networkx_labels(path_graph, pos,
labels=labels,
font_size=10,
font_weight='bold')
# 边标签
edge_labels = nx.get_edge_attributes(path_graph, 'relation_type')
nx.draw_networkx_edge_labels(path_graph, pos,
edge_labels=edge_labels,
font_size=9)
plt.title("推理路径可视化", fontsize=14, weight='bold')
plt.axis('off')
plt.tight_layout()
plt.show()
# 可视化推理路径
if result3['paths']:
visualize_reasoning_path(kg, result3['paths'][0])
7.2 子图可视化¶
In [ ]:
Copied!
# 可视化查询相关的子图
if result3['matched_entities']:
# 获取匹配实体的ID
entity_ids = [e[1] for e in result3['matched_entities']]
# 提取子图
subgraph = kg.get_subgraph(entity_ids, hops=2)
# 可视化
visualize_graph(subgraph, figsize=(12, 8), highlight_nodes=entity_ids)
# 可视化查询相关的子图
if result3['matched_entities']:
# 获取匹配实体的ID
entity_ids = [e[1] for e in result3['matched_entities']]
# 提取子图
subgraph = kg.get_subgraph(entity_ids, hops=2)
# 可视化
visualize_graph(subgraph, figsize=(12, 8), highlight_nodes=entity_ids)
In [ ]:
Copied!
# 测试不同跳数的检索效果
test_queries = [
"RAG依赖什么?",
"Python的应用领域?",
"Transformer包含什么?"
]
hop_counts = [1, 2, 3]
results_summary = []
for query in test_queries:
for hops in hop_counts:
result = graph_rag.query(query, max_hops=hops, verbose=False)
results_summary.append({
'查询': query,
'跳数': hops,
'实体数': len(result['entities']),
'关系数': len(result['relations']),
'路径数': len(result['paths'])
})
# 显示结果
print("\n📊 不同跳数的检索效果对比:\n")
print(f"{'查询':<25} {'跳数':<6} {'实体数':<8} {'关系数':<8} {'路径数':<8}")
print("-" * 60)
for r in results_summary:
print(f"{r['查询']:<25} {r['跳数']:<6} {r['实体数']:<8} {r['关系数']:<8} {r['路径数']:<8}")
# 测试不同跳数的检索效果
test_queries = [
"RAG依赖什么?",
"Python的应用领域?",
"Transformer包含什么?"
]
hop_counts = [1, 2, 3]
results_summary = []
for query in test_queries:
for hops in hop_counts:
result = graph_rag.query(query, max_hops=hops, verbose=False)
results_summary.append({
'查询': query,
'跳数': hops,
'实体数': len(result['entities']),
'关系数': len(result['relations']),
'路径数': len(result['paths'])
})
# 显示结果
print("\n📊 不同跳数的检索效果对比:\n")
print(f"{'查询':<25} {'跳数':<6} {'实体数':<8} {'关系数':<8} {'路径数':<8}")
print("-" * 60)
for r in results_summary:
print(f"{r['查询']:<25} {r['跳数']:<6} {r['实体数']:<8} {r['关系数']:<8} {r['路径数']:<8}")
8.2 图谱统计信息¶
In [ ]:
Copied!
# 获取详细统计
stats = kg.get_stats()
print("\n📊 知识图谱统计信息:\n")
for key, value in stats.items():
print(f" {key}: {value}")
# 按类型统计实体
print("\n📦 实体类型分布:")
entity_types = defaultdict(int)
for entity in kg.entities.values():
entity_types[entity.type] += 1
for entity_type, count in sorted(entity_types.items()):
print(f" {entity_type}: {count}")
# 关系类型统计
print("\n🔗 关系类型分布:")
relation_types = defaultdict(int)
for relation in kg.relations:
relation_types[relation.relation_type] += 1
for rel_type, count in sorted(relation_types.items()):
print(f" {rel_type}: {count}")
# 获取详细统计
stats = kg.get_stats()
print("\n📊 知识图谱统计信息:\n")
for key, value in stats.items():
print(f" {key}: {value}")
# 按类型统计实体
print("\n📦 实体类型分布:")
entity_types = defaultdict(int)
for entity in kg.entities.values():
entity_types[entity.type] += 1
for entity_type, count in sorted(entity_types.items()):
print(f" {entity_type}: {count}")
# 关系类型统计
print("\n🔗 关系类型分布:")
relation_types = defaultdict(int)
for relation in kg.relations:
relation_types[relation.relation_type] += 1
for rel_type, count in sorted(relation_types.items()):
print(f" {rel_type}: {count}")
In [ ]:
Copied!
def rank_paths(kg: KnowledgeGraph,
paths: List[List[str]],
query: str = "") -> List[Tuple[List[str], float]]:
"""对推理路径进行排序
评分因素:
1. 路径长度(越短越好)
2. 关系权重(越高越好)
3. 实体重要性
"""
scored_paths = []
for path in paths:
score = 0.0
# 1. 路径长度(短路径得分更高)
score += 10.0 / len(path)
# 2. 关系权重
for i in range(len(path) - 1):
edge_data = kg.graph.get_edge_data(path[i], path[i+1])
if edge_data:
weight = edge_data[0].get('weight', 1.0)
score += weight
# 3. 实体类型重要性
type_importance = {
'Technology': 3,
'Concept': 2,
'Field': 1.5,
'Person': 1
}
for node_id in path:
entity = kg.entities.get(node_id)
if entity:
score += type_importance.get(entity.type, 0.5)
scored_paths.append((path, score))
# 按得分排序
scored_paths.sort(key=lambda x: x[1], reverse=True)
return scored_paths
# 测试路径排序
if result3['paths']:
ranked = rank_paths(kg, result3['paths'])
print("\n🏆 推理路径排序:\n")
for i, (path, score) in enumerate(ranked[:5], 1):
path_str = " → ".join([kg.entities[eid].name for eid in path])
print(f"{i}. (得分: {score:.2f}) {path_str}")
def rank_paths(kg: KnowledgeGraph,
paths: List[List[str]],
query: str = "") -> List[Tuple[List[str], float]]:
"""对推理路径进行排序
评分因素:
1. 路径长度(越短越好)
2. 关系权重(越高越好)
3. 实体重要性
"""
scored_paths = []
for path in paths:
score = 0.0
# 1. 路径长度(短路径得分更高)
score += 10.0 / len(path)
# 2. 关系权重
for i in range(len(path) - 1):
edge_data = kg.graph.get_edge_data(path[i], path[i+1])
if edge_data:
weight = edge_data[0].get('weight', 1.0)
score += weight
# 3. 实体类型重要性
type_importance = {
'Technology': 3,
'Concept': 2,
'Field': 1.5,
'Person': 1
}
for node_id in path:
entity = kg.entities.get(node_id)
if entity:
score += type_importance.get(entity.type, 0.5)
scored_paths.append((path, score))
# 按得分排序
scored_paths.sort(key=lambda x: x[1], reverse=True)
return scored_paths
# 测试路径排序
if result3['paths']:
ranked = rank_paths(kg, result3['paths'])
print("\n🏆 推理路径排序:\n")
for i, (path, score) in enumerate(ranked[:5], 1):
path_str = " → ".join([kg.entities[eid].name for eid in path])
print(f"{i}. (得分: {score:.2f}) {path_str}")
10. 总结¶
你已经学会:¶
✅ 知识图谱基础
- 实体和关系的定义
- NetworkX图谱构建
- 图谱可视化
✅ 实体识别
- 规则匹配方法
- 实体链接
- 歧义处理
✅ 图谱检索
- 多跳推理
- 子图提取
- 路径查找
✅ GraphRAG系统
- 图谱上下文构建
- 推理路径生成
- 答案生成
✅ 可视化与评估
- 推理路径可视化
- 子图可视化
- 性能评估
下一步:¶
- 🚀 使用真实NER模型(spaCy)
- 📊 集成向量检索
- 🎯 优化路径评分算法
- 🔗 添加更多实体和关系
- 🌓 构建更大规模的知识图谱
练习:¶
- 为你的领域构建一个小型知识图谱
- 实现基于LLM的实体和关系抽取
- 添加实体消歧功能
- 实现更复杂的路径排序算法
- 将图谱检索与向量检索融合
🎉 恭喜完成知识图谱RAG实践!