refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
This commit is contained in:
@@ -1,305 +1,130 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试重构后的 IndexBuilder 和 RAG 检索
|
||||
包括:索引构建、稠密检索、稀疏检索、混合检索、父子文档检索
|
||||
简单的 RAG 检索测试
|
||||
使用 app/rag/retriever 提供的功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from rag_indexer.index_builder import IndexBuilder
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
|
||||
from backend.app.model_services import get_embedding_service
|
||||
from qdrant_client import models
|
||||
from backend.app.rag.retriever import (
|
||||
create_parent_hybrid_retriever,
|
||||
create_hybrid_retriever
|
||||
)
|
||||
from backend.rag_core import QdrantHybridStore
|
||||
|
||||
|
||||
async def test_index_builder():
|
||||
"""测试索引构建功能"""
|
||||
print("="*70)
|
||||
print("1. 测试索引构建功能...")
|
||||
print("="*70)
|
||||
# 统一的测试查询列表
|
||||
TEST_QUERIES = [
|
||||
"黄双银",
|
||||
]
|
||||
|
||||
|
||||
async def test_simple_vector_store_search():
|
||||
"""测试:直接使用 QdrantHybridStore 的 asimilarity_search"""
|
||||
print("="*80)
|
||||
print("测试 1: QdrantHybridStore.asimilarity_search")
|
||||
print("="*80)
|
||||
|
||||
# 创建 IndexBuilder 实例
|
||||
builder = IndexBuilder(
|
||||
vs = QdrantHybridStore(collection_name="rag_documents")
|
||||
|
||||
for query in TEST_QUERIES:
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 60)
|
||||
|
||||
docs = await vs.asimilarity_search(query, k=10)
|
||||
|
||||
if docs:
|
||||
print(f"✓ 找到 {len(docs)} 个文档")
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n {i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||||
preview = doc.page_content[:120].strip()
|
||||
if len(doc.page_content) > 120:
|
||||
preview += "..."
|
||||
print(f" 内容: {preview}")
|
||||
else:
|
||||
print("✗ 未找到结果")
|
||||
|
||||
await vs.close_async_client()
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
async def test_hybrid_retriever():
|
||||
"""测试:HybridRetriever(子文档检索)"""
|
||||
print("\n" + "="*80)
|
||||
print("测试 2: HybridRetriever (子文档混合检索)")
|
||||
print("="*80)
|
||||
|
||||
retriever = create_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200
|
||||
search_k=10
|
||||
)
|
||||
|
||||
# 测试文档路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
test_file = os.path.join(project_root, "data", "user_docs", "doublestory.txt")
|
||||
|
||||
if os.path.exists(test_file):
|
||||
# 构建索引
|
||||
print(f"正在为文件 {test_file} 构建索引...")
|
||||
processed = await builder.build_from_file(test_file)
|
||||
print(f"索引构建完成,处理了 {processed} 个文档")
|
||||
for query in TEST_QUERIES:
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 60)
|
||||
|
||||
# 获取集合信息
|
||||
info = builder.get_collection_info()
|
||||
print(f"集合信息: {info}")
|
||||
else:
|
||||
print(f"测试文件不存在: {test_file}")
|
||||
docs = await retriever.ainvoke(query)
|
||||
|
||||
if docs:
|
||||
print(f"✓ 找到 {len(docs)} 个子文档")
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n {i}. parent_id: {doc.metadata.get('parent_id', 'none')}")
|
||||
preview = doc.page_content[:100].strip()
|
||||
if len(doc.page_content) > 100:
|
||||
preview += "..."
|
||||
print(f" 内容: {preview}")
|
||||
else:
|
||||
print("✗ 未找到结果")
|
||||
|
||||
# 关闭资源
|
||||
builder.close()
|
||||
print("\n索引构建测试完成")
|
||||
return processed
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
def test_dense_retrieval():
|
||||
"""测试稠密检索"""
|
||||
print("\n" + "="*70)
|
||||
print("2. 测试稠密检索...")
|
||||
print("="*70)
|
||||
async def test_parent_hybrid_retriever():
|
||||
"""测试:ParentHybridRetriever(父子文档混合检索)"""
|
||||
print("\n" + "="*80)
|
||||
print("测试 3: ParentHybridRetriever (父子文档混合检索)")
|
||||
print("="*80)
|
||||
|
||||
# 获取嵌入服务
|
||||
embeddings = get_embedding_service()
|
||||
|
||||
# 创建向量存储
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
|
||||
# 测试查询
|
||||
query = "The Ant and the Grasshopper"
|
||||
print(f"查询: {query}")
|
||||
|
||||
results = vs.similarity_search(query, k=3)
|
||||
|
||||
print(f"\n找到 {len(results)} 个结果:")
|
||||
for i, doc in enumerate(results, 1):
|
||||
print(f"\n{i}. (来源: {doc.metadata.get('source', 'unknown')})")
|
||||
print(f" 元数据: {doc.metadata}")
|
||||
content = doc.page_content.strip()
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
print(f" 内容: {content}")
|
||||
|
||||
|
||||
def test_sparse_retrieval_simple():
|
||||
"""简单测试稀疏检索"""
|
||||
print("\n" + "="*70)
|
||||
print("3. 测试稀疏检索(BM25)...")
|
||||
print("="*70)
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
# 测试查询 - 用关键词
|
||||
query = "winter work food"
|
||||
print(f"查询关键词: {query}")
|
||||
|
||||
# 生成稀疏查询向量
|
||||
sparse_query = sparse_embedder.embed_query(query)
|
||||
|
||||
# 包装成 SparseVector 对象
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 直接查询稀疏向量
|
||||
response = client.query_points(
|
||||
retriever = create_parent_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=3,
|
||||
with_payload=True
|
||||
search_k=10
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(response.points)} 个结果:")
|
||||
for i, point in enumerate(response.points, 1):
|
||||
print(f"\n{i}. (分数: {point.score:.4f})")
|
||||
text = point.payload.get("text", "")
|
||||
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||
print(f" 元数据: {metadata}")
|
||||
content = text.strip()
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
print(f" 内容: {content}")
|
||||
|
||||
|
||||
def test_hybrid_retrieval_simple():
|
||||
"""简单测试混合检索(稠密+稀疏 RRF 融合)"""
|
||||
print("\n" + "="*70)
|
||||
print("4. 测试混合检索(稠密+稀疏 RRF 融合)...")
|
||||
print("="*70)
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
# 测试查询
|
||||
query = "Ant and Grasshopper story"
|
||||
print(f"查询: {query}")
|
||||
|
||||
# 生成双向量
|
||||
dense_query = embeddings.embed_query(query)
|
||||
sparse_query = sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 使用 Qdrant 的 query_points 做混合检索
|
||||
response = client.query_points(
|
||||
collection_name="rag_documents",
|
||||
prefetch=[
|
||||
models.Prefetch(
|
||||
query=dense_query,
|
||||
using="dense",
|
||||
limit=3
|
||||
),
|
||||
models.Prefetch(
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=3
|
||||
)
|
||||
],
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
limit=3,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(response.points)} 个结果:")
|
||||
for i, point in enumerate(response.points, 1):
|
||||
print(f"\n{i}. (RRF 融合分数: {point.score:.4f})")
|
||||
text = point.payload.get("text", "")
|
||||
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||
print(f" 元数据: {metadata}")
|
||||
content = text.strip()
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
print(f" 内容: {content}")
|
||||
|
||||
|
||||
def test_parent_child_retrieval_simple():
|
||||
"""简单测试父子文档检索"""
|
||||
print("\n" + "="*70)
|
||||
print("5. 测试父子文档混合检索...")
|
||||
print("="*70)
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
# 测试查询
|
||||
query = "The Ant and the Grasshopper story moral"
|
||||
print(f"查询: {query}")
|
||||
|
||||
# 生成双向量
|
||||
dense_query = embeddings.embed_query(query)
|
||||
sparse_query = sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 先做混合检索找到子文档
|
||||
response = client.query_points(
|
||||
collection_name="rag_documents",
|
||||
prefetch=[
|
||||
models.Prefetch(
|
||||
query=dense_query,
|
||||
using="dense",
|
||||
limit=5
|
||||
),
|
||||
models.Prefetch(
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=5
|
||||
)
|
||||
],
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
limit=5,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# 收集 parent_id
|
||||
parent_score_map = {}
|
||||
child_points = {}
|
||||
for point in response.points:
|
||||
parent_id = point.payload.get("parent_id", point.id)
|
||||
score = point.score
|
||||
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
|
||||
parent_score_map[parent_id] = score
|
||||
child_points[parent_id] = point
|
||||
|
||||
parent_ids = list(parent_score_map.keys())
|
||||
|
||||
print(f"\n找到 {len(parent_ids)} 个不同的 parent_id:")
|
||||
|
||||
# 查找父文档
|
||||
if parent_ids:
|
||||
parent_docs = client.retrieve(
|
||||
collection_name="rag_documents",
|
||||
ids=parent_ids,
|
||||
with_payload=True
|
||||
)
|
||||
for query in TEST_QUERIES:
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 60)
|
||||
|
||||
found_parent_ids = {p.id for p in parent_docs}
|
||||
docs = await retriever.ainvoke(query)
|
||||
|
||||
# 准备结果列表
|
||||
results = []
|
||||
for p in parent_docs:
|
||||
score = parent_score_map[p.id]
|
||||
results.append((p, score))
|
||||
|
||||
# 处理没找到父文档的情况 - 用子文档代替
|
||||
missing = set(parent_ids) - found_parent_ids
|
||||
for parent_id in missing:
|
||||
child_point = child_points[parent_id]
|
||||
print(f"\n注意: parent_id {parent_id} 未找到,使用子文档代替")
|
||||
results.append((child_point, parent_score_map[parent_id]))
|
||||
|
||||
# 按分数排序
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 显示
|
||||
print(f"\n共 {len(results)} 个结果(去重后):")
|
||||
for i, (point, score) in enumerate(results[:3], 1):
|
||||
print(f"\n{i}. (分数: {score:.4f})")
|
||||
text = point.payload.get("text", "")
|
||||
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||
print(f" 元数据: {metadata}")
|
||||
content = text.strip()
|
||||
if len(content) > 400:
|
||||
content = content[:400] + "..."
|
||||
print(f" 内容: {content}")
|
||||
else:
|
||||
print("\n未找到结果")
|
||||
if docs:
|
||||
print(f"✓ 找到 {len(docs)} 个父文档")
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n {i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||||
preview = doc.page_content[:150].strip()
|
||||
if len(doc.page_content) > 150:
|
||||
preview += "..."
|
||||
print(f" 内容:\n {preview}")
|
||||
else:
|
||||
print("✗ 未找到结果")
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
# 1. 先构建索引
|
||||
await test_index_builder()
|
||||
print("\n" + "="*80)
|
||||
print("RAG 检索功能测试")
|
||||
print("="*80)
|
||||
|
||||
# 2. 测试稠密检索
|
||||
test_dense_retrieval()
|
||||
# 测试 1: 直接使用 vector store
|
||||
await test_simple_vector_store_search()
|
||||
|
||||
# 3. 测试稀疏检索
|
||||
test_sparse_retrieval_simple()
|
||||
# 测试 2: HybridRetriever
|
||||
await test_hybrid_retriever()
|
||||
|
||||
# 4. 测试混合检索
|
||||
test_hybrid_retrieval_simple()
|
||||
# 测试 3: ParentHybridRetriever
|
||||
await test_parent_hybrid_retriever()
|
||||
|
||||
# 5. 测试父子文档检索
|
||||
test_parent_child_retrieval_simple()
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("所有测试完成!")
|
||||
print("="*70)
|
||||
print("\n🎉 所有测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user