2026-05-04 17:58:10 +08:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
"""
|
|
|
|
|
|
完整的 RAG Pipeline 测试
|
|
|
|
|
|
测试从查询改写 → 检索 → RRF融合 → 重排序 → 格式化输出的整个流程
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline
|
|
|
|
|
|
from backend.app.rag.tools import create_rag_tool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_rag_pipeline_direct():
|
|
|
|
|
|
"""测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)"""
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
print("测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)")
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建 pipeline(默认用小模型)
|
|
|
|
|
|
pipeline = create_rag_pipeline(
|
|
|
|
|
|
collection_name="rag_documents",
|
|
|
|
|
|
num_queries=3,
|
|
|
|
|
|
rerank_top_n=5
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-05-04 18:59:15 +08:00
|
|
|
|
query = "吕布的经历"
|
2026-05-04 17:58:10 +08:00
|
|
|
|
|
|
|
|
|
|
print(f"\n用户查询: {query}")
|
|
|
|
|
|
print("-" * 80)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行检索
|
|
|
|
|
|
docs = await pipeline.aretrieve(query)
|
|
|
|
|
|
|
|
|
|
|
|
if docs:
|
|
|
|
|
|
print(f"\n✓ 找到 {len(docs)} 个相关文档")
|
|
|
|
|
|
print("-" * 80)
|
|
|
|
|
|
|
|
|
|
|
|
for i, doc in enumerate(docs, 1):
|
|
|
|
|
|
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
|
|
|
|
|
print(f" 内容:\n{doc.page_content}")
|
|
|
|
|
|
print("-" * 80)
|
|
|
|
|
|
|
|
|
|
|
|
# 格式化输出
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
|
|
|
print("格式化后的上下文:")
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
formatted_context = pipeline.format_context(docs)
|
|
|
|
|
|
print(formatted_context)
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("\n✗ 未找到相关文档")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_rag_tool():
|
|
|
|
|
|
"""测试 2: 使用 RAG Tool(默认用小模型做查询改写)"""
|
|
|
|
|
|
print("\n"+"="*80)
|
|
|
|
|
|
print("测试 2: 使用 RAG Tool(默认用小模型做查询改写)")
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建 tool(默认用小模型)
|
|
|
|
|
|
rag_tool = create_rag_tool(
|
|
|
|
|
|
collection_name="rag_documents",
|
|
|
|
|
|
num_queries=3,
|
|
|
|
|
|
rerank_top_n=5
|
|
|
|
|
|
)
|
2026-05-04 18:59:15 +08:00
|
|
|
|
query = "吕布的经历"
|
2026-05-04 17:58:10 +08:00
|
|
|
|
|
|
|
|
|
|
print(f"\n用户查询: {query}")
|
|
|
|
|
|
print("-" * 80)
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 tool (异步调用 ainvoke)
|
|
|
|
|
|
result = await rag_tool.ainvoke(query)
|
|
|
|
|
|
|
|
|
|
|
|
print("\nTool 返回结果:")
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
print(result)
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_custom_pipeline():
|
|
|
|
|
|
"""测试 3: 自定义参数的 RAGPipeline(默认用小模型)"""
|
|
|
|
|
|
print("\n"+"="*80)
|
|
|
|
|
|
print("测试 3: 自定义参数的 RAGPipeline(默认用小模型)")
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
|
|
|
|
|
|
# 自定义参数(默认用小模型)
|
|
|
|
|
|
pipeline = RAGPipeline(
|
|
|
|
|
|
collection_name="rag_documents",
|
|
|
|
|
|
num_queries=2, # 只生成 2 个查询变体
|
|
|
|
|
|
rerank_top_n=3 # 只返回前 3 个最相关文档
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-05-04 18:59:15 +08:00
|
|
|
|
query = "吕布的经历"
|
2026-05-04 17:58:10 +08:00
|
|
|
|
|
|
|
|
|
|
print(f"\n用户查询: {query}")
|
|
|
|
|
|
print(f"配置: num_queries=2, rerank_top_n=3")
|
|
|
|
|
|
print("-" * 80)
|
|
|
|
|
|
|
|
|
|
|
|
docs = await pipeline.aretrieve(query)
|
|
|
|
|
|
|
|
|
|
|
|
if docs:
|
|
|
|
|
|
print(f"\n✓ 找到 {len(docs)} 个相关文档")
|
|
|
|
|
|
print("-" * 80)
|
|
|
|
|
|
|
|
|
|
|
|
for i, doc in enumerate(docs, 1):
|
|
|
|
|
|
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
|
|
|
|
|
preview = doc.page_content[:200].strip()
|
|
|
|
|
|
if len(doc.page_content) > 200:
|
|
|
|
|
|
preview += "..."
|
|
|
|
|
|
print(f" 内容预览: {preview}")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
|
|
|
print("格式化后的上下文:")
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
print(pipeline.format_context(docs))
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("\n✗ 未找到相关文档")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main():
|
|
|
|
|
|
"""主测试函数"""
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
|
|
|
print("完整 RAG Pipeline 测试")
|
2026-05-04 18:59:15 +08:00
|
|
|
|
print("查询: '吕布的经历'")
|
2026-05-04 17:58:10 +08:00
|
|
|
|
print("="*80)
|
|
|
|
|
|
|
|
|
|
|
|
# 测试 1: 直接使用 pipeline
|
|
|
|
|
|
await test_rag_pipeline_direct()
|
|
|
|
|
|
|
|
|
|
|
|
# 测试 2: 使用 tool
|
|
|
|
|
|
await test_rag_tool()
|
|
|
|
|
|
|
|
|
|
|
|
# 测试 3: 自定义参数
|
|
|
|
|
|
await test_custom_pipeline()
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
|
|
|
print("🎉 所有 RAG Pipeline 测试完成!")
|
|
|
|
|
|
print("="*80)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
asyncio.run(main())
|