Files
ailine/tools/test/test_rag_indexer_result.py
root 8af82f8f7f
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 5m4s
feat: RAG混合检索系统完整实现 + 启动脚本修复
- 实现了稠密+稀疏混合检索,使用 Qdrant 原生 RRF 融合
- 修复了 retriever.py 的 BaseRetriever 继承和稀疏向量包装问题
- 修复了 pipeline.py 的 Optional 导入问题
- 添加了稀疏 embedder 的缓存配置
- 简化了 vector_store.py,移除不必要的逻辑
- 修复了 start.sh 的 PROJECT_DIR 硬编码和端口配置问题
- 完善了 RAG 检索的测试文件
2026-05-04 02:54:37 +08:00

312 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试重构后的 IndexBuilder 和 RAG 检索
包括:索引构建、稠密检索、稀疏检索、混合检索、父子文档检索
"""
import asyncio
import os
import sys
# 添加项目根目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
sys.path.insert(0, project_root)
from rag_indexer.index_builder import IndexBuilder
from rag_indexer.splitters import SplitterType
from rag_core import QdrantVectorStore, get_sparse_embedder
from app.model_services import get_embedding_service
from qdrant_client import models
async def test_index_builder():
"""测试索引构建功能"""
print("="*70)
print("1. 测试索引构建功能...")
print("="*70)
# 创建 IndexBuilder 实例
builder = IndexBuilder(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200
)
# 测试文档路径
test_file = os.path.join(project_root, "data", "user_docs", "doublestory.txt")
if os.path.exists(test_file):
# 构建索引
print(f"正在为文件 {test_file} 构建索引...")
processed = await builder.build_from_file(test_file)
print(f"索引构建完成,处理了 {processed} 个文档")
# 获取集合信息
info = builder.get_collection_info()
print(f"集合信息: {info}")
else:
print(f"测试文件不存在: {test_file}")
# 关闭资源
builder.close()
print("\n索引构建测试完成")
return processed
def test_dense_retrieval():
"""测试稠密检索"""
print("\n" + "="*70)
print("2. 测试稠密检索...")
print("="*70)
# 获取嵌入服务
embeddings = get_embedding_service()
# 创建向量存储
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
# 测试查询
query = "The Ant and the Grasshopper"
print(f"查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果:")
for i, doc in enumerate(results, 1):
print(f"\n{i}. (来源: {doc.metadata.get('source', 'unknown')})")
print(f" 元数据: {doc.metadata}")
content = doc.page_content.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_sparse_retrieval_simple():
"""简单测试稀疏检索"""
print("\n" + "="*70)
print("3. 测试稀疏检索BM25...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询 - 用关键词
query = "winter work food"
print(f"查询关键词: {query}")
# 生成稀疏查询向量
sparse_query = sparse_embedder.embed_query(query)
# 包装成 SparseVector 对象
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 直接查询稀疏向量
response = client.query_points(
collection_name="rag_documents",
query=sparse_vec,
using="sparse",
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果:")
for i, point in enumerate(response.points, 1):
print(f"\n{i}. (分数: {point.score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_hybrid_retrieval_simple():
"""简单测试混合检索(稠密+稀疏 RRF 融合)"""
print("\n" + "="*70)
print("4. 测试混合检索(稠密+稀疏 RRF 融合)...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询
query = "Ant and Grasshopper story"
print(f"查询: {query}")
# 生成双向量
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 使用 Qdrant 的 query_points 做混合检索
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=3
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=3
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果:")
for i, point in enumerate(response.points, 1):
print(f"\n{i}. (RRF 融合分数: {point.score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_parent_child_retrieval_simple():
"""简单测试父子文档检索"""
print("\n" + "="*70)
print("5. 测试父子文档混合检索...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询
query = "The Ant and the Grasshopper story moral"
print(f"查询: {query}")
# 生成双向量
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 先做混合检索找到子文档
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=5
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=5
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=5,
with_payload=True
)
# 收集 parent_id
parent_score_map = {}
child_points = {}
for point in response.points:
parent_id = point.payload.get("parent_id", point.id)
score = point.score
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
child_points[parent_id] = point
parent_ids = list(parent_score_map.keys())
print(f"\n找到 {len(parent_ids)} 个不同的 parent_id:")
# 查找父文档
if parent_ids:
parent_docs = client.retrieve(
collection_name="rag_documents",
ids=parent_ids,
with_payload=True
)
found_parent_ids = {p.id for p in parent_docs}
# 准备结果列表
results = []
for p in parent_docs:
score = parent_score_map[p.id]
results.append((p, score))
# 处理没找到父文档的情况 - 用子文档代替
missing = set(parent_ids) - found_parent_ids
for parent_id in missing:
child_point = child_points[parent_id]
print(f"\n注意: parent_id {parent_id} 未找到,使用子文档代替")
results.append((child_point, parent_score_map[parent_id]))
# 按分数排序
results.sort(key=lambda x: x[1], reverse=True)
# 显示
print(f"\n{len(results)} 个结果(去重后):")
for i, (point, score) in enumerate(results[:3], 1):
print(f"\n{i}. (分数: {score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 400:
content = content[:400] + "..."
print(f" 内容: {content}")
else:
print("\n未找到结果")
async def main():
"""主测试函数"""
# 1. 先构建索引
await test_index_builder()
# 2. 测试稠密检索
test_dense_retrieval()
# 3. 测试稀疏检索
test_sparse_retrieval_simple()
# 4. 测试混合检索
test_hybrid_retrieval_simple()
# 5. 测试父子文档检索
test_parent_child_retrieval_simple()
print("\n" + "="*70)
print("所有测试完成!")
print("="*70)
if __name__ == "__main__":
asyncio.run(main())