Files
ailine/tools/test/test_rag_pipeline.py
root 9841f47432
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
refactor: 重构RAG核心组件,简化代码结构和测试文件
2026-05-04 17:58:10 +08:00

146 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
完整的 RAG Pipeline 测试
测试从查询改写 → 检索 → RRF融合 → 重排序 → 格式化输出的整个流程
"""
import asyncio
from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline
from backend.app.rag.tools import create_rag_tool
async def test_rag_pipeline_direct():
"""测试 1: 直接使用 RAGPipeline默认用小模型做查询改写"""
print("="*80)
print("测试 1: 直接使用 RAGPipeline默认用小模型做查询改写")
print("="*80)
# 创建 pipeline默认用小模型
pipeline = create_rag_pipeline(
collection_name="rag_documents",
num_queries=3,
rerank_top_n=5
)
query = "黄双银的经历"
print(f"\n用户查询: {query}")
print("-" * 80)
# 执行检索
docs = await pipeline.aretrieve(query)
if docs:
print(f"\n✓ 找到 {len(docs)} 个相关文档")
print("-" * 80)
for i, doc in enumerate(docs, 1):
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
print(f" 内容:\n{doc.page_content}")
print("-" * 80)
# 格式化输出
print("\n" + "="*80)
print("格式化后的上下文:")
print("="*80)
formatted_context = pipeline.format_context(docs)
print(formatted_context)
else:
print("\n✗ 未找到相关文档")
print("\n" + "="*80)
async def test_rag_tool():
"""测试 2: 使用 RAG Tool默认用小模型做查询改写"""
print("\n"+"="*80)
print("测试 2: 使用 RAG Tool默认用小模型做查询改写")
print("="*80)
# 创建 tool默认用小模型
rag_tool = create_rag_tool(
collection_name="rag_documents",
num_queries=3,
rerank_top_n=5
)
query = "黄双银的经历"
print(f"\n用户查询: {query}")
print("-" * 80)
# 使用 tool (异步调用 ainvoke)
result = await rag_tool.ainvoke(query)
print("\nTool 返回结果:")
print("="*80)
print(result)
print("="*80)
async def test_custom_pipeline():
"""测试 3: 自定义参数的 RAGPipeline默认用小模型"""
print("\n"+"="*80)
print("测试 3: 自定义参数的 RAGPipeline默认用小模型")
print("="*80)
# 自定义参数(默认用小模型)
pipeline = RAGPipeline(
collection_name="rag_documents",
num_queries=2, # 只生成 2 个查询变体
rerank_top_n=3 # 只返回前 3 个最相关文档
)
query = "黄双银的经历"
print(f"\n用户查询: {query}")
print(f"配置: num_queries=2, rerank_top_n=3")
print("-" * 80)
docs = await pipeline.aretrieve(query)
if docs:
print(f"\n✓ 找到 {len(docs)} 个相关文档")
print("-" * 80)
for i, doc in enumerate(docs, 1):
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
preview = doc.page_content[:200].strip()
if len(doc.page_content) > 200:
preview += "..."
print(f" 内容预览: {preview}")
print("\n" + "="*80)
print("格式化后的上下文:")
print("="*80)
print(pipeline.format_context(docs))
else:
print("\n✗ 未找到相关文档")
print("\n" + "="*80)
async def main():
"""主测试函数"""
print("\n" + "="*80)
print("完整 RAG Pipeline 测试")
print("查询: '黄双银的经历'")
print("="*80)
# 测试 1: 直接使用 pipeline
await test_rag_pipeline_direct()
# 测试 2: 使用 tool
await test_rag_tool()
# 测试 3: 自定义参数
await test_custom_pipeline()
print("\n" + "="*80)
print("🎉 所有 RAG Pipeline 测试完成!")
print("="*80)
if __name__ == "__main__":
asyncio.run(main())