本地RAG尝试

This commit is contained in:
2026-04-18 16:31:48 +08:00
parent 6042d4a476
commit 0470afce13
12 changed files with 1587 additions and 4 deletions

230
app/rag/tools.py Normal file
View File

@@ -0,0 +1,230 @@
"""
RAG 工具包装
将 RAG 流水线包装成 LangChain Tool供 Agent 调用。
"""
from typing import Optional, Dict, Any
from langchain.tools import tool
from langchain_core.tools import Tool
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from .pipeline import RAGPipeline, RAGConfig, RAGLevel
class RAGTool:
"""
RAG 工具包装器
将 RAG 流水线包装成 Agent 可调用的工具。
"""
def __init__(
self,
pipeline: RAGPipeline,
tool_name: str = "search_knowledge_base",
tool_description: str = None,
):
"""
初始化 RAG 工具
Args:
pipeline: RAG 流水线实例
tool_name: 工具名称
tool_description: 工具描述
"""
self.pipeline = pipeline
self.tool_name = tool_name
# 默认工具描述
self.tool_description = tool_description or (
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"输入应为要搜索的查询文本。"
)
# 创建 LangChain 工具
self._tool = self._create_tool()
def _create_tool(self) -> Tool:
"""创建 LangChain 工具"""
@tool(self.tool_name, args_schema=None)
def search_knowledge_base(query: str) -> str:
"""
在知识库中搜索相关信息
Args:
query: 搜索查询
Returns:
格式化后的搜索结果
"""
try:
# 执行检索
result = self.pipeline.retrieve(query)
if not result.documents:
return "在知识库中未找到相关信息。"
# 格式化上下文
context = self.pipeline.format_context(
result.documents,
max_length=4000, # 限制上下文长度
)
# 构建响应
response = (
f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n"
f"{context}\n\n"
f"⏱️ 检索耗时: {result.query_time:.2f}"
)
return response
except Exception as e:
error_msg = f"检索过程中发生错误: {str(e)}"
if self.pipeline.config.verbose:
print(f"RAG 工具错误: {error_msg}")
return error_msg
# 设置工具描述
search_knowledge_base.description = self.tool_description
return search_knowledge_base
def get_tool(self) -> Tool:
"""获取 LangChain 工具"""
return self._tool
def __call__(self, query: str) -> str:
"""直接调用工具"""
return self._tool.invoke({"query": query})
def create_rag_tool(
embeddings: Embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
tool_description: Optional[str] = None,
) -> Tool:
"""
创建 RAG 工具(便捷函数)
Args:
embeddings: 嵌入模型
llm: 语言模型(用于高级 RAG 功能)
config: RAG 配置字典
tool_name: 工具名称
tool_description: 工具描述
Returns:
LangChain Tool 实例
"""
# 创建 RAG 流水线
pipeline = RAGPipeline.create_from_config(
embeddings=embeddings,
llm=llm,
config_dict=config,
)
# 创建工具包装器
rag_tool = RAGTool(
pipeline=pipeline,
tool_name=tool_name,
tool_description=tool_description,
)
return rag_tool.get_tool()
# 导出便捷函数
search_knowledge_base_tool = create_rag_tool
def bind_rag_to_agent(
agent_llm: BaseLanguageModel,
embeddings: Embeddings,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
tool_name: str = "search_knowledge_base",
) -> BaseLanguageModel:
"""
将 RAG 工具绑定到 Agent 模型
Args:
agent_llm: Agent 使用的语言模型
embeddings: 嵌入模型
rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同)
config: RAG 配置
tool_name: 工具名称
Returns:
绑定工具后的模型
"""
# 如果未指定 RAG LLM使用 Agent LLM
if rag_llm is None:
rag_llm = agent_llm
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm,
config=config,
tool_name=tool_name,
)
# 绑定工具到模型
return agent_llm.bind_tools([rag_tool])
def create_agentic_rag_pipeline(
embeddings: Embeddings,
agent_llm: BaseLanguageModel,
rag_llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
创建完整的 Agentic RAG 流水线Level 4
Args:
embeddings: 嵌入模型
agent_llm: Agent 模型
rag_llm: RAG 专用模型
config: 配置
Returns:
包含模型和工具的字典
"""
# 配置 Agentic RAG 级别
if config is None:
config = {}
config["rag_level"] = RAGLevel.AGENTIC.value
# 创建 RAG 工具
rag_tool = create_rag_tool(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config=config,
tool_name="search_knowledge_base",
tool_description=(
"在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、"
"专业知识或需要基于已知信息回答的问题时使用此工具。"
"Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。"
),
)
# 绑定工具到模型
bound_llm = agent_llm.bind_tools([rag_tool])
return {
"llm": bound_llm,
"tool": rag_tool,
"pipeline": RAGPipeline.create_from_config(
embeddings=embeddings,
llm=rag_llm or agent_llm,
config_dict=config,
),
}