This commit is contained in:
@@ -1,198 +1,88 @@
|
|||||||
"""
|
"""
|
||||||
AI Agent 服务类 - 完全简化版本!
|
AI Agent 服务类
|
||||||
按照指南实现,不用 stream_mode="messages" 避免重复 token!
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
from typing import AsyncGenerator, Dict, Any
|
||||||
import asyncio
|
|
||||||
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
|
|
||||||
|
|
||||||
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
|
|
||||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||||
|
|
||||||
# 本地模块
|
|
||||||
from backend.app.model_services import get_cached_chat_services
|
from backend.app.model_services import get_cached_chat_services
|
||||||
from backend.app.main_graph.main_graph_builder import build_agent_graph
|
from backend.app.main_graph.main_graph_builder import build_agent_graph
|
||||||
from backend.app.logger import debug, info, warning, error
|
from backend.app.logger import info
|
||||||
from backend.app.main_graph.state import AgentState
|
from backend.app.memory.mem0_client import Mem0Client
|
||||||
from .stream_context import set_stream_queue
|
|
||||||
|
from .service_config import ServiceConfig
|
||||||
|
from .stream_handler import run_graph_stream
|
||||||
|
|
||||||
|
|
||||||
class AIAgentService:
|
class AIAgentService:
|
||||||
def __init__(self, checkpointer):
|
def __init__(self, checkpointer):
|
||||||
self.checkpointer = checkpointer
|
self.checkpointer = checkpointer
|
||||||
self.graph = None
|
self.graph = None
|
||||||
self.chat_services = None
|
self.config: ServiceConfig = None
|
||||||
# Mem0 客户端
|
|
||||||
self.mem0_client = None
|
self.mem0_client = None
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self) -> "AIAgentService":
|
||||||
# 0. 初始化 Mem0 客户端
|
"""初始化 Agent 服务"""
|
||||||
from ..memory.mem0_client import Mem0Client
|
|
||||||
self.mem0_client = Mem0Client()
|
self.mem0_client = Mem0Client()
|
||||||
|
|
||||||
# 1. 获取缓存的模型字典
|
|
||||||
self.chat_services = get_cached_chat_services()
|
self.chat_services = get_cached_chat_services()
|
||||||
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
|
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
|
||||||
|
|
||||||
# 2. 构建图
|
|
||||||
info(f"🔄 构建 Agent 图...")
|
|
||||||
graph_builder = build_agent_graph(
|
graph_builder = build_agent_graph(
|
||||||
chat_services=self.chat_services,
|
chat_services=self.chat_services,
|
||||||
mem0_client=self.mem0_client
|
mem0_client=self.mem0_client
|
||||||
)
|
)
|
||||||
|
|
||||||
# 编译图
|
|
||||||
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
||||||
|
|
||||||
|
self.config = ServiceConfig(self.chat_services)
|
||||||
info(f"✅ Agent 图初始化完成")
|
info(f"✅ Agent 图初始化完成")
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _resolve_model(self, model: str) -> str:
|
def _resolve_and_build(
|
||||||
"""
|
|
||||||
解析并验证模型名称,不可用时回退到第一个可用模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: 目标模型名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
实际使用的模型名称
|
|
||||||
"""
|
|
||||||
if not model or model not in self.chat_services:
|
|
||||||
fallback = next(iter(self.chat_services.keys()))
|
|
||||||
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
|
|
||||||
return fallback
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _build_invocation(
|
|
||||||
self, message: str, thread_id: str, model: str, user_id: str
|
self, message: str, thread_id: str, model: str, user_id: str
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
):
|
||||||
"""
|
"""解析模型并构建调用参数"""
|
||||||
构建图调用所需的 config 和 input_state
|
resolved_model = self.config.resolve_model(model)
|
||||||
|
return resolved_model, self.config.build_invocation(
|
||||||
Args:
|
message, thread_id, resolved_model, user_id
|
||||||
message: 用户消息
|
)
|
||||||
thread_id: 会话 ID
|
|
||||||
model: 模型名称
|
|
||||||
user_id: 用户 ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(config, input_state) 元组
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
},
|
|
||||||
"metadata": {"user_id": user_id}
|
|
||||||
}
|
|
||||||
|
|
||||||
input_state = {
|
|
||||||
"messages": [HumanMessage(content=message)],
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
return config, input_state
|
|
||||||
|
|
||||||
async def process_message(
|
async def process_message(
|
||||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||||
# 解析模型名称
|
resolved_model, (config, input_state) = self._resolve_and_build(
|
||||||
resolved_model = self._resolve_model(model)
|
message, thread_id, model, user_id
|
||||||
|
)
|
||||||
# 构建调用参数
|
|
||||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
|
||||||
|
|
||||||
result = await self.graph.ainvoke(input_state, config=config)
|
result = await self.graph.ainvoke(input_state, config=config)
|
||||||
|
|
||||||
# 优先使用 final_reply(finalize 节点返回)
|
|
||||||
reply = result.get("final_reply", "")
|
reply = result.get("final_reply", "")
|
||||||
if not reply and result.get("messages"):
|
if not reply and result.get("messages"):
|
||||||
reply = result["messages"][-1].content
|
reply = result["messages"][-1].content
|
||||||
|
|
||||||
token_usage = result.get("last_token_usage", {})
|
|
||||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
|
||||||
|
|
||||||
# 获取元数据
|
|
||||||
metadata = result.get("metadata", {})
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"reply": reply,
|
"reply": reply,
|
||||||
"token_usage": token_usage,
|
"token_usage": result.get("last_token_usage", {}),
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": result.get("last_elapsed_time", 0.0),
|
||||||
"model_used": resolved_model,
|
"model_used": resolved_model,
|
||||||
"metadata": metadata
|
"metadata": result.get("metadata", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def process_message_stream(
|
async def process_message_stream(
|
||||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
"""流式处理消息 - 完全简化!"""
|
"""流式处理消息"""
|
||||||
# 解析模型名称
|
resolved_model, (config, input_state) = self._resolve_and_build(
|
||||||
resolved_model = self._resolve_model(model)
|
message, thread_id, model, user_id
|
||||||
|
)
|
||||||
# 构建调用参数
|
|
||||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
|
||||||
|
|
||||||
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
|
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
|
||||||
actual_model_used = resolved_model
|
|
||||||
|
|
||||||
# 创建 token 队列
|
async for event in run_graph_stream(self.graph, input_state, config):
|
||||||
queue = asyncio.Queue()
|
if event.get("type") != "done":
|
||||||
set_stream_queue(queue) # 设置上下文变量
|
|
||||||
|
|
||||||
async def run_graph():
|
|
||||||
"""后台任务:运行 graph,流式事件都从 agent 节点内部发送!"""
|
|
||||||
try:
|
|
||||||
info(f"📡 开始调用 graph.astream()...")
|
|
||||||
|
|
||||||
# 注意:只用 stream_mode=["updates"],不要 "messages"!避免重复 token!
|
|
||||||
async for _ in self.graph.astream(
|
|
||||||
input_state,
|
|
||||||
config=config,
|
|
||||||
stream_mode=["updates"],
|
|
||||||
version="v2",
|
|
||||||
subgraphs=True
|
|
||||||
):
|
|
||||||
# 流式事件都从 agent.py 节点内部通过队列发送了
|
|
||||||
# 这里不需要再发送任何事件
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
error(f"❌ 执行图时出错: {e}")
|
|
||||||
import traceback
|
|
||||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
|
||||||
await queue.put({"type": "error", "message": str(e)})
|
|
||||||
finally:
|
|
||||||
await queue.put(None) # 结束哨兵
|
|
||||||
|
|
||||||
# 启动后台任务
|
|
||||||
bg_task = asyncio.create_task(run_graph())
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
event = await queue.get()
|
|
||||||
if event is None:
|
|
||||||
break
|
|
||||||
yield event
|
yield event
|
||||||
|
else:
|
||||||
except GeneratorExit:
|
yield {**event, "model_used": resolved_model}
|
||||||
# 客户端断开连接,取消后台任务
|
|
||||||
info("⚠️ GeneratorExit,取消后台任务")
|
|
||||||
bg_task.cancel()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# 保证任务被清理
|
|
||||||
if not bg_task.done():
|
|
||||||
info("⏹️ 清理后台任务")
|
|
||||||
bg_task.cancel()
|
|
||||||
try:
|
|
||||||
await bg_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
info("✅ 后台任务已取消")
|
|
||||||
|
|
||||||
# 发送结束事件,保证前端平稳关闭
|
|
||||||
yield {
|
|
||||||
"type": "done",
|
|
||||||
"model_used": actual_model_used
|
|
||||||
}
|
|
||||||
|
|||||||
46
backend/app/agent/service_config.py
Normal file
46
backend/app/agent/service_config.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
Agent Service 配置模块 - 配置构建和解析
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Tuple, Optional
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from backend.app.logger import warning
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceConfig:
|
||||||
|
"""配置构建器"""
|
||||||
|
|
||||||
|
def __init__(self, chat_services: dict):
|
||||||
|
self.chat_services = chat_services
|
||||||
|
|
||||||
|
def resolve_model(self, model: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
解析并验证模型名称,不可用时回退到第一个可用模型
|
||||||
|
"""
|
||||||
|
if not model or model not in self.chat_services:
|
||||||
|
fallback = next(iter(self.chat_services.keys()))
|
||||||
|
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
|
||||||
|
return fallback
|
||||||
|
return model
|
||||||
|
|
||||||
|
def build_invocation(
|
||||||
|
self, message: str, thread_id: str, model: str, user_id: str
|
||||||
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
构建图调用所需的 config 和 input_state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(config, input_state) 元组
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
},
|
||||||
|
"metadata": {"user_id": user_id}
|
||||||
|
}
|
||||||
|
|
||||||
|
input_state = {
|
||||||
|
"messages": [HumanMessage(content=message)],
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
return config, input_state
|
||||||
78
backend/app/agent/stream_handler.py
Normal file
78
backend/app/agent/stream_handler.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
流式处理模块 - 处理 Agent 执行的流式输出
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncGenerator, Dict, Any
|
||||||
|
|
||||||
|
from backend.app.logger import info, error
|
||||||
|
from .stream_context import set_stream_queue
|
||||||
|
|
||||||
|
|
||||||
|
async def run_graph_stream(
|
||||||
|
graph,
|
||||||
|
input_state: Dict[str, Any],
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
运行图并通过队列流式输出事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: 编译后的 LangGraph
|
||||||
|
input_state: 输入状态
|
||||||
|
config: 配置
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
流式事件
|
||||||
|
"""
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
set_stream_queue(queue)
|
||||||
|
|
||||||
|
async def run_graph():
|
||||||
|
"""后台任务:运行 graph"""
|
||||||
|
try:
|
||||||
|
info(f"📡 开始调用 graph.astream()...")
|
||||||
|
async for _ in graph.astream(
|
||||||
|
input_state,
|
||||||
|
config=config,
|
||||||
|
stream_mode=["updates"],
|
||||||
|
version="v2",
|
||||||
|
subgraphs=True
|
||||||
|
):
|
||||||
|
# 流式事件都从 agent.py 节点内部通过队列发送了
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
error(f"❌ 执行图时出错: {e}")
|
||||||
|
import traceback
|
||||||
|
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||||
|
await queue.put({"type": "error", "message": str(e)})
|
||||||
|
finally:
|
||||||
|
await queue.put(None) # 结束哨兵
|
||||||
|
|
||||||
|
# 启动后台任务
|
||||||
|
bg_task = asyncio.create_task(run_graph())
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
event = await queue.get()
|
||||||
|
if event is None:
|
||||||
|
break
|
||||||
|
yield event
|
||||||
|
|
||||||
|
except GeneratorExit:
|
||||||
|
info("⚠️ GeneratorExit,取消后台任务")
|
||||||
|
bg_task.cancel()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await _cleanup_task(bg_task)
|
||||||
|
|
||||||
|
|
||||||
|
async def _cleanup_task(bg_task: asyncio.Task) -> None:
|
||||||
|
"""清理后台任务"""
|
||||||
|
if not bg_task.done():
|
||||||
|
info("⏹️ 清理后台任务")
|
||||||
|
bg_task.cancel()
|
||||||
|
try:
|
||||||
|
await bg_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
info("✅ 后台任务已取消")
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
# app/rag_initializer.py
|
|
||||||
from ...rag.tools import create_rag_tool
|
|
||||||
from ...rag.retriever import create_parent_hybrid_retriever
|
|
||||||
from ...model_services import get_embedding_service
|
|
||||||
from backend.app.logger import info, warning
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# 全局 RAG 工具
|
|
||||||
_rag_tool = None
|
|
||||||
_initialized = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_rag_tool() -> callable:
|
|
||||||
"""获取全局 RAG 工具"""
|
|
||||||
return _rag_tool
|
|
||||||
|
|
||||||
|
|
||||||
def is_initialized() -> bool:
|
|
||||||
"""检查是否已初始化"""
|
|
||||||
return _initialized
|
|
||||||
|
|
||||||
|
|
||||||
async def init_rag_tool(force: bool = False):
|
|
||||||
"""
|
|
||||||
初始化 RAG 工具(注册到模块级变量,内部获取所需服务)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
force: 是否强制重新初始化
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RAG 工具(@tool 装饰函数)或 None
|
|
||||||
"""
|
|
||||||
global _rag_tool, _initialized
|
|
||||||
|
|
||||||
# 防止重复初始化
|
|
||||||
if _initialized and not force:
|
|
||||||
info("[RAG] 已初始化,跳过")
|
|
||||||
return _rag_tool
|
|
||||||
|
|
||||||
try:
|
|
||||||
from backend.app.model_services.chat_services import get_chat_service
|
|
||||||
|
|
||||||
info("🔄 正在初始化 RAG 检索系统...")
|
|
||||||
embeddings = get_embedding_service()
|
|
||||||
retriever = create_parent_hybrid_retriever(
|
|
||||||
collection_name="rag_documents",
|
|
||||||
search_k=5,
|
|
||||||
embeddings=embeddings,
|
|
||||||
)
|
|
||||||
rewrite_llm = get_chat_service()
|
|
||||||
|
|
||||||
rag_tool = create_rag_tool(
|
|
||||||
retriever=retriever,
|
|
||||||
llm=rewrite_llm,
|
|
||||||
num_queries=3,
|
|
||||||
rerank_top_n=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
_rag_tool = rag_tool
|
|
||||||
_initialized = True
|
|
||||||
info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})")
|
|
||||||
return rag_tool
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
|
||||||
"""重置(用于测试)"""
|
|
||||||
global _rag_tool, _initialized
|
|
||||||
_rag_tool = None
|
|
||||||
_initialized = False
|
|
||||||
@@ -3,12 +3,11 @@
|
|||||||
Web Search Public Utility - Free, no API Key, using DuckDuckGo
|
Web Search Public Utility - Free, no API Key, using DuckDuckGo
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import requests
|
|
||||||
import warnings
|
from backend.app.logger import info
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -44,47 +43,31 @@ class WebSearchTool:
|
|||||||
"""
|
"""
|
||||||
num_results = max_results or self.max_results
|
num_results = max_results or self.max_results
|
||||||
|
|
||||||
# 方式 1: Tavily (需要 API Key,质量最高)
|
# 尝试搜索方式,按优先级
|
||||||
|
result = self._try_tavily(query, num_results)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
result = self._try_ddgs(query, num_results)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 兜底方案
|
||||||
|
return self._get_mock_results(query, num_results)
|
||||||
|
|
||||||
|
def _try_tavily(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
|
||||||
|
"""尝试 Tavily API 搜索"""
|
||||||
try:
|
try:
|
||||||
return self._search_tavily(query, num_results)
|
return self._search_tavily(query, max_results)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("[WebSearch] tavily 未安装,尝试其他搜索方式")
|
info("[WebSearch] tavily 未安装")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "API_KEY" in str(e) or "未配置" in str(e):
|
error_msg = str(e)
|
||||||
print(f"[WebSearch] Tavily API Key 未配置: {e}")
|
if "API_KEY" in error_msg or "未配置" in error_msg:
|
||||||
|
info(f"[WebSearch] Tavily API Key 未配置")
|
||||||
else:
|
else:
|
||||||
print(f"[WebSearch] Tavily 搜索失败: {e}")
|
info(f"[WebSearch] Tavily 搜索失败: {e}")
|
||||||
|
return None
|
||||||
# 方式 2: 尝试用 ddgs 包
|
|
||||||
try:
|
|
||||||
from ddgs import DDGS
|
|
||||||
print(f"[WebSearch] 使用 ddgs 搜索: {query}")
|
|
||||||
with DDGS() as ddgs:
|
|
||||||
results = list(ddgs.text(query, max_results=num_results))
|
|
||||||
if results:
|
|
||||||
search_results = []
|
|
||||||
for r in results:
|
|
||||||
search_results.append(SearchResult(
|
|
||||||
title=r.get("title", ""),
|
|
||||||
url=r.get("href", ""),
|
|
||||||
snippet=r.get("body", ""),
|
|
||||||
source="DuckDuckGo"
|
|
||||||
))
|
|
||||||
print(f"[WebSearch] ddgs 返回 {len(search_results)} 条结果")
|
|
||||||
return search_results
|
|
||||||
except ImportError:
|
|
||||||
print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WebSearch] ddgs 搜索失败: {e}")
|
|
||||||
|
|
||||||
# 方式 3: 尝试用简单 HTTP 请求
|
|
||||||
try:
|
|
||||||
return self._search_http(query, num_results)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WebSearch] HTTP 搜索也失败: {e}")
|
|
||||||
|
|
||||||
# 方式 4: 返回模拟数据作为最后兜底
|
|
||||||
return self._search_mock(query, num_results)
|
|
||||||
|
|
||||||
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
||||||
"""使用 Tavily API 搜索"""
|
"""使用 Tavily API 搜索"""
|
||||||
@@ -111,56 +94,40 @@ class WebSearchTool:
|
|||||||
source="Tavily"
|
source="Tavily"
|
||||||
))
|
))
|
||||||
|
|
||||||
print(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
info(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _search_http(self, query: str, max_results: int) -> List[SearchResult]:
|
def _try_ddgs(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
|
||||||
"""用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源"""
|
"""尝试 DuckDuckGo 搜索"""
|
||||||
print(f"[WebSearch] 尝试 HTTP 搜索")
|
|
||||||
|
|
||||||
# 方式 1: 尝试百度搜索(简单方式)
|
|
||||||
try:
|
try:
|
||||||
return self._search_baidu(query, max_results)
|
from ddgs import DDGS
|
||||||
except Exception as e:
|
|
||||||
print(f"[WebSearch] 百度搜索失败: {e}")
|
|
||||||
|
|
||||||
# 方式 2: 返回模拟数据
|
|
||||||
return self._search_mock(query, max_results)
|
|
||||||
|
|
||||||
def _search_baidu(self, query: str, max_results: int) -> List[SearchResult]:
|
|
||||||
"""尝试百度搜索"""
|
|
||||||
import requests
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
url = f"https://www.baidu.com/s?wd={quote(query)}"
|
|
||||||
headers = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.get(url, headers=headers, timeout=10)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
# 简单解析百度搜索结果(简化版)
|
|
||||||
results = []
|
results = []
|
||||||
# 这里只是示意,真实百度搜索需要更复杂的解析
|
with DDGS() as ddgs:
|
||||||
results.append(SearchResult(
|
for r in ddgs.text(query, max_results=max_results):
|
||||||
title=f"百度搜索: {query}",
|
results.append(SearchResult(
|
||||||
url=url,
|
title=r.get("title", ""),
|
||||||
snippet="如需要真实搜索结果,请考虑使用百度搜索 API",
|
url=r.get("href", ""),
|
||||||
source="百度"
|
snippet=r.get("body", ""),
|
||||||
))
|
source="DuckDuckGo"
|
||||||
return results
|
))
|
||||||
except Exception as e:
|
|
||||||
print(f"[WebSearch] 百度搜索也失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _search_mock(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
if results:
|
||||||
"""模拟搜索结果(兜底方案)"""
|
info(f"[WebSearch] ddgs 返回 {len(results)} 条结果")
|
||||||
print(f"[WebSearch] 使用模拟搜索结果 (查询: {query})")
|
return results
|
||||||
|
|
||||||
# 根据查询内容生成更有意义的模拟结果
|
except ImportError:
|
||||||
mock_templates = [
|
info("[WebSearch] ddgs 未安装")
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[WebSearch] ddgs 搜索失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_mock_results(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||||
|
"""获取模拟搜索结果(兜底方案)"""
|
||||||
|
info(f"[WebSearch] 使用模拟搜索结果")
|
||||||
|
|
||||||
|
templates = [
|
||||||
{
|
{
|
||||||
"title": f"关于「{query}」的相关介绍",
|
"title": f"关于「{query}」的相关介绍",
|
||||||
"snippet": "这是模拟结果。如需真实搜索,请检查容器网络连接或配置代理。",
|
"snippet": "这是模拟结果。如需真实搜索,请检查容器网络连接或配置代理。",
|
||||||
@@ -177,50 +144,48 @@ class WebSearchTool:
|
|||||||
"url": "https://example.com/more"
|
"url": "https://example.com/more"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
num = max_results or self.max_results
|
num = max_results or self.max_results
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for i, template in enumerate(mock_templates[:num]):
|
for template in templates[:num]:
|
||||||
results.append(SearchResult(
|
results.append(SearchResult(
|
||||||
title=template["title"],
|
title=template["title"],
|
||||||
url=template["url"],
|
url=template["url"],
|
||||||
snippet=template["snippet"],
|
snippet=template["snippet"],
|
||||||
source="模拟数据"
|
source="模拟数据"
|
||||||
))
|
))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def format_search_results(self, results: List[SearchResult]) -> str:
|
def format_search_results(self, results: List[SearchResult]) -> str:
|
||||||
"""
|
"""
|
||||||
格式化搜索结果(带引用溯源)
|
格式化搜索结果(带引用溯源)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
results: 搜索结果列表
|
results: 搜索结果列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
格式化后的 Markdown 文本
|
格式化后的 Markdown 文本
|
||||||
"""
|
"""
|
||||||
if not results:
|
if not results:
|
||||||
return "未找到相关搜索结果"
|
return "未找到相关搜索结果"
|
||||||
|
|
||||||
lines = []
|
lines = ["## 🔍 联网搜索结果\n"]
|
||||||
lines.append("## 🔍 联网搜索结果\n")
|
|
||||||
|
|
||||||
for idx, result in enumerate(results, 1):
|
for idx, result in enumerate(results, 1):
|
||||||
lines.append(f"### [{idx}] {result.title}")
|
lines.append(f"### [{idx}] {result.title}")
|
||||||
lines.append(f"- 🔗 来源:[{result.url}]({result.url})")
|
lines.append(f"- 🔗 来源:[{result.url}]({result.url})")
|
||||||
lines.append(f"- 📝 摘要:{result.snippet}")
|
lines.append(f"- 📝 摘要:{result.snippet}")
|
||||||
lines.append(f"- 📅 时间:{result.timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
|
lines.append(f"- 📅 时间:{result.timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
# 添加引用溯源说明
|
|
||||||
lines.append("---")
|
lines.append("---")
|
||||||
lines.append("💡 **引用溯源说明**:")
|
lines.append("💡 **引用溯源说明**:")
|
||||||
lines.append("- 以上搜索结果均标注了来源链接")
|
lines.append("- 以上搜索结果均标注了来源链接")
|
||||||
lines.append("- 使用方括号数字标识引用(如 [1]、[2])")
|
lines.append("- 使用方括号数字标识引用(如 [1]、[2])")
|
||||||
lines.append("- 可通过链接追溯原始信息")
|
lines.append("- 可通过链接追溯原始信息")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@@ -239,11 +204,11 @@ def get_web_search_tool() -> WebSearchTool:
|
|||||||
def web_search(query: str, max_results: int = 5) -> str:
|
def web_search(query: str, max_results: int = 5) -> str:
|
||||||
"""
|
"""
|
||||||
便捷函数:联网搜索并返回格式化结果
|
便捷函数:联网搜索并返回格式化结果
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 搜索关键词
|
query: 搜索关键词
|
||||||
max_results: 返回结果数量
|
max_results: 返回结果数量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
格式化后的搜索结果文本
|
格式化后的搜索结果文本
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
"""
|
|
||||||
RAG 工具模块(完全异步)
|
|
||||||
|
|
||||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
|
||||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
|
||||||
|
|
||||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
|
||||||
"""
|
|
||||||
from typing import Callable, Optional
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
from ..rag.pipeline import RAGPipeline, create_rag_pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def create_rag_tool(
|
|
||||||
retriever: Optional[BaseRetriever] = None,
|
|
||||||
llm: Optional[BaseLanguageModel] = "default_small",
|
|
||||||
num_queries: int = 3,
|
|
||||||
rerank_top_n: int = 5,
|
|
||||||
collection_name: str = "rag_documents",
|
|
||||||
) -> Callable:
|
|
||||||
"""
|
|
||||||
创建一个配置好的 RAG 检索工具(完全异步)。
|
|
||||||
|
|
||||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
retriever: 基础检索器对象(可选,不提供则自动创建)
|
|
||||||
llm: 用于生成多路查询的语言模型。
|
|
||||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
|
||||||
- None / False: 不做查询改写
|
|
||||||
- BaseLanguageModel 实例: 自定义模型
|
|
||||||
num_queries: 生成的查询变体数量
|
|
||||||
rerank_top_n: 最终返回的文档数量
|
|
||||||
collection_name: Qdrant 集合名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Async LangChain Tool 函数
|
|
||||||
"""
|
|
||||||
pipeline = RAGPipeline(
|
|
||||||
retriever=retriever,
|
|
||||||
llm=llm,
|
|
||||||
num_queries=num_queries,
|
|
||||||
rerank_top_n=rerank_top_n,
|
|
||||||
collection_name=collection_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def search_knowledge_base(query: str) -> str:
|
|
||||||
"""
|
|
||||||
在知识库中搜索与查询相关的文档片段(完全异步)。
|
|
||||||
|
|
||||||
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
|
||||||
检索效果最优。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户提出的问题或查询字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
格式化后的相关文档内容
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
documents = await pipeline.aretrieve(query)
|
|
||||||
if not documents:
|
|
||||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
|
||||||
|
|
||||||
context = pipeline.format_context(documents)
|
|
||||||
return context
|
|
||||||
except Exception as e:
|
|
||||||
return f"检索过程中发生错误: {str(e)}"
|
|
||||||
|
|
||||||
return search_knowledge_base
|
|
||||||
@@ -2,124 +2,17 @@
|
|||||||
Agent Tools - 所有工具统一定义
|
Agent Tools - 所有工具统一定义
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from .rag import rag_search
|
||||||
from backend.app.logger import info
|
from .web_search import web_search
|
||||||
|
from .subgraph import contact_lookup, dictionary_lookup, news_analysis
|
||||||
# ========== RAG ==========
|
|
||||||
|
|
||||||
_rag_pipeline = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_rag_pipeline():
|
|
||||||
global _rag_pipeline
|
|
||||||
if _rag_pipeline is None:
|
|
||||||
from backend.app.rag.pipeline import RAGPipeline
|
|
||||||
_rag_pipeline = RAGPipeline(
|
|
||||||
num_queries=3,
|
|
||||||
rerank_top_n=5,
|
|
||||||
use_rerank=True,
|
|
||||||
return_parent_docs=True,
|
|
||||||
)
|
|
||||||
return _rag_pipeline
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def rag_search(query: str) -> str:
|
|
||||||
"""
|
|
||||||
检索知识库获取相关信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含检索结果和置信度的结构化回复,格式:
|
|
||||||
- 内容:检索到的相关信息
|
|
||||||
- 置信度评估:基于向量相似度、重排分数、LLM判断的综合评分
|
|
||||||
"""
|
|
||||||
info(f"[Tool] rag_search: {query[:30]}...")
|
|
||||||
try:
|
|
||||||
pipeline = _get_rag_pipeline()
|
|
||||||
# 使用带置信度的检索
|
|
||||||
result = await pipeline.aretrieve_with_confidence(query, original_query=query)
|
|
||||||
|
|
||||||
if not result.content:
|
|
||||||
return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度:0.0\n建议:可尝试联网搜索获取信息。"
|
|
||||||
|
|
||||||
# 构建包含置信度的回复
|
|
||||||
confidence_desc = "高"
|
|
||||||
if result.confidence < 0.4:
|
|
||||||
confidence_desc = "低"
|
|
||||||
elif result.confidence < 0.6:
|
|
||||||
confidence_desc = "中"
|
|
||||||
|
|
||||||
response = f"""【RAG检索结果】
|
|
||||||
{result.content}
|
|
||||||
|
|
||||||
【置信度评估】
|
|
||||||
- 综合置信度:{result.confidence:.2f}({confidence_desc})
|
|
||||||
- 向量相似度:{result.scores['embedding']:.2f}
|
|
||||||
- 重排分数:{result.scores['rerank']:.2f}
|
|
||||||
- LLM评估:{result.scores['llm']:.2f}
|
|
||||||
|
|
||||||
{'✅ 检索结果可信,可直接使用' if result.is_useful else '⚠️ 检索结果置信度较低,可能需要联网搜索补充'}"""
|
|
||||||
|
|
||||||
info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}")
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
info(f"[Tool] rag_search 失败: {e}")
|
|
||||||
return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 联网搜索 ==========
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def web_search(query: str) -> str:
|
|
||||||
"""联网搜索获取最新信息"""
|
|
||||||
info(f"[Tool] web_search: {query[:30]}...")
|
|
||||||
try:
|
|
||||||
from backend.app.core.web_search import web_search as search_fn
|
|
||||||
return search_fn(query, max_results=5)
|
|
||||||
except Exception as e:
|
|
||||||
info(f"[Tool] web_search 失败: {e}")
|
|
||||||
return f"联网搜索失败: {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 子图工具 ==========
|
|
||||||
|
|
||||||
async def _call_subgraph(builder_fn, state_cls, query: str) -> str:
|
|
||||||
"""通用子图调用"""
|
|
||||||
try:
|
|
||||||
graph = builder_fn().compile()
|
|
||||||
state = state_cls(user_query=query)
|
|
||||||
result = await graph.ainvoke(state)
|
|
||||||
return result.get("final_result", "执行完成")
|
|
||||||
except Exception as e:
|
|
||||||
info(f"[Tool] 子图调用失败: {e}")
|
|
||||||
return f"执行失败: {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def contact_lookup(query: str) -> str:
|
|
||||||
"""查询通讯录"""
|
|
||||||
from backend.app.subgraphs.contact.graph import build_contact_subgraph
|
|
||||||
from backend.app.subgraphs.contact.state import ContactState
|
|
||||||
return await _call_subgraph(build_contact_subgraph, ContactState, query)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def dictionary_lookup(word: str) -> str:
|
|
||||||
"""查询词典/翻译"""
|
|
||||||
from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph
|
|
||||||
from backend.app.subgraphs.dictionary.state import DictionaryState
|
|
||||||
return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def news_analysis(topic: str) -> str:
|
|
||||||
"""分析新闻热点"""
|
|
||||||
from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph
|
|
||||||
from backend.app.subgraphs.news_analysis.state import NewsAnalysisState
|
|
||||||
return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic)
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 导出 ==========
|
|
||||||
|
|
||||||
ALL_TOOLS = [rag_search, web_search, contact_lookup, dictionary_lookup, news_analysis]
|
ALL_TOOLS = [rag_search, web_search, contact_lookup, dictionary_lookup, news_analysis]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"rag_search",
|
||||||
|
"web_search",
|
||||||
|
"contact_lookup",
|
||||||
|
"dictionary_lookup",
|
||||||
|
"news_analysis",
|
||||||
|
"ALL_TOOLS",
|
||||||
|
]
|
||||||
|
|||||||
8
backend/app/tools/base.py
Normal file
8
backend/app/tools/base.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
工具模块配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
__all__ = ["tool", "info"]
|
||||||
70
backend/app/tools/rag.py
Normal file
70
backend/app/tools/rag.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""
|
||||||
|
RAG 检索工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
|
||||||
|
_rag_pipeline = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rag_pipeline():
|
||||||
|
"""获取或创建 RAG pipeline 单例"""
|
||||||
|
global _rag_pipeline
|
||||||
|
if _rag_pipeline is None:
|
||||||
|
from backend.app.rag.pipeline import RAGPipeline
|
||||||
|
_rag_pipeline = RAGPipeline(
|
||||||
|
num_queries=3,
|
||||||
|
rerank_top_n=5,
|
||||||
|
use_rerank=True,
|
||||||
|
return_parent_docs=True,
|
||||||
|
)
|
||||||
|
return _rag_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def _format_confidence(result) -> str:
|
||||||
|
"""格式化置信度描述"""
|
||||||
|
if result.confidence < 0.4:
|
||||||
|
return "低"
|
||||||
|
elif result.confidence < 0.6:
|
||||||
|
return "中"
|
||||||
|
return "高"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def rag_search(query: str) -> str:
|
||||||
|
"""
|
||||||
|
检索知识库获取相关信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含检索结果和置信度的结构化回复
|
||||||
|
"""
|
||||||
|
info(f"[Tool] rag_search: {query[:30]}...")
|
||||||
|
try:
|
||||||
|
pipeline = _get_rag_pipeline()
|
||||||
|
result = await pipeline.aretrieve_with_confidence(query, original_query=query)
|
||||||
|
|
||||||
|
if not result.content:
|
||||||
|
return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度:0.0\n建议:可尝试联网搜索获取信息。"
|
||||||
|
|
||||||
|
confidence_desc = _format_confidence(result)
|
||||||
|
is_useful_note = "✅ 检索结果可信,可直接使用" if result.is_useful else "⚠️ 检索结果置信度较低,可能需要联网搜索补充"
|
||||||
|
|
||||||
|
response = f"""【RAG检索结果】
|
||||||
|
{result.content}
|
||||||
|
|
||||||
|
【置信度评估】
|
||||||
|
- 综合置信度:{result.confidence:.2f}({confidence_desc})
|
||||||
|
- 向量相似度:{result.scores['embedding']:.2f}
|
||||||
|
- 重排分数:{result.scores['rerank']:.2f}
|
||||||
|
- LLM评估:{result.scores['llm']:.2f}
|
||||||
|
|
||||||
|
{is_useful_note}"""
|
||||||
|
|
||||||
|
info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[Tool] rag_search 失败: {e}")
|
||||||
|
return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索"
|
||||||
42
backend/app/tools/subgraph.py
Normal file
42
backend/app/tools/subgraph.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
子图工具 - 通讯录、词典、新闻分析等
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_subgraph(builder_fn, state_cls, query: str) -> str:
|
||||||
|
"""通用子图调用"""
|
||||||
|
try:
|
||||||
|
graph = builder_fn().compile()
|
||||||
|
state = state_cls(user_query=query)
|
||||||
|
result = await graph.ainvoke(state)
|
||||||
|
return result.get("final_result", "执行完成")
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[Tool] 子图调用失败: {e}")
|
||||||
|
return f"执行失败: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def contact_lookup(query: str) -> str:
|
||||||
|
"""查询通讯录"""
|
||||||
|
from backend.app.subgraphs.contact.graph import build_contact_subgraph
|
||||||
|
from backend.app.subgraphs.contact.state import ContactState
|
||||||
|
return await _call_subgraph(build_contact_subgraph, ContactState, query)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def dictionary_lookup(word: str) -> str:
|
||||||
|
"""查询词典/翻译"""
|
||||||
|
from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph
|
||||||
|
from backend.app.subgraphs.dictionary.state import DictionaryState
|
||||||
|
return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def news_analysis(topic: str) -> str:
|
||||||
|
"""分析新闻热点"""
|
||||||
|
from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph
|
||||||
|
from backend.app.subgraphs.news_analysis.state import NewsAnalysisState
|
||||||
|
return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic)
|
||||||
18
backend/app/tools/web_search.py
Normal file
18
backend/app/tools/web_search.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
联网搜索工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def web_search(query: str) -> str:
|
||||||
|
"""联网搜索获取最新信息"""
|
||||||
|
info(f"[Tool] web_search: {query[:30]}...")
|
||||||
|
try:
|
||||||
|
from backend.app.core.web_search import web_search as search_fn
|
||||||
|
return search_fn(query, max_results=5)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[Tool] web_search 失败: {e}")
|
||||||
|
return f"联网搜索失败: {str(e)}"
|
||||||
Reference in New Issue
Block a user