Files
ailine/app/rag/tools.py
2026-04-18 16:31:48 +08:00

230 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 工具包装
将 RAG 流水线包装成 LangChain Tool供 Agent 调用。
"""
from typing import Optional, Dict, Any
from langchain.tools import tool
from langchain_core.tools import Tool
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .pipeline import RAGPipeline, RAGConfig, RAGLevel
class RAGTool:
"""
RAG 工具包装器
将 RAG 流水线包装成 Agent 可调用的工具。
"""
def __init__(
self,
pipeline: RAGPipeline,
tool_name: str = "search_knowledge_base",
tool_description: str = None,
):
"""
初始化 RAG 工具
Args:
pipeline: RAG 流水线实例
tool_name: 工具名称
tool_description: 工具描述
"""
self.pipeline = pipeline
self.tool_name = tool_name
# 默认工具描述
self.tool_description = tool_description or (
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"输入应为要搜索的查询文本。"
)
# 创建 LangChain 工具
self._tool = self._create_tool()
def _create_tool(self) -> Tool:
"""创建 LangChain 工具"""
@tool(self.tool_name, args_schema=None)
def search_knowledge_base(query: str) -> str:
"""
在知识库中搜索相关信息
Args:
query: 搜索查询
Returns:
格式化后的搜索结果
"""
try:
# 执行检索
result = self.pipeline.retrieve(query)
if not result.documents:
return "在知识库中未找到相关信息。"
# 格式化上下文
context = self.pipeline.format_context(
result.documents,
max_length=4000, # 限制上下文长度
)
# 构建响应
response = (
f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n"
f"{context}\n\n"
f"⏱️ 检索耗时: {result.query_time:.2f}"
)
return response
except Exception as e:
error_msg = f"检索过程中发生错误: {str(e)}"
if self.pipeline.config.verbose:
print(f"RAG 工具错误: {error_msg}")
return error_msg
# 设置工具描述
search_knowledge_base.description = self.tool_description
return search_knowledge_base
def get_tool(self) -> Tool:
"""获取 LangChain 工具"""
return self._tool
def __call__(self, query: str) -> str:
"""直接调用工具"""
return self._tool.invoke({"query": query})
def create_rag_tool(
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
tool_description: Optional[str] = None,
) -> Tool:
"""
创建 RAG 工具(便捷函数)
Args:
embeddings: 嵌入模型
llm: 语言模型(用于高级 RAG 功能)
config: RAG 配置字典
tool_name: 工具名称
tool_description: 工具描述
Returns:
LangChain Tool 实例
"""
# 创建 RAG 流水线
pipeline = RAGPipeline.create_from_config(
embeddings=embeddings,
llm=llm,
config_dict=config,
)
# 创建工具包装器
rag_tool = RAGTool(
pipeline=pipeline,
tool_name=tool_name,
tool_description=tool_description,
)
return rag_tool.get_tool()
# 导出便捷函数
search_knowledge_base_tool = create_rag_tool
def bind_rag_to_agent(
agent_llm: BaseLanguageModel,
embeddings: Embeddings,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
) -> BaseLanguageModel:
"""
将 RAG 工具绑定到 Agent 模型
Args:
agent_llm: Agent 使用的语言模型
embeddings: 嵌入模型
rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同)
config: RAG 配置
tool_name: 工具名称
Returns:
绑定工具后的模型
"""
# 如果未指定 RAG LLM使用 Agent LLM
if rag_llm is None:
rag_llm = agent_llm
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm,
config=config,
tool_name=tool_name,
)
# 绑定工具到模型
return agent_llm.bind_tools([rag_tool])
def create_agentic_rag_pipeline(
embeddings: Embeddings,
agent_llm: BaseLanguageModel,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
创建完整的 Agentic RAG 流水线Level 4
Args:
embeddings: 嵌入模型
agent_llm: Agent 模型
rag_llm: RAG 专用模型
config: 配置
Returns:
包含模型和工具的字典
"""
# 配置 Agentic RAG 级别
if config is None:
config = {}
config["rag_level"] = RAGLevel.AGENTIC.value
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config=config,
tool_name="search_knowledge_base",
tool_description=(
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。"
),
)
# 绑定工具到模型
bound_llm = agent_llm.bind_tools([rag_tool])
return {
"llm": bound_llm,
"tool": rag_tool,
"pipeline": RAGPipeline.create_from_config(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config_dict=config,
),
}