146 lines
3.9 KiB
Python
146 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
完整的 RAG Pipeline 测试
|
||
测试从查询改写 → 检索 → RRF融合 → 重排序 → 格式化输出的整个流程
|
||
"""
|
||
|
||
import asyncio
|
||
from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline
|
||
from backend.app.rag.tools import create_rag_tool
|
||
|
||
|
||
async def test_rag_pipeline_direct():
|
||
"""测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)"""
|
||
print("="*80)
|
||
print("测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)")
|
||
print("="*80)
|
||
|
||
# 创建 pipeline(默认用小模型)
|
||
pipeline = create_rag_pipeline(
|
||
collection_name="rag_documents",
|
||
num_queries=3,
|
||
rerank_top_n=5
|
||
)
|
||
|
||
query = "黄双银的经历"
|
||
|
||
print(f"\n用户查询: {query}")
|
||
print("-" * 80)
|
||
|
||
# 执行检索
|
||
docs = await pipeline.aretrieve(query)
|
||
|
||
if docs:
|
||
print(f"\n✓ 找到 {len(docs)} 个相关文档")
|
||
print("-" * 80)
|
||
|
||
for i, doc in enumerate(docs, 1):
|
||
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||
print(f" 内容:\n{doc.page_content}")
|
||
print("-" * 80)
|
||
|
||
# 格式化输出
|
||
print("\n" + "="*80)
|
||
print("格式化后的上下文:")
|
||
print("="*80)
|
||
formatted_context = pipeline.format_context(docs)
|
||
print(formatted_context)
|
||
else:
|
||
print("\n✗ 未找到相关文档")
|
||
|
||
print("\n" + "="*80)
|
||
|
||
|
||
async def test_rag_tool():
|
||
"""测试 2: 使用 RAG Tool(默认用小模型做查询改写)"""
|
||
print("\n"+"="*80)
|
||
print("测试 2: 使用 RAG Tool(默认用小模型做查询改写)")
|
||
print("="*80)
|
||
|
||
# 创建 tool(默认用小模型)
|
||
rag_tool = create_rag_tool(
|
||
collection_name="rag_documents",
|
||
num_queries=3,
|
||
rerank_top_n=5
|
||
)
|
||
|
||
query = "黄双银的经历"
|
||
|
||
print(f"\n用户查询: {query}")
|
||
print("-" * 80)
|
||
|
||
# 使用 tool (异步调用 ainvoke)
|
||
result = await rag_tool.ainvoke(query)
|
||
|
||
print("\nTool 返回结果:")
|
||
print("="*80)
|
||
print(result)
|
||
print("="*80)
|
||
|
||
|
||
async def test_custom_pipeline():
|
||
"""测试 3: 自定义参数的 RAGPipeline(默认用小模型)"""
|
||
print("\n"+"="*80)
|
||
print("测试 3: 自定义参数的 RAGPipeline(默认用小模型)")
|
||
print("="*80)
|
||
|
||
# 自定义参数(默认用小模型)
|
||
pipeline = RAGPipeline(
|
||
collection_name="rag_documents",
|
||
num_queries=2, # 只生成 2 个查询变体
|
||
rerank_top_n=3 # 只返回前 3 个最相关文档
|
||
)
|
||
|
||
query = "黄双银的经历"
|
||
|
||
print(f"\n用户查询: {query}")
|
||
print(f"配置: num_queries=2, rerank_top_n=3")
|
||
print("-" * 80)
|
||
|
||
docs = await pipeline.aretrieve(query)
|
||
|
||
if docs:
|
||
print(f"\n✓ 找到 {len(docs)} 个相关文档")
|
||
print("-" * 80)
|
||
|
||
for i, doc in enumerate(docs, 1):
|
||
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||
preview = doc.page_content[:200].strip()
|
||
if len(doc.page_content) > 200:
|
||
preview += "..."
|
||
print(f" 内容预览: {preview}")
|
||
|
||
print("\n" + "="*80)
|
||
print("格式化后的上下文:")
|
||
print("="*80)
|
||
print(pipeline.format_context(docs))
|
||
else:
|
||
print("\n✗ 未找到相关文档")
|
||
|
||
print("\n" + "="*80)
|
||
|
||
|
||
async def main():
|
||
"""主测试函数"""
|
||
print("\n" + "="*80)
|
||
print("完整 RAG Pipeline 测试")
|
||
print("查询: '黄双银的经历'")
|
||
print("="*80)
|
||
|
||
# 测试 1: 直接使用 pipeline
|
||
await test_rag_pipeline_direct()
|
||
|
||
# 测试 2: 使用 tool
|
||
await test_rag_tool()
|
||
|
||
# 测试 3: 自定义参数
|
||
await test_custom_pipeline()
|
||
|
||
print("\n" + "="*80)
|
||
print("🎉 所有 RAG Pipeline 测试完成!")
|
||
print("="*80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|