refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
This commit is contained in:
145
tools/test/test_rag_pipeline.py
Normal file
145
tools/test/test_rag_pipeline.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
完整的 RAG Pipeline 测试
|
||||
测试从查询改写 → 检索 → RRF融合 → 重排序 → 格式化输出的整个流程
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline
|
||||
from backend.app.rag.tools import create_rag_tool
|
||||
|
||||
|
||||
async def test_rag_pipeline_direct():
|
||||
"""测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)"""
|
||||
print("="*80)
|
||||
print("测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)")
|
||||
print("="*80)
|
||||
|
||||
# 创建 pipeline(默认用小模型)
|
||||
pipeline = create_rag_pipeline(
|
||||
collection_name="rag_documents",
|
||||
num_queries=3,
|
||||
rerank_top_n=5
|
||||
)
|
||||
|
||||
query = "黄双银的经历"
|
||||
|
||||
print(f"\n用户查询: {query}")
|
||||
print("-" * 80)
|
||||
|
||||
# 执行检索
|
||||
docs = await pipeline.aretrieve(query)
|
||||
|
||||
if docs:
|
||||
print(f"\n✓ 找到 {len(docs)} 个相关文档")
|
||||
print("-" * 80)
|
||||
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||||
print(f" 内容:\n{doc.page_content}")
|
||||
print("-" * 80)
|
||||
|
||||
# 格式化输出
|
||||
print("\n" + "="*80)
|
||||
print("格式化后的上下文:")
|
||||
print("="*80)
|
||||
formatted_context = pipeline.format_context(docs)
|
||||
print(formatted_context)
|
||||
else:
|
||||
print("\n✗ 未找到相关文档")
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
async def test_rag_tool():
|
||||
"""测试 2: 使用 RAG Tool(默认用小模型做查询改写)"""
|
||||
print("\n"+"="*80)
|
||||
print("测试 2: 使用 RAG Tool(默认用小模型做查询改写)")
|
||||
print("="*80)
|
||||
|
||||
# 创建 tool(默认用小模型)
|
||||
rag_tool = create_rag_tool(
|
||||
collection_name="rag_documents",
|
||||
num_queries=3,
|
||||
rerank_top_n=5
|
||||
)
|
||||
|
||||
query = "黄双银的经历"
|
||||
|
||||
print(f"\n用户查询: {query}")
|
||||
print("-" * 80)
|
||||
|
||||
# 使用 tool (异步调用 ainvoke)
|
||||
result = await rag_tool.ainvoke(query)
|
||||
|
||||
print("\nTool 返回结果:")
|
||||
print("="*80)
|
||||
print(result)
|
||||
print("="*80)
|
||||
|
||||
|
||||
async def test_custom_pipeline():
|
||||
"""测试 3: 自定义参数的 RAGPipeline(默认用小模型)"""
|
||||
print("\n"+"="*80)
|
||||
print("测试 3: 自定义参数的 RAGPipeline(默认用小模型)")
|
||||
print("="*80)
|
||||
|
||||
# 自定义参数(默认用小模型)
|
||||
pipeline = RAGPipeline(
|
||||
collection_name="rag_documents",
|
||||
num_queries=2, # 只生成 2 个查询变体
|
||||
rerank_top_n=3 # 只返回前 3 个最相关文档
|
||||
)
|
||||
|
||||
query = "黄双银的经历"
|
||||
|
||||
print(f"\n用户查询: {query}")
|
||||
print(f"配置: num_queries=2, rerank_top_n=3")
|
||||
print("-" * 80)
|
||||
|
||||
docs = await pipeline.aretrieve(query)
|
||||
|
||||
if docs:
|
||||
print(f"\n✓ 找到 {len(docs)} 个相关文档")
|
||||
print("-" * 80)
|
||||
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||||
preview = doc.page_content[:200].strip()
|
||||
if len(doc.page_content) > 200:
|
||||
preview += "..."
|
||||
print(f" 内容预览: {preview}")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("格式化后的上下文:")
|
||||
print("="*80)
|
||||
print(pipeline.format_context(docs))
|
||||
else:
|
||||
print("\n✗ 未找到相关文档")
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "="*80)
|
||||
print("完整 RAG Pipeline 测试")
|
||||
print("查询: '黄双银的经历'")
|
||||
print("="*80)
|
||||
|
||||
# 测试 1: 直接使用 pipeline
|
||||
await test_rag_pipeline_direct()
|
||||
|
||||
# 测试 2: 使用 tool
|
||||
await test_rag_tool()
|
||||
|
||||
# 测试 3: 自定义参数
|
||||
await test_custom_pipeline()
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("🎉 所有 RAG Pipeline 测试完成!")
|
||||
print("="*80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user