1.2 CLIP模型简介¶
CLIP (Contrastive Language-Image Pre-training) by OpenAI
核心思想:
- 图像和文本映射到同一个向量空间
- 相似的图像和文本在空间中距离更近
- 实现跨模态语义理解
# CLIP工作原理
Image Encoder: 图像 → 512维向量
Text Encoder: 文本 → 512维向量
# 相似度计算
similarity = cosine_similarity(image_vector, text_vector)
1.3 环境配置¶
安装必要的库:
In [ ]:
Copied!
# 安装依赖
!pip install -q torch torchvision transformers pillow matplotlib
!pip install -q openai clip-by-openai # CLIP模型
!pip install -q sentence-transformers # 备选方案
# 安装依赖
!pip install -q torch torchvision transformers pillow matplotlib
!pip install -q openai clip-by-openai # CLIP模型
!pip install -q sentence-transformers # 备选方案
In [ ]:
Copied!
# 导入库
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict
import os
# 检查GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# 导入库
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict
import os
# 检查GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
In [ ]:
Copied!
import clip
# 加载CLIP模型
model, preprocess = clip.load("ViT-B/32", device=device)
print(f"Model loaded: ViT-B/32")
print(f"Image input resolution: {model.visual.input_resolution}")
print(f"Context length: {model.context_length}")
print(f"Vocabulary size: {model.vocab_size}")
import clip
# 加载CLIP模型
model, preprocess = clip.load("ViT-B/32", device=device)
print(f"Model loaded: ViT-B/32")
print(f"Image input resolution: {model.visual.input_resolution}")
print(f"Context length: {model.context_length}")
print(f"Vocabulary size: {model.vocab_size}")
2.2 图像和文本编码¶
In [ ]:
Copied!
def encode_image(image_path: str) -> torch.Tensor:
"""
使用CLIP编码图像
Args:
image_path: 图像文件路径
Returns:
图像嵌入向量(512维)
"""
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
return image_features.cpu().numpy()[0]
def encode_text(text: str) -> torch.Tensor:
"""
使用CLIP编码文本
Args:
text: 输入文本
Returns:
文本嵌入向量(512维)
"""
text_tokens = clip.tokenize([text]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_tokens)
return text_features.cpu().numpy()[0]
# 测试编码
sample_text = "一只可爱的小猫"
text_embedding = encode_text(sample_text)
print(f"Text: {sample_text}")
print(f"Embedding shape: {text_embedding.shape}")
print(f"Embedding norm: {np.linalg.norm(text_embedding):.4f}")
def encode_image(image_path: str) -> torch.Tensor:
"""
使用CLIP编码图像
Args:
image_path: 图像文件路径
Returns:
图像嵌入向量(512维)
"""
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
return image_features.cpu().numpy()[0]
def encode_text(text: str) -> torch.Tensor:
"""
使用CLIP编码文本
Args:
text: 输入文本
Returns:
文本嵌入向量(512维)
"""
text_tokens = clip.tokenize([text]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_tokens)
return text_features.cpu().numpy()[0]
# 测试编码
sample_text = "一只可爱的小猫"
text_embedding = encode_text(sample_text)
print(f"Text: {sample_text}")
print(f"Embedding shape: {text_embedding.shape}")
print(f"Embedding norm: {np.linalg.norm(text_embedding):.4f}")
2.3 计算跨模态相似度¶
In [ ]:
Copied!
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""
计算余弦相似度
Args:
vec1, vec2: 向量
Returns:
相似度分数(0-1)
"""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
# 示例:计算文本-文本相似度
texts = [
"一只可爱的小猫",
"一只可爱的小狗",
"今天的天气很好",
"笔记本电脑"
]
query_text = "小猫在玩耍"
query_embedding = encode_text(query_text)
print(f"\nQuery: {query_text}\n")
print("相似度排名:")
print("-" * 50)
similarities = []
for text in texts:
text_embedding = encode_text(text)
sim = cosine_similarity(query_embedding, text_embedding)
similarities.append((text, sim))
similarities.sort(key=lambda x: x[1], reverse=True)
for text, sim in similarities:
print(f"{text:30s} {sim:.4f}")
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""
计算余弦相似度
Args:
vec1, vec2: 向量
Returns:
相似度分数(0-1)
"""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
# 示例:计算文本-文本相似度
texts = [
"一只可爱的小猫",
"一只可爱的小狗",
"今天的天气很好",
"笔记本电脑"
]
query_text = "小猫在玩耍"
query_embedding = encode_text(query_text)
print(f"\nQuery: {query_text}\n")
print("相似度排名:")
print("-" * 50)
similarities = []
for text in texts:
text_embedding = encode_text(text)
sim = cosine_similarity(query_embedding, text_embedding)
similarities.append((text, sim))
similarities.sort(key=lambda x: x[1], reverse=True)
for text, sim in similarities:
print(f"{text:30s} {sim:.4f}")
In [ ]:
Copied!
# 创建示例图像集
import urllib.request
from io import BytesIO
# 示例图像URL(使用公开的示例图像)
sample_images = {
"cat": "https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400",
"dog": "https://images.unsplash.com/photo-1587300003388-59208cc962cb?w=400",
"car": "https://images.unsplash.com/photo-1503376780353-7e6692767b70?w=400",
"food": "https://images.unsplash.com/photo-1476224203421-9ac39bcb3327?w=400",
"landscape": "https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=400"
}
# 下载图像
os.makedirs("data/images", exist_ok=True)
image_paths = {}
for name, url in sample_images.items():
try:
urllib.request.urlretrieve(url, f"data/images/{name}.jpg")
image_paths[name] = f"data/images/{name}.jpg"
print(f"Downloaded: {name}.jpg")
except Exception as e:
print(f"Failed to download {name}: {e}")
# 创建示例图像集
import urllib.request
from io import BytesIO
# 示例图像URL(使用公开的示例图像)
sample_images = {
"cat": "https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400",
"dog": "https://images.unsplash.com/photo-1587300003388-59208cc962cb?w=400",
"car": "https://images.unsplash.com/photo-1503376780353-7e6692767b70?w=400",
"food": "https://images.unsplash.com/photo-1476224203421-9ac39bcb3327?w=400",
"landscape": "https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=400"
}
# 下载图像
os.makedirs("data/images", exist_ok=True)
image_paths = {}
for name, url in sample_images.items():
try:
urllib.request.urlretrieve(url, f"data/images/{name}.jpg")
image_paths[name] = f"data/images/{name}.jpg"
print(f"Downloaded: {name}.jpg")
except Exception as e:
print(f"Failed to download {name}: {e}")
3.2 构建图像索引¶
In [ ]:
Copied!
class ImageIndex:
"""图像索引类"""
def __init__(self):
self.embeddings = []
self.paths = []
self.metadata = []
def add_image(self, image_path: str, metadata: dict = None):
"""添加图像到索引"""
embedding = encode_image(image_path)
self.embeddings.append(embedding)
self.paths.append(image_path)
self.metadata.append(metadata or {})
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[dict]:
"""搜索最相似的图像"""
similarities = []
for i, emb in enumerate(self.embeddings):
sim = cosine_similarity(query_embedding, emb)
similarities.append({
'path': self.paths[i],
'similarity': sim,
'metadata': self.metadata[i]
})
similarities.sort(key=lambda x: x['similarity'], reverse=True)
return similarities[:top_k]
# 构建图像索引
image_index = ImageIndex()
for name, path in image_paths.items():
image_index.add_image(path, metadata={'category': name})
print(f"Indexed: {name}")
print(f"\nTotal images indexed: {len(image_index.embeddings)}")
class ImageIndex:
"""图像索引类"""
def __init__(self):
self.embeddings = []
self.paths = []
self.metadata = []
def add_image(self, image_path: str, metadata: dict = None):
"""添加图像到索引"""
embedding = encode_image(image_path)
self.embeddings.append(embedding)
self.paths.append(image_path)
self.metadata.append(metadata or {})
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[dict]:
"""搜索最相似的图像"""
similarities = []
for i, emb in enumerate(self.embeddings):
sim = cosine_similarity(query_embedding, emb)
similarities.append({
'path': self.paths[i],
'similarity': sim,
'metadata': self.metadata[i]
})
similarities.sort(key=lambda x: x['similarity'], reverse=True)
return similarities[:top_k]
# 构建图像索引
image_index = ImageIndex()
for name, path in image_paths.items():
image_index.add_image(path, metadata={'category': name})
print(f"Indexed: {name}")
print(f"\nTotal images indexed: {len(image_index.embeddings)}")
3.3 以图搜图¶
In [ ]:
Copied!
def search_by_image(query_image_path: str, top_k: int = 3):
"""
使用查询图像搜索相似图像
Args:
query_image_path: 查询图像路径
top_k: 返回前k个结果
"""
# 编码查询图像
query_embedding = encode_image(query_image_path)
# 搜索
results = image_index.search(query_embedding, top_k=top_k)
# 显示结果
fig, axes = plt.subplots(1, top_k + 1, figsize=(15, 3))
# 查询图像
query_img = Image.open(query_image_path)
axes[0].imshow(query_img)
axes[0].set_title("Query Image")
axes[0].axis('off')
# 搜索结果
for i, result in enumerate(results):
img = Image.open(result['path'])
axes[i + 1].imshow(img)
axes[i + 1].set_title(f"Sim: {result['similarity']:.3f}")
axes[i + 1].axis('off')
plt.tight_layout()
plt.show()
return results
# 测试:使用第一张图作为查询
if image_paths:
query_path = list(image_paths.values())[0]
results = search_by_image(query_path, top_k=3)
print("\n搜索结果:")
for r in results:
print(f" {r['path']}: {r['similarity']:.4f}")
def search_by_image(query_image_path: str, top_k: int = 3):
"""
使用查询图像搜索相似图像
Args:
query_image_path: 查询图像路径
top_k: 返回前k个结果
"""
# 编码查询图像
query_embedding = encode_image(query_image_path)
# 搜索
results = image_index.search(query_embedding, top_k=top_k)
# 显示结果
fig, axes = plt.subplots(1, top_k + 1, figsize=(15, 3))
# 查询图像
query_img = Image.open(query_image_path)
axes[0].imshow(query_img)
axes[0].set_title("Query Image")
axes[0].axis('off')
# 搜索结果
for i, result in enumerate(results):
img = Image.open(result['path'])
axes[i + 1].imshow(img)
axes[i + 1].set_title(f"Sim: {result['similarity']:.3f}")
axes[i + 1].axis('off')
plt.tight_layout()
plt.show()
return results
# 测试:使用第一张图作为查询
if image_paths:
query_path = list(image_paths.values())[0]
results = search_by_image(query_path, top_k=3)
print("\n搜索结果:")
for r in results:
print(f" {r['path']}: {r['similarity']:.4f}")
In [ ]:
Copied!
def search_by_text(query_text: str, top_k: int = 3):
"""
使用文本查询搜索图像
Args:
query_text: 查询文本
top_k: 返回前k个结果
"""
# 编码查询文本
query_embedding = encode_text(query_text)
# 搜索
results = image_index.search(query_embedding, top_k=top_k)
# 显示结果
fig, axes = plt.subplots(1, top_k, figsize=(12, 4))
for i, result in enumerate(results):
img = Image.open(result['path'])
axes[i].imshow(img)
axes[i].set_title(f"{result['metadata']['category']}\nSim: {result['similarity']:.3f}")
axes[i].axis('off')
plt.suptitle(f"Query: {query_text}", fontsize=14)
plt.tight_layout()
plt.show()
return results
# 测试文本查询
queries = [
"一只小猫",
"汽车",
"美食",
"自然风景"
]
for query in queries:
print(f"\n查询: {query}")
results = search_by_text(query, top_k=3)
def search_by_text(query_text: str, top_k: int = 3):
"""
使用文本查询搜索图像
Args:
query_text: 查询文本
top_k: 返回前k个结果
"""
# 编码查询文本
query_embedding = encode_text(query_text)
# 搜索
results = image_index.search(query_embedding, top_k=top_k)
# 显示结果
fig, axes = plt.subplots(1, top_k, figsize=(12, 4))
for i, result in enumerate(results):
img = Image.open(result['path'])
axes[i].imshow(img)
axes[i].set_title(f"{result['metadata']['category']}\nSim: {result['similarity']:.3f}")
axes[i].axis('off')
plt.suptitle(f"Query: {query_text}", fontsize=14)
plt.tight_layout()
plt.show()
return results
# 测试文本查询
queries = [
"一只小猫",
"汽车",
"美食",
"自然风景"
]
for query in queries:
print(f"\n查询: {query}")
results = search_by_text(query, top_k=3)
In [ ]:
Copied!
from openai import OpenAI
import base64
# 初始化客户端
client = OpenAI(api_key="your-api-key") # 替换为你的API key
def encode_image_to_base64(image_path: str) -> str:
"""将图像编码为base64"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def analyze_image_with_gpt4v(image_path: str, question: str) -> str:
"""
使用GPT-4V分析图像
Args:
image_path: 图像路径
question: 关于图像的问题
Returns:
GPT-4V的回答
"""
base64_image = encode_image_to_base64(image_path)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=300
)
return response.choices[0].message.content
# 示例使用
# if image_paths:
# test_image = list(image_paths.values())[0]
# answer = analyze_image_with_gpt4v(
# test_image,
# "这张图片里有什么?请详细描述。"
# )
# print(f"GPT-4V回答: {answer}")
from openai import OpenAI
import base64
# 初始化客户端
client = OpenAI(api_key="your-api-key") # 替换为你的API key
def encode_image_to_base64(image_path: str) -> str:
"""将图像编码为base64"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def analyze_image_with_gpt4v(image_path: str, question: str) -> str:
"""
使用GPT-4V分析图像
Args:
image_path: 图像路径
question: 关于图像的问题
Returns:
GPT-4V的回答
"""
base64_image = encode_image_to_base64(image_path)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=300
)
return response.choices[0].message.content
# 示例使用
# if image_paths:
# test_image = list(image_paths.values())[0]
# answer = analyze_image_with_gpt4v(
# test_image,
# "这张图片里有什么?请详细描述。"
# )
# print(f"GPT-4V回答: {answer}")
5.2 视觉问答系统¶
In [ ]:
Copied!
class VisualQASystem:
"""视觉问答系统"""
def __init__(self, image_index: ImageIndex):
self.image_index = image_index
self.client = OpenAI(api_key="your-api-key")
def query(self, text_query: str, top_k: int = 3) -> dict:
"""
文本查询:检索相关图像并回答
Args:
text_query: 文本查询
top_k: 检索的图像数量
Returns:
包含检索结果和回答的字典
"""
# 1. 使用CLIP检索相关图像
query_embedding = encode_text(text_query)
image_results = self.image_index.search(query_embedding, top_k=top_k)
# 2. 使用GPT-4V分析检索到的图像
analysis_results = []
for result in image_results:
analysis = self._analyze_image(
result['path'],
f"用户查询:{text_query}。这张图片与查询相关吗?请说明原因。"
)
analysis_results.append({
'image_path': result['path'],
'similarity': result['similarity'],
'analysis': analysis
})
# 3. 生成综合回答
answer = self._generate_answer(text_query, analysis_results)
return {
'query': text_query,
'retrieved_images': analysis_results,
'answer': answer
}
def _analyze_image(self, image_path: str, question: str) -> str:
"""分析单张图像"""
try:
base64_image = encode_image_to_base64(image_path)
response = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=200
)
return response.choices[0].message.content
except Exception as e:
return f"分析失败: {str(e)}"
def _generate_answer(self, query: str, results: list) -> str:
"""基于检索结果生成回答"""
# 简化版本:直接返回最相关图像的分析
if results:
best_result = max(results, key=lambda x: x['similarity'])
return f"找到 {len(results)} 张相关图片。最相关的是:{best_result['analysis']}"
return "未找到相关图片。"
# 创建视觉问答系统
# vqa_system = VisualQASystem(image_index)
# 测试
# result = vqa_system.query("可爱的小动物")
# print(f"查询: {result['query']}")
# print(f"回答: {result['answer']}")
class VisualQASystem:
"""视觉问答系统"""
def __init__(self, image_index: ImageIndex):
self.image_index = image_index
self.client = OpenAI(api_key="your-api-key")
def query(self, text_query: str, top_k: int = 3) -> dict:
"""
文本查询:检索相关图像并回答
Args:
text_query: 文本查询
top_k: 检索的图像数量
Returns:
包含检索结果和回答的字典
"""
# 1. 使用CLIP检索相关图像
query_embedding = encode_text(text_query)
image_results = self.image_index.search(query_embedding, top_k=top_k)
# 2. 使用GPT-4V分析检索到的图像
analysis_results = []
for result in image_results:
analysis = self._analyze_image(
result['path'],
f"用户查询:{text_query}。这张图片与查询相关吗?请说明原因。"
)
analysis_results.append({
'image_path': result['path'],
'similarity': result['similarity'],
'analysis': analysis
})
# 3. 生成综合回答
answer = self._generate_answer(text_query, analysis_results)
return {
'query': text_query,
'retrieved_images': analysis_results,
'answer': answer
}
def _analyze_image(self, image_path: str, question: str) -> str:
"""分析单张图像"""
try:
base64_image = encode_image_to_base64(image_path)
response = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=200
)
return response.choices[0].message.content
except Exception as e:
return f"分析失败: {str(e)}"
def _generate_answer(self, query: str, results: list) -> str:
"""基于检索结果生成回答"""
# 简化版本:直接返回最相关图像的分析
if results:
best_result = max(results, key=lambda x: x['similarity'])
return f"找到 {len(results)} 张相关图片。最相关的是:{best_result['analysis']}"
return "未找到相关图片。"
# 创建视觉问答系统
# vqa_system = VisualQASystem(image_index)
# 测试
# result = vqa_system.query("可爱的小动物")
# print(f"查询: {result['query']}")
# print(f"回答: {result['answer']}")
In [ ]:
Copied!
class MultimodalRAG:
"""
完整的多模态RAG系统
支持两种查询模式:
1. 以文搜图:文本查询 → 图像检索 → 视觉理解 → 回答
2. 以图搜图:图像查询 → 图像检索 → 视觉理解 → 回答
"""
def __init__(self, use_vision: bool = True):
"""
Args:
use_vision: 是否使用GPT-4V进行视觉理解
"""
self.image_index = ImageIndex()
self.use_vision = use_vision
if use_vision:
self.client = OpenAI(api_key="your-api-key")
def index_images(self, image_paths: List[str], metadata: List[dict] = None):
"""批量索引图像"""
for i, path in enumerate(image_paths):
meta = metadata[i] if metadata else {}
self.image_index.add_image(path, meta)
print(f"Indexed {len(image_paths)} images")
def query_with_text(self, query: str, top_k: int = 3) -> dict:
"""文本查询"""
# 1. 检索
query_embedding = encode_text(query)
results = self.image_index.search(query_embedding, top_k=top_k)
# 2. 视觉理解(可选)
if self.use_vision:
enhanced_results = self._enhance_with_vision(query, results)
else:
enhanced_results = results
# 3. 返回
return {
'query': query,
'mode': 'text-to-image',
'results': enhanced_results
}
def query_with_image(self, query_image_path: str, top_k: int = 3) -> dict:
"""图像查询"""
# 1. 检索
query_embedding = encode_image(query_image_path)
results = self.image_index.search(query_embedding, top_k=top_k)
# 2. 返回
return {
'query': query_image_path,
'mode': 'image-to-image',
'results': results
}
def _enhance_with_vision(self, query: str, results: list) -> list:
"""使用GPT-4V增强结果"""
enhanced = []
for result in results:
analysis = self._analyze_image(result['path'], query)
enhanced.append({
**result,
'analysis': analysis
})
return enhanced
def _analyze_image(self, image_path: str, query: str) -> str:
"""分析图像"""
try:
base64_image = encode_image_to_base64(image_path)
response = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": f"用户查询:{query}\n请分析这张图片与查询的关系。"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=150
)
return response.choices[0].message.content
except:
return "视觉分析失败"
# 创建系统
rag_system = MultimodalRAG(use_vision=False) # 暂时关闭GPT-4V
# 索引图像
if image_paths:
paths = list(image_paths.values())
metadatas = [{'category': k} for k in image_paths.keys()]
rag_system.index_images(paths, metadatas)
class MultimodalRAG:
"""
完整的多模态RAG系统
支持两种查询模式:
1. 以文搜图:文本查询 → 图像检索 → 视觉理解 → 回答
2. 以图搜图:图像查询 → 图像检索 → 视觉理解 → 回答
"""
def __init__(self, use_vision: bool = True):
"""
Args:
use_vision: 是否使用GPT-4V进行视觉理解
"""
self.image_index = ImageIndex()
self.use_vision = use_vision
if use_vision:
self.client = OpenAI(api_key="your-api-key")
def index_images(self, image_paths: List[str], metadata: List[dict] = None):
"""批量索引图像"""
for i, path in enumerate(image_paths):
meta = metadata[i] if metadata else {}
self.image_index.add_image(path, meta)
print(f"Indexed {len(image_paths)} images")
def query_with_text(self, query: str, top_k: int = 3) -> dict:
"""文本查询"""
# 1. 检索
query_embedding = encode_text(query)
results = self.image_index.search(query_embedding, top_k=top_k)
# 2. 视觉理解(可选)
if self.use_vision:
enhanced_results = self._enhance_with_vision(query, results)
else:
enhanced_results = results
# 3. 返回
return {
'query': query,
'mode': 'text-to-image',
'results': enhanced_results
}
def query_with_image(self, query_image_path: str, top_k: int = 3) -> dict:
"""图像查询"""
# 1. 检索
query_embedding = encode_image(query_image_path)
results = self.image_index.search(query_embedding, top_k=top_k)
# 2. 返回
return {
'query': query_image_path,
'mode': 'image-to-image',
'results': results
}
def _enhance_with_vision(self, query: str, results: list) -> list:
"""使用GPT-4V增强结果"""
enhanced = []
for result in results:
analysis = self._analyze_image(result['path'], query)
enhanced.append({
**result,
'analysis': analysis
})
return enhanced
def _analyze_image(self, image_path: str, query: str) -> str:
"""分析图像"""
try:
base64_image = encode_image_to_base64(image_path)
response = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": f"用户查询:{query}\n请分析这张图片与查询的关系。"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=150
)
return response.choices[0].message.content
except:
return "视觉分析失败"
# 创建系统
rag_system = MultimodalRAG(use_vision=False) # 暂时关闭GPT-4V
# 索引图像
if image_paths:
paths = list(image_paths.values())
metadatas = [{'category': k} for k in image_paths.keys()]
rag_system.index_images(paths, metadatas)
6.2 测试系统¶
In [ ]:
Copied!
# 测试文本查询
test_queries = [
"可爱的宠物",
"交通工具",
"美食"
]
for query in test_queries:
result = rag_system.query_with_text(query, top_k=2)
print(f"\n{'='*60}")
print(f"查询: {result['query']}")
print(f"模式: {result['mode']}")
print(f"\n结果:")
for i, r in enumerate(result['results'], 1):
print(f" {i}. {r['metadata']['category']}: {r['similarity']:.4f}")
if 'analysis' in r:
print(f" 分析: {r['analysis'][:100]}...")
# 测试文本查询
test_queries = [
"可爱的宠物",
"交通工具",
"美食"
]
for query in test_queries:
result = rag_system.query_with_text(query, top_k=2)
print(f"\n{'='*60}")
print(f"查询: {result['query']}")
print(f"模式: {result['mode']}")
print(f"\n结果:")
for i, r in enumerate(result['results'], 1):
print(f" {i}. {r['metadata']['category']}: {r['similarity']:.4f}")
if 'analysis' in r:
print(f" 分析: {r['analysis'][:100]}...")
In [ ]:
Copied!
# TODO: 实现带过滤的图像检索
# 提示:修改 ImageIndex 类,添加 filter 参数
class ImageIndexWithFilter(ImageIndex):
def search(self, query_embedding: np.ndarray, top_k: int = 5,
filters: dict = None) -> List[dict]:
"""
添加过滤功能
Args:
filters: 过滤条件,如 {'category': 'animal'}
"""
# 实现过滤逻辑
pass
# 测试
# filtered_index = ImageIndexWithFilter()
# ... 添加带元数据的图像
# results = filtered_index.search(query_emb, filters={'category': 'animal'})
# TODO: 实现带过滤的图像检索
# 提示:修改 ImageIndex 类,添加 filter 参数
class ImageIndexWithFilter(ImageIndex):
def search(self, query_embedding: np.ndarray, top_k: int = 5,
filters: dict = None) -> List[dict]:
"""
添加过滤功能
Args:
filters: 过滤条件,如 {'category': 'animal'}
"""
# 实现过滤逻辑
pass
# 测试
# filtered_index = ImageIndexWithFilter()
# ... 添加带元数据的图像
# results = filtered_index.search(query_emb, filters={'category': 'animal'})
练习2:图文混合检索¶
实现一个同时接受图像+文本作为输入的混合检索系统。
In [ ]:
Copied!
# TODO: 实现图文混合检索
# 提示:将图像和文本嵌入加权融合
def hybrid_query(rag_system: MultimodalRAG,
image_path: str,
text: str,
image_weight: float = 0.5,
top_k: int = 3) -> dict:
"""
混合查询:结合图像和文本
Args:
image_path: 参考图像
text: 文本描述
image_weight: 图像权重(0-1)
top_k: 返回结果数
Returns:
检索结果
"""
# 1. 编码
image_emb = encode_image(image_path)
text_emb = encode_text(text)
# 2. 加权融合
# hybrid_emb = image_emb * image_weight + text_emb * (1 - image_weight)
# 或使用其他融合策略
# 3. 检索
# results = rag_system.image_index.search(hybrid_emb, top_k)
# return results
pass
# TODO: 实现图文混合检索
# 提示:将图像和文本嵌入加权融合
def hybrid_query(rag_system: MultimodalRAG,
image_path: str,
text: str,
image_weight: float = 0.5,
top_k: int = 3) -> dict:
"""
混合查询:结合图像和文本
Args:
image_path: 参考图像
text: 文本描述
image_weight: 图像权重(0-1)
top_k: 返回结果数
Returns:
检索结果
"""
# 1. 编码
image_emb = encode_image(image_path)
text_emb = encode_text(text)
# 2. 加权融合
# hybrid_emb = image_emb * image_weight + text_emb * (1 - image_weight)
# 或使用其他融合策略
# 3. 检索
# results = rag_system.image_index.search(hybrid_emb, top_k)
# return results
pass
练习3:产品推荐系统¶
为电商平台构建一个产品图像推荐系统。
In [ ]:
Copied!
# TODO: 实现产品推荐系统
# 功能:
# 1. 以图搜相似产品
# 2. 以文搜产品
# 3. 产品对比
class ProductRecommender:
"""产品推荐系统"""
def __init__(self):
self.rag = MultimodalRAG()
def find_similar_products(self, product_image: str, top_k: int = 5):
"""找相似产品"""
pass
def search_by_description(self, description: str, top_k: int = 5):
"""按描述搜索"""
pass
def compare_products(self, product_images: List[str]) -> dict:
"""对比多个产品"""
pass
# TODO: 实现产品推荐系统
# 功能:
# 1. 以图搜相似产品
# 2. 以文搜产品
# 3. 产品对比
class ProductRecommender:
"""产品推荐系统"""
def __init__(self):
self.rag = MultimodalRAG()
def find_similar_products(self, product_image: str, top_k: int = 5):
"""找相似产品"""
pass
def search_by_description(self, description: str, top_k: int = 5):
"""按描述搜索"""
pass
def compare_products(self, product_images: List[str]) -> dict:
"""对比多个产品"""
pass