测试修改

This commit is contained in:
2026-04-21 20:49:10 +08:00
parent 37e86f3bb1
commit 5e9bbd519f
6 changed files with 18 additions and 53 deletions

142
test/test_rag.py Normal file
View File

@@ -0,0 +1,142 @@
#!/usr/bin/env python3
"""
RAG 系统使用示例(重构版)
演示:
1. 使用 IndexBuilder 获取父子块检索器
2. 创建固定流程的 RAGPipeline多路改写 → RRF融合 → 重排序 → 返回父文档)
3. 将流水线封装为 LangChain 工具,供 Agent 调用
"""
import asyncio
import sys
import os
from dotenv import load_dotenv
# 加载环境变量Qdrant URL、PostgreSQL 连接等)
load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from pydantic import SecretStr
from langchain_openai import ChatOpenAI
from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType
from backend.app.rag.pipeline import RAGPipeline
from backend.app.rag.tools import create_rag_tool_sync
from backend.rag_core.retriever_factory import create_parent_retriever
def create_llm():
"""创建本地 vLLM 服务 LLM"""
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://127.0.0.1:8081/v1"
)
return ChatOpenAI(
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
async def demonstrate_full_pipeline():
"""
完整流水线演示:
- 从 IndexBuilder 获取 ParentDocumentRetriever
- 创建 RAGPipeline
- 执行检索并打印结果
"""
print("=" * 60)
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
print("=" * 60)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
if retriever is None:
print("错误:检索器未初始化,请确保索引已构建。")
return
# 3. 创建 LLM 用于查询改写
llm = create_llm()
# 4. 创建 RAGPipeline固定流程
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=3, # 生成 3 个查询变体
rerank_top_n=5, # 最终返回 5 个父文档
)
# 5. 执行检索
query = "打虎英雄是谁?"
print(f"\n查询: {query}")
print("-" * 40)
try:
documents = await pipeline.aretrieve(query)
print(f"返回 {len(documents)} 个父文档\n")
# 打印结果预览
for i, doc in enumerate(documents, 1):
content_preview = doc.page_content.replace("\n", " ")[:150]
source = doc.metadata.get("source", "未知来源")
print(f"{i}. 【来源:{source}")
print(f" {content_preview}...\n")
# 可选:格式化完整上下文
# context = pipeline.format_context(documents)
# print(context)
except Exception as e:
print(f"检索失败: {e}")
import traceback
traceback.print_exc()
async def demonstrate_tool_creation():
"""
演示创建 RAG 工具(供 Agent 使用)
"""
print("\n" + "=" * 60)
print("演示:创建 RAG 工具(供 LangGraph Agent 调用)")
print("=" * 60)
# 1. 获取检索器(同上)
config = IndexBuilderConfig(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM
llm = create_llm()
# 3. 创建工具
rag_tool = create_rag_tool_sync(
retriever=retriever,
llm=llm,
num_queries=3,
rerank_top_n=5,
collection_name="rag_documents",
)
print(f"工具名称: {rag_tool.name}")
print(f"工具描述: {rag_tool.description[:100]}...")
# 4. 模拟 Agent 调用工具
query = "请告诉我 打虎英雄是谁?"
print(f"\n模拟调用: {query}")
print("-" * 40)
result = await rag_tool.ainvoke({"query": query})
print(result[:800] + "..." if len(result) > 800 else result)
async def main():
await demonstrate_full_pipeline()
await demonstrate_tool_creation()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -40,42 +40,7 @@ async def test_index_builder():
print(f"集合信息: {info}")
else:
print(f"测试文件不存在: {test_file}")
# 测试搜索功能
print("\n测试搜索功能...")
try:
results = builder.search("吕布", k=3)
print(f"搜索结果数量: {len(results)}")
for i, result in enumerate(results):
print(f"\n结果 {i+1}:")
print(f"内容: {result.page_content[:100]}...")
except Exception as e:
print(f"搜索测试失败: {e}")
# 测试带父块上下文的搜索
print("\n测试带父块上下文的搜索...")
try:
results = await builder.search_with_parent_context("吕布", k=3)
print(f"搜索结果数量: {len(results)}")
for i, result in enumerate(results):
print(f"\n结果 {i+1}:")
print(f"内容: {result.page_content[:100]}...")
except Exception as e:
print(f"带父块上下文的搜索测试失败: {e}")
# 测试统一检索接口
print("\n测试统一检索接口...")
try:
# 返回父块
results_parent = await builder.retrieve("吕布", return_parent=True)
print(f"返回父块的结果数量: {len(results_parent)}")
# 返回子块
results_child = await builder.retrieve("吕布", return_parent=False)
print(f"返回子块的结果数量: {len(results_child)}")
except Exception as e:
print(f"统一检索接口测试失败: {e}")
# 关闭资源
builder.close()
print("\n测试完成")