Files
ailine/app/rag/tools.py

54 lines
1.7 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
RAG 工具模块
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
将检索功能封装为 LangChain Tool Agent 调用
2026-04-20 01:10:18 +08:00
采用固定流水线多路改写 并行检索 RRF 融合 重排序 返回父文档
2026-04-18 16:31:48 +08:00
"""
from typing import Callable
2026-04-19 22:01:55 +08:00
from langchain_core.tools import tool
2026-04-20 01:10:18 +08:00
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from app.rag.pipeline import RAGPipeline
2026-04-20 01:10:18 +08:00
def create_rag_tool_sync(
retriever: BaseRetriever,
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具同步版本用于不支持异步的旧版 Agent
参数同 create_rag_tool
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
pipeline = RAGPipeline(
2026-04-20 01:10:18 +08:00
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
2026-04-18 16:31:48 +08:00
)
2026-04-20 01:10:18 +08:00
@tool
def search_knowledge_base_sync(query: str) -> str:
"""在知识库中搜索与查询相关的文档片段(同步版本)。
功能与异步版本相同多路改写 RRF融合 重排序 返回父文档
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
"""
try:
documents = pipeline.retrieve(query) # 内部调用异步方法并等待
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_sync