230 lines
6.5 KiB
Python
230 lines
6.5 KiB
Python
"""
|
||
RAG 工具包装
|
||
|
||
将 RAG 流水线包装成 LangChain Tool,供 Agent 调用。
|
||
"""
|
||
|
||
from typing import Optional, Dict, Any
|
||
from langchain.tools import tool
|
||
from langchain_core.tools import Tool
|
||
from langchain_core.embeddings import Embeddings
|
||
from langchain_core.language_models import BaseLanguageModel
|
||
|
||
from .pipeline import RAGPipeline, RAGConfig, RAGLevel
|
||
|
||
|
||
class RAGTool:
|
||
"""
|
||
RAG 工具包装器
|
||
|
||
将 RAG 流水线包装成 Agent 可调用的工具。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
pipeline: RAGPipeline,
|
||
tool_name: str = "search_knowledge_base",
|
||
tool_description: str = None,
|
||
):
|
||
"""
|
||
初始化 RAG 工具
|
||
|
||
Args:
|
||
pipeline: RAG 流水线实例
|
||
tool_name: 工具名称
|
||
tool_description: 工具描述
|
||
"""
|
||
self.pipeline = pipeline
|
||
self.tool_name = tool_name
|
||
|
||
# 默认工具描述
|
||
self.tool_description = tool_description or (
|
||
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
|
||
"专业知识或需要基于已知信息回答的问题时使用此工具。"
|
||
"输入应为要搜索的查询文本。"
|
||
)
|
||
|
||
# 创建 LangChain 工具
|
||
self._tool = self._create_tool()
|
||
|
||
def _create_tool(self) -> Tool:
|
||
"""创建 LangChain 工具"""
|
||
|
||
@tool(self.tool_name, args_schema=None)
|
||
def search_knowledge_base(query: str) -> str:
|
||
"""
|
||
在知识库中搜索相关信息
|
||
|
||
Args:
|
||
query: 搜索查询
|
||
|
||
Returns:
|
||
格式化后的搜索结果
|
||
"""
|
||
try:
|
||
# 执行检索
|
||
result = self.pipeline.retrieve(query)
|
||
|
||
if not result.documents:
|
||
return "在知识库中未找到相关信息。"
|
||
|
||
# 格式化上下文
|
||
context = self.pipeline.format_context(
|
||
result.documents,
|
||
max_length=4000, # 限制上下文长度
|
||
)
|
||
|
||
# 构建响应
|
||
response = (
|
||
f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n"
|
||
f"{context}\n\n"
|
||
f"⏱️ 检索耗时: {result.query_time:.2f}秒"
|
||
)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
error_msg = f"检索过程中发生错误: {str(e)}"
|
||
if self.pipeline.config.verbose:
|
||
print(f"RAG 工具错误: {error_msg}")
|
||
return error_msg
|
||
|
||
# 设置工具描述
|
||
search_knowledge_base.description = self.tool_description
|
||
|
||
return search_knowledge_base
|
||
|
||
def get_tool(self) -> Tool:
|
||
"""获取 LangChain 工具"""
|
||
return self._tool
|
||
|
||
def __call__(self, query: str) -> str:
|
||
"""直接调用工具"""
|
||
return self._tool.invoke({"query": query})
|
||
|
||
|
||
def create_rag_tool(
|
||
embeddings: Embeddings,
|
||
llm: Optional[BaseLanguageModel] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
tool_name: str = "search_knowledge_base",
|
||
tool_description: Optional[str] = None,
|
||
) -> Tool:
|
||
"""
|
||
创建 RAG 工具(便捷函数)
|
||
|
||
Args:
|
||
embeddings: 嵌入模型
|
||
llm: 语言模型(用于高级 RAG 功能)
|
||
config: RAG 配置字典
|
||
tool_name: 工具名称
|
||
tool_description: 工具描述
|
||
|
||
Returns:
|
||
LangChain Tool 实例
|
||
"""
|
||
# 创建 RAG 流水线
|
||
pipeline = RAGPipeline.create_from_config(
|
||
embeddings=embeddings,
|
||
llm=llm,
|
||
config_dict=config,
|
||
)
|
||
|
||
# 创建工具包装器
|
||
rag_tool = RAGTool(
|
||
pipeline=pipeline,
|
||
tool_name=tool_name,
|
||
tool_description=tool_description,
|
||
)
|
||
|
||
return rag_tool.get_tool()
|
||
|
||
|
||
# 导出便捷函数
|
||
search_knowledge_base_tool = create_rag_tool
|
||
|
||
|
||
def bind_rag_to_agent(
|
||
agent_llm: BaseLanguageModel,
|
||
embeddings: Embeddings,
|
||
rag_llm: Optional[BaseLanguageModel] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
tool_name: str = "search_knowledge_base",
|
||
) -> BaseLanguageModel:
|
||
"""
|
||
将 RAG 工具绑定到 Agent 模型
|
||
|
||
Args:
|
||
agent_llm: Agent 使用的语言模型
|
||
embeddings: 嵌入模型
|
||
rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同)
|
||
config: RAG 配置
|
||
tool_name: 工具名称
|
||
|
||
Returns:
|
||
绑定工具后的模型
|
||
"""
|
||
# 如果未指定 RAG LLM,使用 Agent LLM
|
||
if rag_llm is None:
|
||
rag_llm = agent_llm
|
||
|
||
# 创建 RAG 工具
|
||
rag_tool = create_rag_tool(
|
||
embeddings=embeddings,
|
||
llm=rag_llm,
|
||
config=config,
|
||
tool_name=tool_name,
|
||
)
|
||
|
||
# 绑定工具到模型
|
||
return agent_llm.bind_tools([rag_tool])
|
||
|
||
|
||
def create_agentic_rag_pipeline(
|
||
embeddings: Embeddings,
|
||
agent_llm: BaseLanguageModel,
|
||
rag_llm: Optional[BaseLanguageModel] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
创建完整的 Agentic RAG 流水线(Level 4)
|
||
|
||
Args:
|
||
embeddings: 嵌入模型
|
||
agent_llm: Agent 模型
|
||
rag_llm: RAG 专用模型
|
||
config: 配置
|
||
|
||
Returns:
|
||
包含模型和工具的字典
|
||
"""
|
||
# 配置 Agentic RAG 级别
|
||
if config is None:
|
||
config = {}
|
||
config["rag_level"] = RAGLevel.AGENTIC.value
|
||
|
||
# 创建 RAG 工具
|
||
rag_tool = create_rag_tool(
|
||
embeddings=embeddings,
|
||
llm=rag_llm or agent_llm,
|
||
config=config,
|
||
tool_name="search_knowledge_base",
|
||
tool_description=(
|
||
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
|
||
"专业知识或需要基于已知信息回答的问题时使用此工具。"
|
||
"Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。"
|
||
),
|
||
)
|
||
|
||
# 绑定工具到模型
|
||
bound_llm = agent_llm.bind_tools([rag_tool])
|
||
|
||
return {
|
||
"llm": bound_llm,
|
||
"tool": rag_tool,
|
||
"pipeline": RAGPipeline.create_from_config(
|
||
embeddings=embeddings,
|
||
llm=rag_llm or agent_llm,
|
||
config_dict=config,
|
||
),
|
||
} |