feat: RAG混合检索系统完整实现 + 启动脚本修复
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 5m4s

- 实现了稠密+稀疏混合检索,使用 Qdrant 原生 RRF 融合
- 修复了 retriever.py 的 BaseRetriever 继承和稀疏向量包装问题
- 修复了 pipeline.py 的 Optional 导入问题
- 添加了稀疏 embedder 的缓存配置
- 简化了 vector_store.py,移除不必要的逻辑
- 修复了 start.sh 的 PROJECT_DIR 硬编码和端口配置问题
- 完善了 RAG 检索的测试文件
This commit is contained in:
2026-05-04 02:54:37 +08:00
parent 54ba2d3457
commit 8af82f8f7f
9 changed files with 461 additions and 157 deletions

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""
测试重构后的 IndexBuilder 和 RAGRetriever
测试重构后的 IndexBuilder 和 RAG 检索
包括:索引构建、稠密检索、稀疏检索、混合检索、父子文档检索
"""
import asyncio
@@ -8,15 +9,23 @@ import os
import sys
# 添加项目根目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..")
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
sys.path.insert(0, project_root)
from rag_indexer.index_builder import IndexBuilder
from rag_indexer.splitters import SplitterType
from rag_core import QdrantVectorStore, get_sparse_embedder
from app.model_services import get_embedding_service
from qdrant_client import models
async def test_index_builder():
"""测试索引构建功能"""
print("测试索引构建功能...")
print("="*70)
print("1. 测试索引构建功能...")
print("="*70)
# 创建 IndexBuilder 实例
builder = IndexBuilder(
@@ -27,7 +36,7 @@ async def test_index_builder():
)
# 测试文档路径
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "user_docs", "doublestory.txt")
test_file = os.path.join(project_root, "data", "user_docs", "doublestory.txt")
if os.path.exists(test_file):
# 构建索引
@@ -40,10 +49,263 @@ async def test_index_builder():
print(f"集合信息: {info}")
else:
print(f"测试文件不存在: {test_file}")
# 关闭资源
builder.close()
print("\n测试完成")
print("\n索引构建测试完成")
return processed
def test_dense_retrieval():
"""测试稠密检索"""
print("\n" + "="*70)
print("2. 测试稠密检索...")
print("="*70)
# 获取嵌入服务
embeddings = get_embedding_service()
# 创建向量存储
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
# 测试查询
query = "The Ant and the Grasshopper"
print(f"查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果:")
for i, doc in enumerate(results, 1):
print(f"\n{i}. (来源: {doc.metadata.get('source', 'unknown')})")
print(f" 元数据: {doc.metadata}")
content = doc.page_content.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_sparse_retrieval_simple():
"""简单测试稀疏检索"""
print("\n" + "="*70)
print("3. 测试稀疏检索BM25...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询 - 用关键词
query = "winter work food"
print(f"查询关键词: {query}")
# 生成稀疏查询向量
sparse_query = sparse_embedder.embed_query(query)
# 包装成 SparseVector 对象
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 直接查询稀疏向量
response = client.query_points(
collection_name="rag_documents",
query=sparse_vec,
using="sparse",
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果:")
for i, point in enumerate(response.points, 1):
print(f"\n{i}. (分数: {point.score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_hybrid_retrieval_simple():
"""简单测试混合检索(稠密+稀疏 RRF 融合)"""
print("\n" + "="*70)
print("4. 测试混合检索(稠密+稀疏 RRF 融合)...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询
query = "Ant and Grasshopper story"
print(f"查询: {query}")
# 生成双向量
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 使用 Qdrant 的 query_points 做混合检索
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=3
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=3
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果:")
for i, point in enumerate(response.points, 1):
print(f"\n{i}. (RRF 融合分数: {point.score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_parent_child_retrieval_simple():
"""简单测试父子文档检索"""
print("\n" + "="*70)
print("5. 测试父子文档混合检索...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询
query = "The Ant and the Grasshopper story moral"
print(f"查询: {query}")
# 生成双向量
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 先做混合检索找到子文档
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=5
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=5
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=5,
with_payload=True
)
# 收集 parent_id
parent_score_map = {}
child_points = {}
for point in response.points:
parent_id = point.payload.get("parent_id", point.id)
score = point.score
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
child_points[parent_id] = point
parent_ids = list(parent_score_map.keys())
print(f"\n找到 {len(parent_ids)} 个不同的 parent_id:")
# 查找父文档
if parent_ids:
parent_docs = client.retrieve(
collection_name="rag_documents",
ids=parent_ids,
with_payload=True
)
found_parent_ids = {p.id for p in parent_docs}
# 准备结果列表
results = []
for p in parent_docs:
score = parent_score_map[p.id]
results.append((p, score))
# 处理没找到父文档的情况 - 用子文档代替
missing = set(parent_ids) - found_parent_ids
for parent_id in missing:
child_point = child_points[parent_id]
print(f"\n注意: parent_id {parent_id} 未找到,使用子文档代替")
results.append((child_point, parent_score_map[parent_id]))
# 按分数排序
results.sort(key=lambda x: x[1], reverse=True)
# 显示
print(f"\n{len(results)} 个结果(去重后):")
for i, (point, score) in enumerate(results[:3], 1):
print(f"\n{i}. (分数: {score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 400:
content = content[:400] + "..."
print(f" 内容: {content}")
else:
print("\n未找到结果")
async def main():
"""主测试函数"""
# 1. 先构建索引
await test_index_builder()
# 2. 测试稠密检索
test_dense_retrieval()
# 3. 测试稀疏检索
test_sparse_retrieval_simple()
# 4. 测试混合检索
test_hybrid_retrieval_simple()
# 5. 测试父子文档检索
test_parent_child_retrieval_simple()
print("\n" + "="*70)
print("所有测试完成!")
print("="*70)
if __name__ == "__main__":
asyncio.run(test_index_builder())
asyncio.run(main())