Files
ailine/backend/app/rag/tools.py

154 lines
4.7 KiB
Python
Raw Normal View History

2026-04-21 11:02:16 +08:00
"""
RAG 工具模块
将检索功能封装为 LangChain Tool Agent 调用
采用固定流水线多路改写 并行检索 RRF 融合 重排序 返回父文档
默认使用混合检索稠密+BM25稀疏+ 父子文档模式
2026-04-21 11:02:16 +08:00
"""
from typing import Callable, Optional
2026-04-21 11:02:16 +08:00
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
2026-04-21 11:02:16 +08:00
def create_rag_tool_sync(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = None,
2026-04-21 11:02:16 +08:00
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具同步版本
默认使用混合检索稠密+BM25稀疏+ 父子文档模式
Args:
retriever: 基础检索器对象可选不提供则自动创建
llm: 用于生成多路查询的语言模型可选
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
LangChain Tool 函数
2026-04-21 11:02:16 +08:00
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
2026-04-21 11:02:16 +08:00
)
@tool
def search_knowledge_base_sync(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段
使用混合检索稠密向量语义 + BM25 关键词+ 父子文档模式
检索效果最优
2026-04-21 11:02:16 +08:00
Args:
query: 用户提出的问题或查询字符串
2026-04-21 11:02:16 +08:00
Returns:
格式化后的相关文档内容
2026-04-21 11:02:16 +08:00
"""
try:
documents = pipeline.retrieve(query)
2026-04-21 11:02:16 +08:00
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
2026-04-21 11:02:16 +08:00
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_sync
2026-04-21 11:02:16 +08:00
2026-05-04 04:28:32 +08:00
def create_rag_tool_async(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具异步版本
默认使用混合检索稠密+BM25稀疏+ 父子文档模式
Args:
retriever: 基础检索器对象可选不提供则自动创建
llm: 用于生成多路查询的语言模型可选
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
Async LangChain Tool 函数
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
)
@tool
async def search_knowledge_base_async(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段异步版本
使用混合检索稠密向量语义 + BM25 关键词+ 父子文档模式
检索效果最优
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
"""
try:
documents = await pipeline.aretrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_async
def create_rag_tool(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
) -> Callable:
"""
创建 RAG 检索工具的便捷函数同步版本
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型可选
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
Returns:
LangChain Tool 函数
"""
return create_rag_tool_sync(
collection_name=collection_name,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
)