Files
ailine/backend/app/rag/tools.py
root 60afa86ded
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
feat: 实现 BM25 稀疏 + 稠密向量混合检索功能
2026-05-04 02:01:22 +08:00

97 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 工具模块
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
"""
from typing import Callable, Optional
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
def create_rag_tool_sync(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(同步版本)。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
Args:
retriever: 基础检索器对象(可选,不提供则自动创建)
llm: 用于生成多路查询的语言模型(可选)
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
LangChain Tool 函数
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
)
@tool
def search_knowledge_base_sync(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段。
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
检索效果最优。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
"""
try:
documents = pipeline.retrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_sync
def create_rag_tool(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
) -> Callable:
"""
创建 RAG 检索工具的便捷函数(同步版本)。
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型(可选)
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
Returns:
LangChain Tool 函数
"""
return create_rag_tool_sync(
collection_name=collection_name,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
)