Files
ailine/tools/test/test_rag_pipeline.py
root c9bf21be0e
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
fix: 修复 RAG 无限循环问题和导入错误
主要修复:
1. 修复 RAG 推理无限循环问题(大小写不匹配 + 缺少已检索结果检查)
2. 修复 intent_classifier.py 的绝对导入错误
3. 删除旧的 start.sh 脚本,添加新的启动脚本
4. 优化路由逻辑和状态管理
2026-05-04 18:59:15 +08:00

146 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
完整的 RAG Pipeline 测试
测试从查询改写 → 检索 → RRF融合 → 重排序 → 格式化输出的整个流程
"""
import asyncio
from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline
from backend.app.rag.tools import create_rag_tool
async def test_rag_pipeline_direct():
"""测试 1: 直接使用 RAGPipeline默认用小模型做查询改写"""
print("="*80)
print("测试 1: 直接使用 RAGPipeline默认用小模型做查询改写")
print("="*80)
# 创建 pipeline默认用小模型
pipeline = create_rag_pipeline(
collection_name="rag_documents",
num_queries=3,
rerank_top_n=5
)
query = "吕布的经历"
print(f"\n用户查询: {query}")
print("-" * 80)
# 执行检索
docs = await pipeline.aretrieve(query)
if docs:
print(f"\n✓ 找到 {len(docs)} 个相关文档")
print("-" * 80)
for i, doc in enumerate(docs, 1):
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
print(f" 内容:\n{doc.page_content}")
print("-" * 80)
# 格式化输出
print("\n" + "="*80)
print("格式化后的上下文:")
print("="*80)
formatted_context = pipeline.format_context(docs)
print(formatted_context)
else:
print("\n✗ 未找到相关文档")
print("\n" + "="*80)
async def test_rag_tool():
"""测试 2: 使用 RAG Tool默认用小模型做查询改写"""
print("\n"+"="*80)
print("测试 2: 使用 RAG Tool默认用小模型做查询改写")
print("="*80)
# 创建 tool默认用小模型
rag_tool = create_rag_tool(
collection_name="rag_documents",
num_queries=3,
rerank_top_n=5
)
query = "吕布的经历"
print(f"\n用户查询: {query}")
print("-" * 80)
# 使用 tool (异步调用 ainvoke)
result = await rag_tool.ainvoke(query)
print("\nTool 返回结果:")
print("="*80)
print(result)
print("="*80)
async def test_custom_pipeline():
"""测试 3: 自定义参数的 RAGPipeline默认用小模型"""
print("\n"+"="*80)
print("测试 3: 自定义参数的 RAGPipeline默认用小模型")
print("="*80)
# 自定义参数(默认用小模型)
pipeline = RAGPipeline(
collection_name="rag_documents",
num_queries=2, # 只生成 2 个查询变体
rerank_top_n=3 # 只返回前 3 个最相关文档
)
query = "吕布的经历"
print(f"\n用户查询: {query}")
print(f"配置: num_queries=2, rerank_top_n=3")
print("-" * 80)
docs = await pipeline.aretrieve(query)
if docs:
print(f"\n✓ 找到 {len(docs)} 个相关文档")
print("-" * 80)
for i, doc in enumerate(docs, 1):
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
preview = doc.page_content[:200].strip()
if len(doc.page_content) > 200:
preview += "..."
print(f" 内容预览: {preview}")
print("\n" + "="*80)
print("格式化后的上下文:")
print("="*80)
print(pipeline.format_context(docs))
else:
print("\n✗ 未找到相关文档")
print("\n" + "="*80)
async def main():
"""主测试函数"""
print("\n" + "="*80)
print("完整 RAG Pipeline 测试")
print("查询: '吕布的经历'")
print("="*80)
# 测试 1: 直接使用 pipeline
await test_rag_pipeline_direct()
# 测试 2: 使用 tool
await test_rag_tool()
# 测试 3: 自定义参数
await test_custom_pipeline()
print("\n" + "="*80)
print("🎉 所有 RAG Pipeline 测试完成!")
print("="*80)
if __name__ == "__main__":
asyncio.run(main())