""" RAG 工具模块 将检索功能封装为 LangChain Tool,供 Agent 调用。 采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 """ from typing import Optional, Callable from langchain_core.tools import tool from langchain_core.language_models import BaseLanguageModel from langchain_core.retrievers import BaseRetriever from .pipeline import RAGPipeline def create_rag_tool( retriever: BaseRetriever, llm: BaseLanguageModel, num_queries: int = 3, rerank_top_n: int = 5, collection_name: str = "rag_documents", ) -> Callable: """ 创建一个配置好的 RAG 检索工具(异步)。 Args: retriever: 基础检索器(例如 ParentDocumentRetriever 实例) llm: 用于多路查询改写的语言模型 num_queries: 生成查询变体数量 rerank_top_n: 最终返回的文档数量 collection_name: 集合名称(仅用于日志/描述) Returns: LangChain Tool 可调用对象(异步) """ # 初始化流水线(所有组件一次创建,后续复用) pipeline = RAGPipeline( retriever=retriever, llm=llm, num_queries=num_queries, rerank_top_n=rerank_top_n, ) @tool async def search_knowledge_base(query: str) -> str: """在知识库中搜索与查询相关的文档片段。 该工具会: 1. 将用户问题改写成多个不同角度的查询 2. 并行检索每个查询的相关父文档 3. 使用倒数排名融合(RRF)合并结果 4. 用 Cross-Encoder 重排序模型精选最相关的片段 适用于需要精确、全面答案的事实性问题或背景知识查询。 Args: query: 用户提出的问题或查询字符串 Returns: 格式化后的相关文档内容,若无结果则返回提示信息。 """ try: documents = await pipeline.aretrieve(query) if not documents: return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" context = pipeline.format_context(documents) return context except Exception as e: return f"检索过程中发生错误: {str(e)}" return search_knowledge_base def create_rag_tool_sync( retriever: BaseRetriever, llm: BaseLanguageModel, num_queries: int = 3, rerank_top_n: int = 5, collection_name: str = "rag_documents", ) -> Callable: """ 创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。 参数同 create_rag_tool。 """ pipeline = RAGPipeline( retriever=retriever, llm=llm, num_queries=num_queries, rerank_top_n=rerank_top_n, ) @tool def search_knowledge_base_sync(query: str) -> str: """在知识库中搜索与查询相关的文档片段(同步版本)。 功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。 Args: query: 用户提出的问题或查询字符串 Returns: 格式化后的相关文档内容。 """ try: documents = pipeline.retrieve(query) # 内部调用异步方法并等待 if not documents: return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" context = pipeline.format_context(documents) return context except Exception as e: return f"检索过程中发生错误: {str(e)}" return search_knowledge_base_sync