优化查询代码,优化工具代码
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m33s

This commit is contained in:
2026-05-08 22:30:26 +08:00
parent bfb2ddbe76
commit b30f7b00a7
11 changed files with 375 additions and 511 deletions

View File

@@ -1,198 +1,88 @@
"""
AI Agent 服务类 - 完全简化版本!
按照指南实现,不用 stream_mode="messages" 避免重复 token
AI Agent 服务类
"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
from typing import AsyncGenerator, Dict, Any
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
# 本地模块
from backend.app.model_services import get_cached_chat_services
from backend.app.main_graph.main_graph_builder import build_agent_graph
from backend.app.logger import debug, info, warning, error
from backend.app.main_graph.state import AgentState
from .stream_context import set_stream_queue
from backend.app.logger import info
from backend.app.memory.mem0_client import Mem0Client
from .service_config import ServiceConfig
from .stream_handler import run_graph_stream
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graph = None
self.chat_services = None
# Mem0 客户端
self.config: ServiceConfig = None
self.mem0_client = None
async def initialize(self):
# 0. 初始化 Mem0 客户端
from ..memory.mem0_client import Mem0Client
async def initialize(self) -> "AIAgentService":
"""初始化 Agent 服务"""
self.mem0_client = Mem0Client()
# 1. 获取缓存的模型字典
self.chat_services = get_cached_chat_services()
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
# 2. 构建图
info(f"🔄 构建 Agent 图...")
graph_builder = build_agent_graph(
chat_services=self.chat_services,
mem0_client=self.mem0_client
)
# 编译图
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
self.config = ServiceConfig(self.chat_services)
info(f"✅ Agent 图初始化完成")
return self
def _resolve_model(self, model: str) -> str:
"""
解析并验证模型名称,不可用时回退到第一个可用模型
Args:
model: 目标模型名称
Returns:
实际使用的模型名称
"""
if not model or model not in self.chat_services:
fallback = next(iter(self.chat_services.keys()))
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
return fallback
return model
def _build_invocation(
def _resolve_and_build(
self, message: str, thread_id: str, model: str, user_id: str
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
构建图调用所需的 config 和 input_state
Args:
message: 用户消息
thread_id: 会话 ID
model: 模型名称
user_id: 用户 ID
Returns:
(config, input_state) 元组
"""
from langchain_core.messages import HumanMessage
config = {
"configurable": {
"thread_id": thread_id,
},
"metadata": {"user_id": user_id}
}
input_state = {
"messages": [HumanMessage(content=message)],
"user_id": user_id,
}
return config, input_state
):
"""解析模型并构建调用参数"""
resolved_model = self.config.resolve_model(model)
return resolved_model, self.config.build_invocation(
message, thread_id, resolved_model, user_id
)
async def process_message(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
resolved_model, (config, input_state) = self._resolve_and_build(
message, thread_id, model, user_id
)
result = await self.graph.ainvoke(input_state, config=config)
# 优先使用 final_replyfinalize 节点返回)
reply = result.get("final_reply", "")
if not reply and result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
# 获取元数据
metadata = result.get("metadata", {})
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time,
"token_usage": result.get("last_token_usage", {}),
"elapsed_time": result.get("last_elapsed_time", 0.0),
"model_used": resolved_model,
"metadata": metadata
"metadata": result.get("metadata", {}),
}
async def process_message_stream(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> AsyncGenerator[Dict[str, Any], None]:
"""流式处理消息 - 完全简化!"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
"""流式处理消息"""
resolved_model, (config, input_state) = self._resolve_and_build(
message, thread_id, model, user_id
)
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
actual_model_used = resolved_model
# 创建 token 队列
queue = asyncio.Queue()
set_stream_queue(queue) # 设置上下文变量
async def run_graph():
"""后台任务:运行 graph流式事件都从 agent 节点内部发送!"""
try:
info(f"📡 开始调用 graph.astream()...")
# 注意:只用 stream_mode=["updates"],不要 "messages"!避免重复 token
async for _ in self.graph.astream(
input_state,
config=config,
stream_mode=["updates"],
version="v2",
subgraphs=True
):
# 流式事件都从 agent.py 节点内部通过队列发送了
# 这里不需要再发送任何事件
pass
except Exception as e:
error(f"❌ 执行图时出错: {e}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
await queue.put({"type": "error", "message": str(e)})
finally:
await queue.put(None) # 结束哨兵
# 启动后台任务
bg_task = asyncio.create_task(run_graph())
try:
while True:
event = await queue.get()
if event is None:
break
async for event in run_graph_stream(self.graph, input_state, config):
if event.get("type") != "done":
yield event
except GeneratorExit:
# 客户端断开连接,取消后台任务
info("⚠️ GeneratorExit取消后台任务")
bg_task.cancel()
raise
finally:
# 保证任务被清理
if not bg_task.done():
info("⏹️ 清理后台任务")
bg_task.cancel()
try:
await bg_task
except asyncio.CancelledError:
info("✅ 后台任务已取消")
# 发送结束事件,保证前端平稳关闭
yield {
"type": "done",
"model_used": actual_model_used
}
else:
yield {**event, "model_used": resolved_model}

View File

@@ -0,0 +1,46 @@
"""
Agent Service 配置模块 - 配置构建和解析
"""
from typing import Dict, Any, Tuple, Optional
from langchain_core.messages import HumanMessage
from backend.app.logger import warning
class ServiceConfig:
"""配置构建器"""
def __init__(self, chat_services: dict):
self.chat_services = chat_services
def resolve_model(self, model: Optional[str]) -> str:
"""
解析并验证模型名称,不可用时回退到第一个可用模型
"""
if not model or model not in self.chat_services:
fallback = next(iter(self.chat_services.keys()))
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
return fallback
return model
def build_invocation(
self, message: str, thread_id: str, model: str, user_id: str
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
构建图调用所需的 config 和 input_state
Returns:
(config, input_state) 元组
"""
config = {
"configurable": {
"thread_id": thread_id,
},
"metadata": {"user_id": user_id}
}
input_state = {
"messages": [HumanMessage(content=message)],
"user_id": user_id,
}
return config, input_state

View File

@@ -0,0 +1,78 @@
"""
流式处理模块 - 处理 Agent 执行的流式输出
"""
import asyncio
from typing import AsyncGenerator, Dict, Any
from backend.app.logger import info, error
from .stream_context import set_stream_queue
async def run_graph_stream(
graph,
input_state: Dict[str, Any],
config: Dict[str, Any],
) -> AsyncGenerator[Dict[str, Any], None]:
"""
运行图并通过队列流式输出事件
Args:
graph: 编译后的 LangGraph
input_state: 输入状态
config: 配置
Yields:
流式事件
"""
queue: asyncio.Queue = asyncio.Queue()
set_stream_queue(queue)
async def run_graph():
"""后台任务:运行 graph"""
try:
info(f"📡 开始调用 graph.astream()...")
async for _ in graph.astream(
input_state,
config=config,
stream_mode=["updates"],
version="v2",
subgraphs=True
):
# 流式事件都从 agent.py 节点内部通过队列发送了
pass
except Exception as e:
error(f"❌ 执行图时出错: {e}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
await queue.put({"type": "error", "message": str(e)})
finally:
await queue.put(None) # 结束哨兵
# 启动后台任务
bg_task = asyncio.create_task(run_graph())
try:
while True:
event = await queue.get()
if event is None:
break
yield event
except GeneratorExit:
info("⚠️ GeneratorExit取消后台任务")
bg_task.cancel()
raise
finally:
await _cleanup_task(bg_task)
async def _cleanup_task(bg_task: asyncio.Task) -> None:
"""清理后台任务"""
if not bg_task.done():
info("⏹️ 清理后台任务")
bg_task.cancel()
try:
await bg_task
except asyncio.CancelledError:
info("✅ 后台任务已取消")

View File

@@ -1,73 +0,0 @@
# app/rag_initializer.py
from ...rag.tools import create_rag_tool
from ...rag.retriever import create_parent_hybrid_retriever
from ...model_services import get_embedding_service
from backend.app.logger import info, warning
import sys
# 全局 RAG 工具
_rag_tool = None
_initialized = False
def get_rag_tool() -> callable:
"""获取全局 RAG 工具"""
return _rag_tool
def is_initialized() -> bool:
"""检查是否已初始化"""
return _initialized
async def init_rag_tool(force: bool = False):
"""
初始化 RAG 工具(注册到模块级变量,内部获取所需服务)
Args:
force: 是否强制重新初始化
Returns:
RAG 工具(@tool 装饰函数)或 None
"""
global _rag_tool, _initialized
# 防止重复初始化
if _initialized and not force:
info("[RAG] 已初始化,跳过")
return _rag_tool
try:
from backend.app.model_services.chat_services import get_chat_service
info("🔄 正在初始化 RAG 检索系统...")
embeddings = get_embedding_service()
retriever = create_parent_hybrid_retriever(
collection_name="rag_documents",
search_k=5,
embeddings=embeddings,
)
rewrite_llm = get_chat_service()
rag_tool = create_rag_tool(
retriever=retriever,
llm=rewrite_llm,
num_queries=3,
rerank_top_n=5,
)
_rag_tool = rag_tool
_initialized = True
info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None
def reset():
"""重置(用于测试)"""
global _rag_tool, _initialized
_rag_tool = None
_initialized = False

View File

@@ -3,12 +3,11 @@
Web Search Public Utility - Free, no API Key, using DuckDuckGo
"""
from typing import List, Dict, Any, Optional
from typing import List, Optional
from dataclasses import dataclass
from datetime import datetime
import requests
import warnings
import re
from backend.app.logger import info
@dataclass
@@ -44,47 +43,31 @@ class WebSearchTool:
"""
num_results = max_results or self.max_results
# 方式 1: Tavily (需要 API Key质量最高)
# 尝试搜索方式,按优先级
result = self._try_tavily(query, num_results)
if result is not None:
return result
result = self._try_ddgs(query, num_results)
if result is not None:
return result
# 兜底方案
return self._get_mock_results(query, num_results)
def _try_tavily(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
"""尝试 Tavily API 搜索"""
try:
return self._search_tavily(query, num_results)
return self._search_tavily(query, max_results)
except ImportError:
print("[WebSearch] tavily 未安装,尝试其他搜索方式")
info("[WebSearch] tavily 未安装")
except Exception as e:
if "API_KEY" in str(e) or "未配置" in str(e):
print(f"[WebSearch] Tavily API Key 未配置: {e}")
error_msg = str(e)
if "API_KEY" in error_msg or "未配置" in error_msg:
info(f"[WebSearch] Tavily API Key 未配置")
else:
print(f"[WebSearch] Tavily 搜索失败: {e}")
# 方式 2: 尝试用 ddgs 包
try:
from ddgs import DDGS
print(f"[WebSearch] 使用 ddgs 搜索: {query}")
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=num_results))
if results:
search_results = []
for r in results:
search_results.append(SearchResult(
title=r.get("title", ""),
url=r.get("href", ""),
snippet=r.get("body", ""),
source="DuckDuckGo"
))
print(f"[WebSearch] ddgs 返回 {len(search_results)} 条结果")
return search_results
except ImportError:
print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search")
except Exception as e:
print(f"[WebSearch] ddgs 搜索失败: {e}")
# 方式 3: 尝试用简单 HTTP 请求
try:
return self._search_http(query, num_results)
except Exception as e:
print(f"[WebSearch] HTTP 搜索也失败: {e}")
# 方式 4: 返回模拟数据作为最后兜底
return self._search_mock(query, num_results)
info(f"[WebSearch] Tavily 搜索失败: {e}")
return None
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
"""使用 Tavily API 搜索"""
@@ -111,56 +94,40 @@ class WebSearchTool:
source="Tavily"
))
print(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
info(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
return results
def _search_http(self, query: str, max_results: int) -> List[SearchResult]:
"""用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源"""
print(f"[WebSearch] 尝试 HTTP 搜索")
# 方式 1: 尝试百度搜索(简单方式)
def _try_ddgs(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
"""尝试 DuckDuckGo 搜索"""
try:
return self._search_baidu(query, max_results)
except Exception as e:
print(f"[WebSearch] 百度搜索失败: {e}")
# 方式 2: 返回模拟数据
return self._search_mock(query, max_results)
from ddgs import DDGS
def _search_baidu(self, query: str, max_results: int) -> List[SearchResult]:
"""尝试百度搜索"""
import requests
from urllib.parse import quote
url = f"https://www.baidu.com/s?wd={quote(query)}"
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
# 简单解析百度搜索结果(简化版)
results = []
# 这里只是示意,真实百度搜索需要更复杂的解析
results.append(SearchResult(
title=f"百度搜索: {query}",
url=url,
snippet="如需要真实搜索结果,请考虑使用百度搜索 API",
source="百度"
))
return results
except Exception as e:
print(f"[WebSearch] 百度搜索也失败: {e}")
raise
with DDGS() as ddgs:
for r in ddgs.text(query, max_results=max_results):
results.append(SearchResult(
title=r.get("title", ""),
url=r.get("href", ""),
snippet=r.get("body", ""),
source="DuckDuckGo"
))
def _search_mock(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
"""模拟搜索结果(兜底方案)"""
print(f"[WebSearch] 使用模拟搜索结果 (查询: {query})")
# 根据查询内容生成更有意义的模拟结果
mock_templates = [
if results:
info(f"[WebSearch] ddgs 返回 {len(results)} 条结果")
return results
except ImportError:
info("[WebSearch] ddgs 未安装")
except Exception as e:
info(f"[WebSearch] ddgs 搜索失败: {e}")
return None
def _get_mock_results(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
"""获取模拟搜索结果(兜底方案)"""
info(f"[WebSearch] 使用模拟搜索结果")
templates = [
{
"title": f"关于「{query}」的相关介绍",
"snippet": "这是模拟结果。如需真实搜索,请检查容器网络连接或配置代理。",
@@ -177,50 +144,48 @@ class WebSearchTool:
"url": "https://example.com/more"
}
]
num = max_results or self.max_results
results = []
for i, template in enumerate(mock_templates[:num]):
for template in templates[:num]:
results.append(SearchResult(
title=template["title"],
url=template["url"],
snippet=template["snippet"],
source="模拟数据"
))
return results
def format_search_results(self, results: List[SearchResult]) -> str:
"""
格式化搜索结果(带引用溯源)
Args:
results: 搜索结果列表
Returns:
格式化后的 Markdown 文本
"""
if not results:
return "未找到相关搜索结果"
lines = []
lines.append("## 🔍 联网搜索结果\n")
lines = ["## 🔍 联网搜索结果\n"]
for idx, result in enumerate(results, 1):
lines.append(f"### [{idx}] {result.title}")
lines.append(f"- 🔗 来源:[{result.url}]({result.url})")
lines.append(f"- 📝 摘要:{result.snippet}")
lines.append(f"- 📅 时间:{result.timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
lines.append("")
# 添加引用溯源说明
lines.append("---")
lines.append("💡 **引用溯源说明**")
lines.append("- 以上搜索结果均标注了来源链接")
lines.append("- 使用方括号数字标识引用(如 [1]、[2]")
lines.append("- 可通过链接追溯原始信息")
return "\n".join(lines)
@@ -239,11 +204,11 @@ def get_web_search_tool() -> WebSearchTool:
def web_search(query: str, max_results: int = 5) -> str:
"""
便捷函数:联网搜索并返回格式化结果
Args:
query: 搜索关键词
max_results: 返回结果数量
Returns:
格式化后的搜索结果文本
"""

View File

@@ -1,73 +0,0 @@
"""
RAG 工具模块(完全异步)
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
"""
from typing import Callable, Optional
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from ..rag.pipeline import RAGPipeline, create_rag_pipeline
def create_rag_tool(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(完全异步)。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
Args:
retriever: 基础检索器对象(可选,不提供则自动创建)
llm: 用于生成多路查询的语言模型。
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
Async LangChain Tool 函数
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
)
@tool
async def search_knowledge_base(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段(完全异步)。
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
检索效果最优。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
"""
try:
documents = await pipeline.aretrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base

View File

@@ -2,124 +2,17 @@
Agent Tools - 所有工具统一定义
"""
from langchain_core.tools import tool
from backend.app.logger import info
# ========== RAG ==========
_rag_pipeline = None
def _get_rag_pipeline():
global _rag_pipeline
if _rag_pipeline is None:
from backend.app.rag.pipeline import RAGPipeline
_rag_pipeline = RAGPipeline(
num_queries=3,
rerank_top_n=5,
use_rerank=True,
return_parent_docs=True,
)
return _rag_pipeline
@tool
async def rag_search(query: str) -> str:
"""
检索知识库获取相关信息
Returns:
包含检索结果和置信度的结构化回复,格式:
- 内容:检索到的相关信息
- 置信度评估基于向量相似度、重排分数、LLM判断的综合评分
"""
info(f"[Tool] rag_search: {query[:30]}...")
try:
pipeline = _get_rag_pipeline()
# 使用带置信度的检索
result = await pipeline.aretrieve_with_confidence(query, original_query=query)
if not result.content:
return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度0.0\n建议:可尝试联网搜索获取信息。"
# 构建包含置信度的回复
confidence_desc = ""
if result.confidence < 0.4:
confidence_desc = ""
elif result.confidence < 0.6:
confidence_desc = ""
response = f"""【RAG检索结果】
{result.content}
【置信度评估】
- 综合置信度:{result.confidence:.2f}{confidence_desc}
- 向量相似度:{result.scores['embedding']:.2f}
- 重排分数:{result.scores['rerank']:.2f}
- LLM评估{result.scores['llm']:.2f}
{'✅ 检索结果可信,可直接使用' if result.is_useful else '⚠️ 检索结果置信度较低,可能需要联网搜索补充'}"""
info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}")
return response
except Exception as e:
info(f"[Tool] rag_search 失败: {e}")
return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索"
# ========== 联网搜索 ==========
@tool
def web_search(query: str) -> str:
"""联网搜索获取最新信息"""
info(f"[Tool] web_search: {query[:30]}...")
try:
from backend.app.core.web_search import web_search as search_fn
return search_fn(query, max_results=5)
except Exception as e:
info(f"[Tool] web_search 失败: {e}")
return f"联网搜索失败: {str(e)}"
# ========== 子图工具 ==========
async def _call_subgraph(builder_fn, state_cls, query: str) -> str:
"""通用子图调用"""
try:
graph = builder_fn().compile()
state = state_cls(user_query=query)
result = await graph.ainvoke(state)
return result.get("final_result", "执行完成")
except Exception as e:
info(f"[Tool] 子图调用失败: {e}")
return f"执行失败: {str(e)}"
@tool
async def contact_lookup(query: str) -> str:
"""查询通讯录"""
from backend.app.subgraphs.contact.graph import build_contact_subgraph
from backend.app.subgraphs.contact.state import ContactState
return await _call_subgraph(build_contact_subgraph, ContactState, query)
@tool
async def dictionary_lookup(word: str) -> str:
"""查询词典/翻译"""
from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph
from backend.app.subgraphs.dictionary.state import DictionaryState
return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word)
@tool
async def news_analysis(topic: str) -> str:
"""分析新闻热点"""
from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph
from backend.app.subgraphs.news_analysis.state import NewsAnalysisState
return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic)
# ========== 导出 ==========
from .rag import rag_search
from .web_search import web_search
from .subgraph import contact_lookup, dictionary_lookup, news_analysis
ALL_TOOLS = [rag_search, web_search, contact_lookup, dictionary_lookup, news_analysis]
__all__ = [
"rag_search",
"web_search",
"contact_lookup",
"dictionary_lookup",
"news_analysis",
"ALL_TOOLS",
]

View File

@@ -0,0 +1,8 @@
"""
工具模块配置
"""
from langchain_core.tools import tool
from backend.app.logger import info
__all__ = ["tool", "info"]

70
backend/app/tools/rag.py Normal file
View File

@@ -0,0 +1,70 @@
"""
RAG 检索工具
"""
from langchain_core.tools import tool
from backend.app.logger import info
_rag_pipeline = None
def _get_rag_pipeline():
"""获取或创建 RAG pipeline 单例"""
global _rag_pipeline
if _rag_pipeline is None:
from backend.app.rag.pipeline import RAGPipeline
_rag_pipeline = RAGPipeline(
num_queries=3,
rerank_top_n=5,
use_rerank=True,
return_parent_docs=True,
)
return _rag_pipeline
def _format_confidence(result) -> str:
"""格式化置信度描述"""
if result.confidence < 0.4:
return ""
elif result.confidence < 0.6:
return ""
return ""
@tool
async def rag_search(query: str) -> str:
"""
检索知识库获取相关信息
Returns:
包含检索结果和置信度的结构化回复
"""
info(f"[Tool] rag_search: {query[:30]}...")
try:
pipeline = _get_rag_pipeline()
result = await pipeline.aretrieve_with_confidence(query, original_query=query)
if not result.content:
return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度0.0\n建议:可尝试联网搜索获取信息。"
confidence_desc = _format_confidence(result)
is_useful_note = "✅ 检索结果可信,可直接使用" if result.is_useful else "⚠️ 检索结果置信度较低,可能需要联网搜索补充"
response = f"""【RAG检索结果】
{result.content}
【置信度评估】
- 综合置信度:{result.confidence:.2f}{confidence_desc}
- 向量相似度:{result.scores['embedding']:.2f}
- 重排分数:{result.scores['rerank']:.2f}
- LLM评估{result.scores['llm']:.2f}
{is_useful_note}"""
info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}")
return response
except Exception as e:
info(f"[Tool] rag_search 失败: {e}")
return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索"

View File

@@ -0,0 +1,42 @@
"""
子图工具 - 通讯录、词典、新闻分析等
"""
from langchain_core.tools import tool
from backend.app.logger import info
async def _call_subgraph(builder_fn, state_cls, query: str) -> str:
"""通用子图调用"""
try:
graph = builder_fn().compile()
state = state_cls(user_query=query)
result = await graph.ainvoke(state)
return result.get("final_result", "执行完成")
except Exception as e:
info(f"[Tool] 子图调用失败: {e}")
return f"执行失败: {str(e)}"
@tool
async def contact_lookup(query: str) -> str:
"""查询通讯录"""
from backend.app.subgraphs.contact.graph import build_contact_subgraph
from backend.app.subgraphs.contact.state import ContactState
return await _call_subgraph(build_contact_subgraph, ContactState, query)
@tool
async def dictionary_lookup(word: str) -> str:
"""查询词典/翻译"""
from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph
from backend.app.subgraphs.dictionary.state import DictionaryState
return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word)
@tool
async def news_analysis(topic: str) -> str:
"""分析新闻热点"""
from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph
from backend.app.subgraphs.news_analysis.state import NewsAnalysisState
return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic)

View File

@@ -0,0 +1,18 @@
"""
联网搜索工具
"""
from langchain_core.tools import tool
from backend.app.logger import info
@tool
def web_search(query: str) -> str:
"""联网搜索获取最新信息"""
info(f"[Tool] web_search: {query[:30]}...")
try:
from backend.app.core.web_search import web_search as search_fn
return search_fn(query, max_results=5)
except Exception as e:
info(f"[Tool] web_search 失败: {e}")
return f"联网搜索失败: {str(e)}"