Files
ailine/test/test_rag.py
2026-04-21 20:49:10 +08:00

142 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
RAG 系统使用示例(重构版)
演示:
1. 使用 IndexBuilder 获取父子块检索器
2. 创建固定流程的 RAGPipeline多路改写 → RRF融合 → 重排序 → 返回父文档)
3. 将流水线封装为 LangChain 工具,供 Agent 调用
"""
import asyncio
import sys
import os
from dotenv import load_dotenv
# 加载环境变量Qdrant URL、PostgreSQL 连接等)
load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from pydantic import SecretStr
from langchain_openai import ChatOpenAI
from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType
from backend.app.rag.pipeline import RAGPipeline
from backend.app.rag.tools import create_rag_tool_sync
from backend.rag_core.retriever_factory import create_parent_retriever
def create_llm():
"""创建本地 vLLM 服务 LLM"""
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://127.0.0.1:8081/v1"
)
return ChatOpenAI(
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
async def demonstrate_full_pipeline():
"""
完整流水线演示:
- 从 IndexBuilder 获取 ParentDocumentRetriever
- 创建 RAGPipeline
- 执行检索并打印结果
"""
print("=" * 60)
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
print("=" * 60)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
if retriever is None:
print("错误:检索器未初始化,请确保索引已构建。")
return
# 3. 创建 LLM 用于查询改写
llm = create_llm()
# 4. 创建 RAGPipeline固定流程
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=3, # 生成 3 个查询变体
rerank_top_n=5, # 最终返回 5 个父文档
)
# 5. 执行检索
query = "打虎英雄是谁?"
print(f"\n查询: {query}")
print("-" * 40)
try:
documents = await pipeline.aretrieve(query)
print(f"返回 {len(documents)} 个父文档\n")
# 打印结果预览
for i, doc in enumerate(documents, 1):
content_preview = doc.page_content.replace("\n", " ")[:150]
source = doc.metadata.get("source", "未知来源")
print(f"{i}. 【来源:{source}")
print(f" {content_preview}...\n")
# 可选:格式化完整上下文
# context = pipeline.format_context(documents)
# print(context)
except Exception as e:
print(f"检索失败: {e}")
import traceback
traceback.print_exc()
async def demonstrate_tool_creation():
"""
演示创建 RAG 工具(供 Agent 使用)
"""
print("\n" + "=" * 60)
print("演示:创建 RAG 工具(供 LangGraph Agent 调用)")
print("=" * 60)
# 1. 获取检索器(同上)
config = IndexBuilderConfig(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM
llm = create_llm()
# 3. 创建工具
rag_tool = create_rag_tool_sync(
retriever=retriever,
llm=llm,
num_queries=3,
rerank_top_n=5,
collection_name="rag_documents",
)
print(f"工具名称: {rag_tool.name}")
print(f"工具描述: {rag_tool.description[:100]}...")
# 4. 模拟 Agent 调用工具
query = "请告诉我 打虎英雄是谁?"
print(f"\n模拟调用: {query}")
print("-" * 40)
result = await rag_tool.ainvoke({"query": query})
print(result[:800] + "..." if len(result) > 800 else result)
async def main():
await demonstrate_full_pipeline()
await demonstrate_tool_creation()
if __name__ == "__main__":
asyncio.run(main())