Files
ailine/app/rag/tools.py
root 3c906e91d9
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 35m37s
重排,多路查询
2026-04-20 01:10:18 +08:00

116 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 工具模块
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
"""
from typing import Optional, Callable
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from .pipeline import RAGPipeline
def create_rag_tool(
retriever: BaseRetriever,
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(异步)。
Args:
retriever: 基础检索器(例如 ParentDocumentRetriever 实例)
llm: 用于多路查询改写的语言模型
num_queries: 生成查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: 集合名称(仅用于日志/描述)
Returns:
LangChain Tool 可调用对象(异步)
"""
# 初始化流水线(所有组件一次创建,后续复用)
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
)
@tool
async def search_knowledge_base(query: str) -> str:
"""在知识库中搜索与查询相关的文档片段。
该工具会:
1. 将用户问题改写成多个不同角度的查询
2. 并行检索每个查询的相关父文档
3. 使用倒数排名融合RRF合并结果
4. 用 Cross-Encoder 重排序模型精选最相关的片段
适用于需要精确、全面答案的事实性问题或背景知识查询。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容,若无结果则返回提示信息。
"""
try:
documents = await pipeline.aretrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base
def create_rag_tool_sync(
retriever: BaseRetriever,
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent
参数同 create_rag_tool。
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
)
@tool
def search_knowledge_base_sync(query: str) -> str:
"""在知识库中搜索与查询相关的文档片段(同步版本)。
功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容。
"""
try:
documents = pipeline.retrieve(query) # 内部调用异步方法并等待
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_sync