测试修改
This commit is contained in:
142
test/test_rag.py
Normal file
142
test/test_rag.py
Normal file
@@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RAG 系统使用示例(重构版)
|
||||
|
||||
演示:
|
||||
1. 使用 IndexBuilder 获取父子块检索器
|
||||
2. 创建固定流程的 RAGPipeline(多路改写 → RRF融合 → 重排序 → 返回父文档)
|
||||
3. 将流水线封装为 LangChain 工具,供 Agent 调用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量(Qdrant URL、PostgreSQL 连接等)
|
||||
load_dotenv()
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from pydantic import SecretStr
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag_indexer.index_builder import IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
from backend.app.rag.pipeline import RAGPipeline
|
||||
from backend.app.rag.tools import create_rag_tool_sync
|
||||
from backend.rag_core.retriever_factory import create_parent_retriever
|
||||
|
||||
def create_llm():
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://127.0.0.1:8081/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
base_url=vllm_base_url,
|
||||
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
async def demonstrate_full_pipeline():
|
||||
"""
|
||||
完整流水线演示:
|
||||
- 从 IndexBuilder 获取 ParentDocumentRetriever
|
||||
- 创建 RAGPipeline
|
||||
- 执行检索并打印结果
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
|
||||
print("=" * 60)
|
||||
|
||||
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
if retriever is None:
|
||||
print("错误:检索器未初始化,请确保索引已构建。")
|
||||
return
|
||||
|
||||
# 3. 创建 LLM 用于查询改写
|
||||
llm = create_llm()
|
||||
|
||||
# 4. 创建 RAGPipeline(固定流程)
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3, # 生成 3 个查询变体
|
||||
rerank_top_n=5, # 最终返回 5 个父文档
|
||||
)
|
||||
|
||||
# 5. 执行检索
|
||||
query = "打虎英雄是谁?"
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
print(f"返回 {len(documents)} 个父文档\n")
|
||||
|
||||
# 打印结果预览
|
||||
for i, doc in enumerate(documents, 1):
|
||||
content_preview = doc.page_content.replace("\n", " ")[:150]
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
print(f"{i}. 【来源:{source}】")
|
||||
print(f" {content_preview}...\n")
|
||||
|
||||
# 可选:格式化完整上下文
|
||||
# context = pipeline.format_context(documents)
|
||||
# print(context)
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def demonstrate_tool_creation():
|
||||
"""
|
||||
演示创建 RAG 工具(供 Agent 使用)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("演示:创建 RAG 工具(供 LangGraph Agent 调用)")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取检索器(同上)
|
||||
config = IndexBuilderConfig(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
)
|
||||
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
# 2. 创建 LLM
|
||||
llm = create_llm()
|
||||
|
||||
# 3. 创建工具
|
||||
rag_tool = create_rag_tool_sync(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
collection_name="rag_documents",
|
||||
)
|
||||
|
||||
print(f"工具名称: {rag_tool.name}")
|
||||
print(f"工具描述: {rag_tool.description[:100]}...")
|
||||
|
||||
# 4. 模拟 Agent 调用工具
|
||||
query = "请告诉我 打虎英雄是谁?"
|
||||
print(f"\n模拟调用: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
result = await rag_tool.ainvoke({"query": query})
|
||||
print(result[:800] + "..." if len(result) > 800 else result)
|
||||
|
||||
async def main():
|
||||
await demonstrate_full_pipeline()
|
||||
await demonstrate_tool_creation()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -40,42 +40,7 @@ async def test_index_builder():
|
||||
print(f"集合信息: {info}")
|
||||
else:
|
||||
print(f"测试文件不存在: {test_file}")
|
||||
|
||||
# 测试搜索功能
|
||||
print("\n测试搜索功能...")
|
||||
try:
|
||||
results = builder.search("吕布", k=3)
|
||||
print(f"搜索结果数量: {len(results)}")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n结果 {i+1}:")
|
||||
print(f"内容: {result.page_content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"搜索测试失败: {e}")
|
||||
|
||||
# 测试带父块上下文的搜索
|
||||
print("\n测试带父块上下文的搜索...")
|
||||
try:
|
||||
results = await builder.search_with_parent_context("吕布", k=3)
|
||||
print(f"搜索结果数量: {len(results)}")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n结果 {i+1}:")
|
||||
print(f"内容: {result.page_content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"带父块上下文的搜索测试失败: {e}")
|
||||
|
||||
# 测试统一检索接口
|
||||
print("\n测试统一检索接口...")
|
||||
try:
|
||||
# 返回父块
|
||||
results_parent = await builder.retrieve("吕布", return_parent=True)
|
||||
print(f"返回父块的结果数量: {len(results_parent)}")
|
||||
|
||||
# 返回子块
|
||||
results_child = await builder.retrieve("吕布", return_parent=False)
|
||||
print(f"返回子块的结果数量: {len(results_child)}")
|
||||
except Exception as e:
|
||||
print(f"统一检索接口测试失败: {e}")
|
||||
|
||||
|
||||
# 关闭资源
|
||||
builder.close()
|
||||
print("\n测试完成")
|
||||
|
||||
Reference in New Issue
Block a user