Files
ailine/backend/app/rag/tools.py
root 3ae9daa01a
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m44s
导入方式修改
2026-05-05 23:17:00 +08:00

74 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 工具模块(完全异步)
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
"""
from typing import Callable, Optional
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from ..rag.pipeline import RAGPipeline, create_rag_pipeline
def create_rag_tool(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(完全异步)。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
Args:
retriever: 基础检索器对象(可选,不提供则自动创建)
llm: 用于生成多路查询的语言模型。
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
Async LangChain Tool 函数
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
)
@tool
async def search_knowledge_base(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段(完全异步)。
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
检索效果最优。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
"""
try:
documents = await pipeline.aretrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base