Files
ailine/tools/test/test_rag_pipeline.py

146 lines
3.9 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
完整的 RAG Pipeline 测试
测试从查询改写 检索 RRF融合 重排序 格式化输出的整个流程
"""
import asyncio
from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline
from backend.app.rag.tools import create_rag_tool
async def test_rag_pipeline_direct():
"""测试 1: 直接使用 RAGPipeline默认用小模型做查询改写"""
print("="*80)
print("测试 1: 直接使用 RAGPipeline默认用小模型做查询改写")
print("="*80)
# 创建 pipeline默认用小模型
pipeline = create_rag_pipeline(
collection_name="rag_documents",
num_queries=3,
rerank_top_n=5
)
query = "吕布的经历"
print(f"\n用户查询: {query}")
print("-" * 80)
# 执行检索
docs = await pipeline.aretrieve(query)
if docs:
print(f"\n✓ 找到 {len(docs)} 个相关文档")
print("-" * 80)
for i, doc in enumerate(docs, 1):
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
print(f" 内容:\n{doc.page_content}")
print("-" * 80)
# 格式化输出
print("\n" + "="*80)
print("格式化后的上下文:")
print("="*80)
formatted_context = pipeline.format_context(docs)
print(formatted_context)
else:
print("\n✗ 未找到相关文档")
print("\n" + "="*80)
async def test_rag_tool():
"""测试 2: 使用 RAG Tool默认用小模型做查询改写"""
print("\n"+"="*80)
print("测试 2: 使用 RAG Tool默认用小模型做查询改写")
print("="*80)
# 创建 tool默认用小模型
rag_tool = create_rag_tool(
collection_name="rag_documents",
num_queries=3,
rerank_top_n=5
)
query = "吕布的经历"
print(f"\n用户查询: {query}")
print("-" * 80)
# 使用 tool (异步调用 ainvoke)
result = await rag_tool.ainvoke(query)
print("\nTool 返回结果:")
print("="*80)
print(result)
print("="*80)
async def test_custom_pipeline():
"""测试 3: 自定义参数的 RAGPipeline默认用小模型"""
print("\n"+"="*80)
print("测试 3: 自定义参数的 RAGPipeline默认用小模型")
print("="*80)
# 自定义参数(默认用小模型)
pipeline = RAGPipeline(
collection_name="rag_documents",
num_queries=2, # 只生成 2 个查询变体
rerank_top_n=3 # 只返回前 3 个最相关文档
)
query = "吕布的经历"
print(f"\n用户查询: {query}")
print(f"配置: num_queries=2, rerank_top_n=3")
print("-" * 80)
docs = await pipeline.aretrieve(query)
if docs:
print(f"\n✓ 找到 {len(docs)} 个相关文档")
print("-" * 80)
for i, doc in enumerate(docs, 1):
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
preview = doc.page_content[:200].strip()
if len(doc.page_content) > 200:
preview += "..."
print(f" 内容预览: {preview}")
print("\n" + "="*80)
print("格式化后的上下文:")
print("="*80)
print(pipeline.format_context(docs))
else:
print("\n✗ 未找到相关文档")
print("\n" + "="*80)
async def main():
"""主测试函数"""
print("\n" + "="*80)
print("完整 RAG Pipeline 测试")
print("查询: '吕布的经历'")
print("="*80)
# 测试 1: 直接使用 pipeline
await test_rag_pipeline_direct()
# 测试 2: 使用 tool
await test_rag_tool()
# 测试 3: 自定义参数
await test_custom_pipeline()
print("\n" + "="*80)
print("🎉 所有 RAG Pipeline 测试完成!")
print("="*80)
if __name__ == "__main__":
asyncio.run(main())