Compare commits
9 Commits
6dfa9f572e
...
b30f7b00a7
| Author | SHA1 | Date | |
|---|---|---|---|
| b30f7b00a7 | |||
| bfb2ddbe76 | |||
| 5c2380e31c | |||
| d16ad6185e | |||
| ef07b05c22 | |||
| e851e40763 | |||
| ce4d7515d9 | |||
| 527d7a0b1d | |||
| 46cd7abcc6 |
@@ -1,193 +1,88 @@
|
|||||||
"""
|
"""
|
||||||
AI Agent 服务类 - 完全简化版本!
|
AI Agent 服务类
|
||||||
按照指南实现,不用 stream_mode="messages" 避免重复 token!
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
from typing import AsyncGenerator, Dict, Any
|
||||||
import asyncio
|
|
||||||
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
|
|
||||||
|
|
||||||
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
|
|
||||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||||
|
|
||||||
# 本地模块
|
|
||||||
from backend.app.model_services import get_cached_chat_services
|
from backend.app.model_services import get_cached_chat_services
|
||||||
from backend.app.main_graph.main_graph_builder import build_agent_graph
|
from backend.app.main_graph.main_graph_builder import build_agent_graph
|
||||||
from backend.app.logger import debug, info, warning, error
|
from backend.app.logger import info
|
||||||
from backend.app.main_graph.state import AgentState
|
from backend.app.memory.mem0_client import Mem0Client
|
||||||
from .stream_context import set_stream_queue
|
|
||||||
|
from .service_config import ServiceConfig
|
||||||
|
from .stream_handler import run_graph_stream
|
||||||
|
|
||||||
|
|
||||||
class AIAgentService:
|
class AIAgentService:
|
||||||
def __init__(self, checkpointer):
|
def __init__(self, checkpointer):
|
||||||
self.checkpointer = checkpointer
|
self.checkpointer = checkpointer
|
||||||
self.graph = None
|
self.graph = None
|
||||||
self.chat_services = None
|
self.config: ServiceConfig = None
|
||||||
# Mem0 客户端
|
|
||||||
self.mem0_client = None
|
self.mem0_client = None
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self) -> "AIAgentService":
|
||||||
# 0. 初始化 Mem0 客户端
|
"""初始化 Agent 服务"""
|
||||||
from ..memory.mem0_client import Mem0Client
|
|
||||||
self.mem0_client = Mem0Client()
|
self.mem0_client = Mem0Client()
|
||||||
|
|
||||||
# 1. 获取缓存的模型字典
|
|
||||||
self.chat_services = get_cached_chat_services()
|
self.chat_services = get_cached_chat_services()
|
||||||
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
|
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
|
||||||
|
|
||||||
# 2. 构建图
|
|
||||||
info(f"🔄 构建 Agent 图...")
|
|
||||||
graph_builder = build_agent_graph(
|
graph_builder = build_agent_graph(
|
||||||
chat_services=self.chat_services,
|
chat_services=self.chat_services,
|
||||||
mem0_client=self.mem0_client
|
mem0_client=self.mem0_client
|
||||||
)
|
)
|
||||||
|
|
||||||
# 编译图
|
|
||||||
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
||||||
|
|
||||||
|
self.config = ServiceConfig(self.chat_services)
|
||||||
info(f"✅ Agent 图初始化完成")
|
info(f"✅ Agent 图初始化完成")
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _resolve_model(self, model: str) -> str:
|
def _resolve_and_build(
|
||||||
"""
|
|
||||||
解析并验证模型名称,不可用时回退到第一个可用模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: 目标模型名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
实际使用的模型名称
|
|
||||||
"""
|
|
||||||
if not model or model not in self.chat_services:
|
|
||||||
fallback = next(iter(self.chat_services.keys()))
|
|
||||||
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
|
|
||||||
return fallback
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _build_invocation(
|
|
||||||
self, message: str, thread_id: str, model: str, user_id: str
|
self, message: str, thread_id: str, model: str, user_id: str
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
):
|
||||||
"""
|
"""解析模型并构建调用参数"""
|
||||||
构建图调用所需的 config 和 input_state
|
resolved_model = self.config.resolve_model(model)
|
||||||
|
return resolved_model, self.config.build_invocation(
|
||||||
Args:
|
message, thread_id, resolved_model, user_id
|
||||||
message: 用户消息
|
)
|
||||||
thread_id: 会话 ID
|
|
||||||
model: 模型名称
|
|
||||||
user_id: 用户 ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(config, input_state) 元组
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
},
|
|
||||||
"metadata": {"user_id": user_id}
|
|
||||||
}
|
|
||||||
|
|
||||||
input_state = {
|
|
||||||
"messages": [HumanMessage(content=message)],
|
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
return config, input_state
|
|
||||||
|
|
||||||
async def process_message(
|
async def process_message(
|
||||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||||
# 解析模型名称
|
resolved_model, (config, input_state) = self._resolve_and_build(
|
||||||
resolved_model = self._resolve_model(model)
|
message, thread_id, model, user_id
|
||||||
|
)
|
||||||
# 构建调用参数
|
|
||||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
|
||||||
|
|
||||||
result = await self.graph.ainvoke(input_state, config=config)
|
result = await self.graph.ainvoke(input_state, config=config)
|
||||||
|
|
||||||
reply = ""
|
reply = result.get("final_reply", "")
|
||||||
if result.get("messages"):
|
if not reply and result.get("messages"):
|
||||||
reply = result["messages"][-1].content
|
reply = result["messages"][-1].content
|
||||||
|
|
||||||
token_usage = result.get("last_token_usage", {})
|
|
||||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"reply": reply,
|
"reply": reply,
|
||||||
"token_usage": token_usage,
|
"token_usage": result.get("last_token_usage", {}),
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": result.get("last_elapsed_time", 0.0),
|
||||||
"model_used": resolved_model
|
"model_used": resolved_model,
|
||||||
|
"metadata": result.get("metadata", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def process_message_stream(
|
async def process_message_stream(
|
||||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
"""流式处理消息 - 完全简化!"""
|
"""流式处理消息"""
|
||||||
# 解析模型名称
|
resolved_model, (config, input_state) = self._resolve_and_build(
|
||||||
resolved_model = self._resolve_model(model)
|
message, thread_id, model, user_id
|
||||||
|
)
|
||||||
# 构建调用参数
|
|
||||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
|
||||||
|
|
||||||
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
|
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
|
||||||
actual_model_used = resolved_model
|
|
||||||
|
|
||||||
# 创建 token 队列
|
async for event in run_graph_stream(self.graph, input_state, config):
|
||||||
queue = asyncio.Queue()
|
if event.get("type") != "done":
|
||||||
set_stream_queue(queue) # 设置上下文变量
|
|
||||||
|
|
||||||
async def run_graph():
|
|
||||||
"""后台任务:运行 graph,流式事件都从 agent 节点内部发送!"""
|
|
||||||
try:
|
|
||||||
info(f"📡 开始调用 graph.astream()...")
|
|
||||||
|
|
||||||
# 注意:只用 stream_mode=["updates"],不要 "messages"!避免重复 token!
|
|
||||||
async for _ in self.graph.astream(
|
|
||||||
input_state,
|
|
||||||
config=config,
|
|
||||||
stream_mode=["updates"],
|
|
||||||
version="v2",
|
|
||||||
subgraphs=True
|
|
||||||
):
|
|
||||||
# 流式事件都从 agent.py 节点内部通过队列发送了
|
|
||||||
# 这里不需要再发送任何事件
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
error(f"❌ 执行图时出错: {e}")
|
|
||||||
import traceback
|
|
||||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
|
||||||
await queue.put({"type": "error", "message": str(e)})
|
|
||||||
finally:
|
|
||||||
await queue.put(None) # 结束哨兵
|
|
||||||
|
|
||||||
# 启动后台任务
|
|
||||||
bg_task = asyncio.create_task(run_graph())
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
event = await queue.get()
|
|
||||||
if event is None:
|
|
||||||
break
|
|
||||||
yield event
|
yield event
|
||||||
|
else:
|
||||||
except GeneratorExit:
|
yield {**event, "model_used": resolved_model}
|
||||||
# 客户端断开连接,取消后台任务
|
|
||||||
info("⚠️ GeneratorExit,取消后台任务")
|
|
||||||
bg_task.cancel()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# 保证任务被清理
|
|
||||||
if not bg_task.done():
|
|
||||||
info("⏹️ 清理后台任务")
|
|
||||||
bg_task.cancel()
|
|
||||||
try:
|
|
||||||
await bg_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
info("✅ 后台任务已取消")
|
|
||||||
|
|
||||||
# 发送结束事件,保证前端平稳关闭
|
|
||||||
yield {
|
|
||||||
"type": "done",
|
|
||||||
"model_used": actual_model_used
|
|
||||||
}
|
|
||||||
|
|||||||
46
backend/app/agent/service_config.py
Normal file
46
backend/app/agent/service_config.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
Agent Service 配置模块 - 配置构建和解析
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Tuple, Optional
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from backend.app.logger import warning
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceConfig:
|
||||||
|
"""配置构建器"""
|
||||||
|
|
||||||
|
def __init__(self, chat_services: dict):
|
||||||
|
self.chat_services = chat_services
|
||||||
|
|
||||||
|
def resolve_model(self, model: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
解析并验证模型名称,不可用时回退到第一个可用模型
|
||||||
|
"""
|
||||||
|
if not model or model not in self.chat_services:
|
||||||
|
fallback = next(iter(self.chat_services.keys()))
|
||||||
|
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
|
||||||
|
return fallback
|
||||||
|
return model
|
||||||
|
|
||||||
|
def build_invocation(
|
||||||
|
self, message: str, thread_id: str, model: str, user_id: str
|
||||||
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
构建图调用所需的 config 和 input_state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(config, input_state) 元组
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
},
|
||||||
|
"metadata": {"user_id": user_id}
|
||||||
|
}
|
||||||
|
|
||||||
|
input_state = {
|
||||||
|
"messages": [HumanMessage(content=message)],
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
return config, input_state
|
||||||
78
backend/app/agent/stream_handler.py
Normal file
78
backend/app/agent/stream_handler.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
流式处理模块 - 处理 Agent 执行的流式输出
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncGenerator, Dict, Any
|
||||||
|
|
||||||
|
from backend.app.logger import info, error
|
||||||
|
from .stream_context import set_stream_queue
|
||||||
|
|
||||||
|
|
||||||
|
async def run_graph_stream(
|
||||||
|
graph,
|
||||||
|
input_state: Dict[str, Any],
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
运行图并通过队列流式输出事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: 编译后的 LangGraph
|
||||||
|
input_state: 输入状态
|
||||||
|
config: 配置
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
流式事件
|
||||||
|
"""
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
set_stream_queue(queue)
|
||||||
|
|
||||||
|
async def run_graph():
|
||||||
|
"""后台任务:运行 graph"""
|
||||||
|
try:
|
||||||
|
info(f"📡 开始调用 graph.astream()...")
|
||||||
|
async for _ in graph.astream(
|
||||||
|
input_state,
|
||||||
|
config=config,
|
||||||
|
stream_mode=["updates"],
|
||||||
|
version="v2",
|
||||||
|
subgraphs=True
|
||||||
|
):
|
||||||
|
# 流式事件都从 agent.py 节点内部通过队列发送了
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
error(f"❌ 执行图时出错: {e}")
|
||||||
|
import traceback
|
||||||
|
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||||
|
await queue.put({"type": "error", "message": str(e)})
|
||||||
|
finally:
|
||||||
|
await queue.put(None) # 结束哨兵
|
||||||
|
|
||||||
|
# 启动后台任务
|
||||||
|
bg_task = asyncio.create_task(run_graph())
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
event = await queue.get()
|
||||||
|
if event is None:
|
||||||
|
break
|
||||||
|
yield event
|
||||||
|
|
||||||
|
except GeneratorExit:
|
||||||
|
info("⚠️ GeneratorExit,取消后台任务")
|
||||||
|
bg_task.cancel()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await _cleanup_task(bg_task)
|
||||||
|
|
||||||
|
|
||||||
|
async def _cleanup_task(bg_task: asyncio.Task) -> None:
|
||||||
|
"""清理后台任务"""
|
||||||
|
if not bg_task.done():
|
||||||
|
info("⏹️ 清理后台任务")
|
||||||
|
bg_task.cancel()
|
||||||
|
try:
|
||||||
|
await bg_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
info("✅ 后台任务已取消")
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
# app/rag_initializer.py
|
|
||||||
from ...rag.tools import create_rag_tool
|
|
||||||
from ...rag.retriever import create_parent_hybrid_retriever
|
|
||||||
from ...model_services import get_embedding_service
|
|
||||||
from backend.app.logger import info, warning
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# 全局 RAG 工具
|
|
||||||
_rag_tool = None
|
|
||||||
_initialized = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_rag_tool() -> callable:
|
|
||||||
"""获取全局 RAG 工具"""
|
|
||||||
return _rag_tool
|
|
||||||
|
|
||||||
|
|
||||||
def is_initialized() -> bool:
|
|
||||||
"""检查是否已初始化"""
|
|
||||||
return _initialized
|
|
||||||
|
|
||||||
|
|
||||||
async def init_rag_tool(force: bool = False):
|
|
||||||
"""
|
|
||||||
初始化 RAG 工具(注册到模块级变量,内部获取所需服务)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
force: 是否强制重新初始化
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RAG 工具(@tool 装饰函数)或 None
|
|
||||||
"""
|
|
||||||
global _rag_tool, _initialized
|
|
||||||
|
|
||||||
# 防止重复初始化
|
|
||||||
if _initialized and not force:
|
|
||||||
info("[RAG] 已初始化,跳过")
|
|
||||||
return _rag_tool
|
|
||||||
|
|
||||||
try:
|
|
||||||
from backend.app.model_services.chat_services import get_chat_service
|
|
||||||
|
|
||||||
info("🔄 正在初始化 RAG 检索系统...")
|
|
||||||
embeddings = get_embedding_service()
|
|
||||||
retriever = create_parent_hybrid_retriever(
|
|
||||||
collection_name="rag_documents",
|
|
||||||
search_k=5,
|
|
||||||
embeddings=embeddings,
|
|
||||||
)
|
|
||||||
rewrite_llm = get_chat_service()
|
|
||||||
|
|
||||||
rag_tool = create_rag_tool(
|
|
||||||
retriever=retriever,
|
|
||||||
llm=rewrite_llm,
|
|
||||||
num_queries=3,
|
|
||||||
rerank_top_n=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
_rag_tool = rag_tool
|
|
||||||
_initialized = True
|
|
||||||
info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})")
|
|
||||||
return rag_tool
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
|
||||||
"""重置(用于测试)"""
|
|
||||||
global _rag_tool, _initialized
|
|
||||||
_rag_tool = None
|
|
||||||
_initialized = False
|
|
||||||
@@ -3,12 +3,11 @@
|
|||||||
Web Search Public Utility - Free, no API Key, using DuckDuckGo
|
Web Search Public Utility - Free, no API Key, using DuckDuckGo
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import requests
|
|
||||||
import warnings
|
from backend.app.logger import info
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -44,47 +43,31 @@ class WebSearchTool:
|
|||||||
"""
|
"""
|
||||||
num_results = max_results or self.max_results
|
num_results = max_results or self.max_results
|
||||||
|
|
||||||
# 方式 1: Tavily (需要 API Key,质量最高)
|
# 尝试搜索方式,按优先级
|
||||||
|
result = self._try_tavily(query, num_results)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
result = self._try_ddgs(query, num_results)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 兜底方案
|
||||||
|
return self._get_mock_results(query, num_results)
|
||||||
|
|
||||||
|
def _try_tavily(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
|
||||||
|
"""尝试 Tavily API 搜索"""
|
||||||
try:
|
try:
|
||||||
return self._search_tavily(query, num_results)
|
return self._search_tavily(query, max_results)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("[WebSearch] tavily 未安装,尝试其他搜索方式")
|
info("[WebSearch] tavily 未安装")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "API_KEY" in str(e) or "未配置" in str(e):
|
error_msg = str(e)
|
||||||
print(f"[WebSearch] Tavily API Key 未配置: {e}")
|
if "API_KEY" in error_msg or "未配置" in error_msg:
|
||||||
|
info(f"[WebSearch] Tavily API Key 未配置")
|
||||||
else:
|
else:
|
||||||
print(f"[WebSearch] Tavily 搜索失败: {e}")
|
info(f"[WebSearch] Tavily 搜索失败: {e}")
|
||||||
|
return None
|
||||||
# 方式 2: 尝试用 ddgs 包
|
|
||||||
try:
|
|
||||||
from ddgs import DDGS
|
|
||||||
print(f"[WebSearch] 使用 ddgs 搜索: {query}")
|
|
||||||
with DDGS() as ddgs:
|
|
||||||
results = list(ddgs.text(query, max_results=num_results))
|
|
||||||
if results:
|
|
||||||
search_results = []
|
|
||||||
for r in results:
|
|
||||||
search_results.append(SearchResult(
|
|
||||||
title=r.get("title", ""),
|
|
||||||
url=r.get("href", ""),
|
|
||||||
snippet=r.get("body", ""),
|
|
||||||
source="DuckDuckGo"
|
|
||||||
))
|
|
||||||
print(f"[WebSearch] ddgs 返回 {len(search_results)} 条结果")
|
|
||||||
return search_results
|
|
||||||
except ImportError:
|
|
||||||
print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WebSearch] ddgs 搜索失败: {e}")
|
|
||||||
|
|
||||||
# 方式 3: 尝试用简单 HTTP 请求
|
|
||||||
try:
|
|
||||||
return self._search_http(query, num_results)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WebSearch] HTTP 搜索也失败: {e}")
|
|
||||||
|
|
||||||
# 方式 4: 返回模拟数据作为最后兜底
|
|
||||||
return self._search_mock(query, num_results)
|
|
||||||
|
|
||||||
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
||||||
"""使用 Tavily API 搜索"""
|
"""使用 Tavily API 搜索"""
|
||||||
@@ -111,56 +94,40 @@ class WebSearchTool:
|
|||||||
source="Tavily"
|
source="Tavily"
|
||||||
))
|
))
|
||||||
|
|
||||||
print(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
info(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _search_http(self, query: str, max_results: int) -> List[SearchResult]:
|
def _try_ddgs(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
|
||||||
"""用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源"""
|
"""尝试 DuckDuckGo 搜索"""
|
||||||
print(f"[WebSearch] 尝试 HTTP 搜索")
|
|
||||||
|
|
||||||
# 方式 1: 尝试百度搜索(简单方式)
|
|
||||||
try:
|
try:
|
||||||
return self._search_baidu(query, max_results)
|
from ddgs import DDGS
|
||||||
except Exception as e:
|
|
||||||
print(f"[WebSearch] 百度搜索失败: {e}")
|
|
||||||
|
|
||||||
# 方式 2: 返回模拟数据
|
|
||||||
return self._search_mock(query, max_results)
|
|
||||||
|
|
||||||
def _search_baidu(self, query: str, max_results: int) -> List[SearchResult]:
|
|
||||||
"""尝试百度搜索"""
|
|
||||||
import requests
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
url = f"https://www.baidu.com/s?wd={quote(query)}"
|
|
||||||
headers = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.get(url, headers=headers, timeout=10)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
# 简单解析百度搜索结果(简化版)
|
|
||||||
results = []
|
results = []
|
||||||
# 这里只是示意,真实百度搜索需要更复杂的解析
|
with DDGS() as ddgs:
|
||||||
|
for r in ddgs.text(query, max_results=max_results):
|
||||||
results.append(SearchResult(
|
results.append(SearchResult(
|
||||||
title=f"百度搜索: {query}",
|
title=r.get("title", ""),
|
||||||
url=url,
|
url=r.get("href", ""),
|
||||||
snippet="如需要真实搜索结果,请考虑使用百度搜索 API",
|
snippet=r.get("body", ""),
|
||||||
source="百度"
|
source="DuckDuckGo"
|
||||||
))
|
))
|
||||||
|
|
||||||
|
if results:
|
||||||
|
info(f"[WebSearch] ddgs 返回 {len(results)} 条结果")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
info("[WebSearch] ddgs 未安装")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[WebSearch] 百度搜索也失败: {e}")
|
info(f"[WebSearch] ddgs 搜索失败: {e}")
|
||||||
raise
|
|
||||||
|
|
||||||
def _search_mock(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
return None
|
||||||
"""模拟搜索结果(兜底方案)"""
|
|
||||||
print(f"[WebSearch] 使用模拟搜索结果 (查询: {query})")
|
|
||||||
|
|
||||||
# 根据查询内容生成更有意义的模拟结果
|
def _get_mock_results(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||||
mock_templates = [
|
"""获取模拟搜索结果(兜底方案)"""
|
||||||
|
info(f"[WebSearch] 使用模拟搜索结果")
|
||||||
|
|
||||||
|
templates = [
|
||||||
{
|
{
|
||||||
"title": f"关于「{query}」的相关介绍",
|
"title": f"关于「{query}」的相关介绍",
|
||||||
"snippet": "这是模拟结果。如需真实搜索,请检查容器网络连接或配置代理。",
|
"snippet": "这是模拟结果。如需真实搜索,请检查容器网络连接或配置代理。",
|
||||||
@@ -181,7 +148,7 @@ class WebSearchTool:
|
|||||||
num = max_results or self.max_results
|
num = max_results or self.max_results
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for i, template in enumerate(mock_templates[:num]):
|
for template in templates[:num]:
|
||||||
results.append(SearchResult(
|
results.append(SearchResult(
|
||||||
title=template["title"],
|
title=template["title"],
|
||||||
url=template["url"],
|
url=template["url"],
|
||||||
@@ -204,8 +171,7 @@ class WebSearchTool:
|
|||||||
if not results:
|
if not results:
|
||||||
return "未找到相关搜索结果"
|
return "未找到相关搜索结果"
|
||||||
|
|
||||||
lines = []
|
lines = ["## 🔍 联网搜索结果\n"]
|
||||||
lines.append("## 🔍 联网搜索结果\n")
|
|
||||||
|
|
||||||
for idx, result in enumerate(results, 1):
|
for idx, result in enumerate(results, 1):
|
||||||
lines.append(f"### [{idx}] {result.title}")
|
lines.append(f"### [{idx}] {result.title}")
|
||||||
@@ -214,7 +180,6 @@ class WebSearchTool:
|
|||||||
lines.append(f"- 📅 时间:{result.timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
|
lines.append(f"- 📅 时间:{result.timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
# 添加引用溯源说明
|
|
||||||
lines.append("---")
|
lines.append("---")
|
||||||
lines.append("💡 **引用溯源说明**:")
|
lines.append("💡 **引用溯源说明**:")
|
||||||
lines.append("- 以上搜索结果均标注了来源链接")
|
lines.append("- 以上搜索结果均标注了来源链接")
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
极简 Agent 主图 - 简化版本!
|
Agent 主图 - 标准 LangGraph 结构
|
||||||
因为完整的 ReAct 循环已经在 agent.py 里了!
|
将 ReAct 循环分离为独立的 agent/tools 节点
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
from backend.app.main_graph.state import AgentState
|
from backend.app.main_graph.state import AgentState
|
||||||
from backend.app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
from backend.app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||||
from backend.app.main_graph.nodes.agent import create_agent_node
|
from backend.app.main_graph.nodes.agent import create_agent_node
|
||||||
|
from backend.app.main_graph.nodes.tools import tools_node
|
||||||
|
from backend.app.main_graph.nodes.finalize import finalize_node
|
||||||
from backend.app.logger import info
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
|
||||||
@@ -16,7 +19,20 @@ def build_agent_graph(
|
|||||||
max_steps: int = 10
|
max_steps: int = 10
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
构建简化的 Agent 图(ReAct 循环在 agent 节点内)
|
构建标准 Agent 图
|
||||||
|
|
||||||
|
节点:
|
||||||
|
- init_state: 初始化状态
|
||||||
|
- memory_trigger: 记忆触发
|
||||||
|
- agent: 单步推理
|
||||||
|
- tools: 工具执行
|
||||||
|
- finalize: 后处理
|
||||||
|
|
||||||
|
边:
|
||||||
|
- START -> init_state -> memory_trigger -> agent
|
||||||
|
- agent -> (条件边) -> tools 或 finalize
|
||||||
|
- tools -> agent (循环)
|
||||||
|
- finalize -> END
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_services: 模型服务字典
|
chat_services: 模型服务字典
|
||||||
@@ -26,57 +42,70 @@ def build_agent_graph(
|
|||||||
Returns:
|
Returns:
|
||||||
构建好的 StateGraph(未编译)
|
构建好的 StateGraph(未编译)
|
||||||
"""
|
"""
|
||||||
# ========== 设置全局客户端 ==========
|
# 设置全局客户端
|
||||||
if mem0_client:
|
if mem0_client:
|
||||||
set_mem0_client(mem0_client)
|
set_mem0_client(mem0_client)
|
||||||
|
|
||||||
# ========== 1. 初始化节点:重置步数 ==========
|
# ========== 1. init_state 节点 ==========
|
||||||
async def init_state_node(state: AgentState):
|
async def init_state_node(state: AgentState):
|
||||||
info("[Init State] 初始化状态,重置步数")
|
info("[Init State] 初始化状态")
|
||||||
return {
|
return {
|
||||||
"current_step": 0,
|
"current_step": 0,
|
||||||
"max_steps": max_steps
|
"max_steps": max_steps,
|
||||||
|
"tool_call_history": [],
|
||||||
|
"tool_result_history": [],
|
||||||
|
"tools_used": [],
|
||||||
|
"stop": False,
|
||||||
|
"stop_reason": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
# ========== 2. 记忆节点(可选) ==========
|
# ========== 2. Agent 节点 ==========
|
||||||
retrieve_memory_node = None
|
|
||||||
if mem0_client:
|
|
||||||
try:
|
|
||||||
from ..nodes.retrieve_memory import create_retrieve_memory_node
|
|
||||||
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
|
||||||
except Exception as e:
|
|
||||||
info(f"[Graph Builder] 记忆节点初始化失败: {e}")
|
|
||||||
|
|
||||||
# ========== 3. Agent 节点(包含完整 ReAct 循环,支持动态模型切换) ==========
|
|
||||||
agent_node_fn = create_agent_node(chat_services)
|
agent_node_fn = create_agent_node(chat_services)
|
||||||
|
|
||||||
# ========== 4. 完成节点 ==========
|
# ========== 3. 条件边函数 ==========
|
||||||
async def finalize_node_simple(state: AgentState):
|
def should_continue(state: AgentState) -> Literal["tools", "finalize"]:
|
||||||
info("[Finalize] 进入完成节点")
|
"""根据 agent 节点输出决定下一步"""
|
||||||
return {}
|
# 手动停止标志
|
||||||
|
if getattr(state, "stop", False):
|
||||||
|
return "finalize"
|
||||||
|
|
||||||
# ========== 5. 构建图 ==========
|
# 检查是否有工具调用
|
||||||
|
last_msg = state.messages[-1]
|
||||||
|
if hasattr(last_msg, 'tool_calls') and last_msg.tool_calls:
|
||||||
|
return "tools"
|
||||||
|
|
||||||
|
return "finalize"
|
||||||
|
|
||||||
|
# ========== 4. 构建图 ==========
|
||||||
graph = StateGraph(AgentState)
|
graph = StateGraph(AgentState)
|
||||||
|
|
||||||
|
# 添加节点
|
||||||
graph.add_node("init_state", init_state_node)
|
graph.add_node("init_state", init_state_node)
|
||||||
if retrieve_memory_node:
|
|
||||||
graph.add_node("retrieve_memory", retrieve_memory_node)
|
|
||||||
graph.add_node("memory_trigger", memory_trigger_node)
|
graph.add_node("memory_trigger", memory_trigger_node)
|
||||||
graph.add_node("agent", agent_node_fn)
|
graph.add_node("agent", agent_node_fn)
|
||||||
graph.add_node("finalize", finalize_node_simple)
|
graph.add_node("tools", tools_node)
|
||||||
|
graph.add_node("finalize", finalize_node)
|
||||||
|
|
||||||
# ========== 6. 边的连接 ==========
|
# 边的连接
|
||||||
graph.add_edge(START, "init_state")
|
graph.add_edge(START, "init_state")
|
||||||
|
|
||||||
if retrieve_memory_node:
|
|
||||||
graph.add_edge("init_state", "retrieve_memory")
|
|
||||||
graph.add_edge("retrieve_memory", "memory_trigger")
|
|
||||||
else:
|
|
||||||
graph.add_edge("init_state", "memory_trigger")
|
graph.add_edge("init_state", "memory_trigger")
|
||||||
|
|
||||||
graph.add_edge("memory_trigger", "agent")
|
graph.add_edge("memory_trigger", "agent")
|
||||||
graph.add_edge("agent", "finalize")
|
|
||||||
|
# 条件边: agent -> tools 或 finalize
|
||||||
|
graph.add_conditional_edges(
|
||||||
|
"agent",
|
||||||
|
should_continue,
|
||||||
|
{
|
||||||
|
"tools": "tools",
|
||||||
|
"finalize": "finalize"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 循环回边: tools -> agent
|
||||||
|
graph.add_edge("tools", "agent")
|
||||||
|
|
||||||
|
# 结束
|
||||||
graph.add_edge("finalize", END)
|
graph.add_edge("finalize", END)
|
||||||
|
|
||||||
info("✅ [Graph Builder] 简化 Agent 图构建完成(ReAct 在节点内)")
|
info("✅ [Graph Builder] 标准 Agent 图构建完成")
|
||||||
return graph
|
return graph
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
Agent 节点 - 简化版本
|
Agent 节点 - 简化版本(单步推理)
|
||||||
直接定义 agent_node 函数,支持动态模型切换和循环检测
|
只负责一次 LLM 调用,不执行工具
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import json
|
||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage
|
||||||
|
|
||||||
from backend.app.main_graph.state import AgentState
|
from backend.app.main_graph.state import AgentState
|
||||||
from backend.app.logger import info, error
|
from backend.app.logger import info, error, debug
|
||||||
from backend.app.tools import ALL_TOOLS
|
from backend.app.tools import ALL_TOOLS
|
||||||
from backend.app.agent.stream_context import get_stream_queue
|
from backend.app.agent.stream_context import get_stream_queue
|
||||||
from backend.app.agent.prompts import SYSTEM_PROMPT
|
from backend.app.agent.prompts import SYSTEM_PROMPT
|
||||||
@@ -28,7 +28,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
|||||||
latest = results[-1]
|
latest = results[-1]
|
||||||
prev = results[-2]
|
prev = results[-2]
|
||||||
|
|
||||||
# 长度差异太大,不算相似
|
|
||||||
if len(latest) == 0 or len(prev) == 0:
|
if len(latest) == 0 or len(prev) == 0:
|
||||||
return len(latest) == len(prev)
|
return len(latest) == len(prev)
|
||||||
|
|
||||||
@@ -36,7 +35,6 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
|||||||
if len_ratio < 0.5:
|
if len_ratio < 0.5:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查内容重复度(简单:前100字符)
|
|
||||||
common_len = 0
|
common_len = 0
|
||||||
for a, b in zip(latest[:100], prev[:100]):
|
for a, b in zip(latest[:100], prev[:100]):
|
||||||
if a == b:
|
if a == b:
|
||||||
@@ -50,27 +48,23 @@ def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool:
|
|||||||
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
|
def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool:
|
||||||
"""
|
"""
|
||||||
检测是否应该停止(循环检测)
|
检测是否应该停止(循环检测)
|
||||||
|
|
||||||
条件:连续2次调用相同工具 + 参数相似 + 结果相似
|
条件:连续2次调用相同工具 + 参数相似 + 结果相似
|
||||||
"""
|
"""
|
||||||
if len(tool_calls) < 2:
|
if len(tool_calls) < 2:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查最近的工具调用是否相同
|
|
||||||
last_tc = tool_calls[-1]
|
last_tc = tool_calls[-1]
|
||||||
prev_tc = tool_calls[-2]
|
prev_tc = tool_calls[-2]
|
||||||
|
|
||||||
if last_tc["name"] != prev_tc["name"]:
|
if last_tc["name"] != prev_tc["name"]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 参数是否相似
|
|
||||||
last_args = _normalize_args(last_tc["args"])
|
last_args = _normalize_args(last_tc["args"])
|
||||||
prev_args = _normalize_args(prev_tc["args"])
|
prev_args = _normalize_args(prev_tc["args"])
|
||||||
|
|
||||||
if last_args != prev_args:
|
if last_args != prev_args:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 结果是否相似
|
|
||||||
if len(tool_results) >= 2:
|
if len(tool_results) >= 2:
|
||||||
return _is_similar_result(tool_results[-2:])
|
return _is_similar_result(tool_results[-2:])
|
||||||
|
|
||||||
@@ -79,22 +73,43 @@ def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bo
|
|||||||
|
|
||||||
def create_agent_node(chat_services: dict):
|
def create_agent_node(chat_services: dict):
|
||||||
"""
|
"""
|
||||||
创建 Agent 节点 - 支持动态模型切换
|
创建 Agent 节点 - 单步推理版本
|
||||||
|
|
||||||
简化设计:
|
设计:
|
||||||
- 直接返回 async 函数,无需工厂包装
|
- 只做一次 LLM 调用
|
||||||
- 从 config 中获取模型名称,运行时动态切换
|
- 不执行工具(工具执行由 tools 节点负责)
|
||||||
|
- 返回 AIMessage(可能包含 tool_calls)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||||
"""Agent 节点:完整的 ReAct 循环"""
|
"""Agent 节点:单步 LLM 调用"""
|
||||||
queue = get_stream_queue()
|
queue = get_stream_queue()
|
||||||
is_streaming = queue is not None
|
is_streaming = queue is not None
|
||||||
|
|
||||||
# 获取步数
|
# 获取步数
|
||||||
current_step = getattr(state, "current_step", 0)
|
current_step = getattr(state, "current_step", 0)
|
||||||
max_steps = getattr(state, "max_steps", 10)
|
max_steps = getattr(state, "max_steps", 10)
|
||||||
info(f"[Agent] 从第 {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}")
|
info(f"[Agent] 第 {current_step + 1} 步开始")
|
||||||
|
|
||||||
|
# 步数已达上限
|
||||||
|
if current_step >= max_steps:
|
||||||
|
info("[Agent] 达到步数上限,强制结束")
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content="[系统] 已达到最大步数限制。")],
|
||||||
|
"stop": True,
|
||||||
|
"stop_reason": "max_steps",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 循环检测
|
||||||
|
tool_history = getattr(state, "tool_call_history", [])
|
||||||
|
result_history = getattr(state, "tool_result_history", [])
|
||||||
|
if _should_stop_for_loop(tool_history, result_history):
|
||||||
|
info("[Agent] 检测到循环,终止推理")
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")],
|
||||||
|
"stop": True,
|
||||||
|
"stop_reason": "loop_detected",
|
||||||
|
}
|
||||||
|
|
||||||
# 动态获取模型
|
# 动态获取模型
|
||||||
model_name = "primary"
|
model_name = "primary"
|
||||||
@@ -111,22 +126,15 @@ def create_agent_node(chat_services: dict):
|
|||||||
|
|
||||||
# 获取记忆上下文
|
# 获取记忆上下文
|
||||||
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
|
memory_context = getattr(state, "memory_context", "暂无用户背景信息")
|
||||||
|
|
||||||
# 组装消息(注入记忆上下文到提示词)
|
|
||||||
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
|
prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context)
|
||||||
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
|
messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages)
|
||||||
turn = current_step
|
|
||||||
|
|
||||||
try:
|
|
||||||
while turn < max_steps:
|
|
||||||
turn += 1
|
|
||||||
info(f"[Agent] 第 {turn} 轮思考")
|
|
||||||
|
|
||||||
|
# 发送节点开始事件
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
await queue.put({"type": "node_start", "node": "agent"})
|
await queue.put({"type": "node_start", "node": "agent"})
|
||||||
|
|
||||||
# 选择 LLM(最后一轮不带工具)
|
# 选择 LLM(最后一轮不带工具)
|
||||||
if turn >= max_steps:
|
if current_step + 1 >= max_steps:
|
||||||
current_llm = llm.bind_tools([])
|
current_llm = llm.bind_tools([])
|
||||||
info(f"[Agent] 达到步数上限,使用无工具模型")
|
info(f"[Agent] 达到步数上限,使用无工具模型")
|
||||||
else:
|
else:
|
||||||
@@ -138,10 +146,7 @@ def create_agent_node(chat_services: dict):
|
|||||||
pending_tool_calls = {}
|
pending_tool_calls = {}
|
||||||
final_tool_calls = []
|
final_tool_calls = []
|
||||||
|
|
||||||
# 循环检测:记录历史调用
|
try:
|
||||||
tool_call_history: List[dict] = []
|
|
||||||
tool_result_history: List[str] = []
|
|
||||||
|
|
||||||
# 调用 LLM
|
# 调用 LLM
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
async for chunk in current_llm.astream(messages):
|
async for chunk in current_llm.astream(messages):
|
||||||
@@ -181,7 +186,6 @@ def create_agent_node(chat_services: dict):
|
|||||||
if isinstance(args_val, str):
|
if isinstance(args_val, str):
|
||||||
pending_tool_calls[idx]["args"] += args_val
|
pending_tool_calls[idx]["args"] += args_val
|
||||||
else:
|
else:
|
||||||
import json
|
|
||||||
pending_tool_calls[idx]["args"] += json.dumps(args_val)
|
pending_tool_calls[idx]["args"] += json.dumps(args_val)
|
||||||
else:
|
else:
|
||||||
result = await current_llm.ainvoke(messages)
|
result = await current_llm.ainvoke(messages)
|
||||||
@@ -199,7 +203,6 @@ def create_agent_node(chat_services: dict):
|
|||||||
args = {}
|
args = {}
|
||||||
if tc_data["args"]:
|
if tc_data["args"]:
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
args = json.loads(tc_data["args"])
|
args = json.loads(tc_data["args"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
info(f"[Agent] 解析参数失败: {e}")
|
info(f"[Agent] 解析参数失败: {e}")
|
||||||
@@ -209,62 +212,9 @@ def create_agent_node(chat_services: dict):
|
|||||||
"args": args
|
"args": args
|
||||||
})
|
})
|
||||||
|
|
||||||
# 执行工具
|
# 发送节点结束事件
|
||||||
if final_tool_calls:
|
|
||||||
info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具")
|
|
||||||
new_messages = []
|
|
||||||
|
|
||||||
for tc in final_tool_calls:
|
|
||||||
tool_name = tc["name"]
|
|
||||||
tool_args = tc["args"]
|
|
||||||
tool_id = tc["id"]
|
|
||||||
|
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
await queue.put({
|
await queue.put({"type": "node_end", "node": "agent"})
|
||||||
"type": "custom",
|
|
||||||
"data": {"type": "tool_start", "tool": tool_name, "args": tool_args, "id": tool_id}
|
|
||||||
})
|
|
||||||
|
|
||||||
# 查找并执行工具
|
|
||||||
tool_result = ""
|
|
||||||
tool_found = False
|
|
||||||
for tool in ALL_TOOLS:
|
|
||||||
if tool.name == tool_name:
|
|
||||||
tool_found = True
|
|
||||||
try:
|
|
||||||
tool_result = await tool.ainvoke(tool_args)
|
|
||||||
except Exception as e:
|
|
||||||
tool_result = f"工具调用出错: {str(e)}"
|
|
||||||
error(f"[Agent] 工具 {tool_name} 调用出错: {e}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not tool_found:
|
|
||||||
tool_result = f"未找到工具: {tool_name}"
|
|
||||||
|
|
||||||
if is_streaming:
|
|
||||||
await queue.put({
|
|
||||||
"type": "custom",
|
|
||||||
"data": {"type": "tool_end", "tool": tool_name, "id": tool_id, "result": str(tool_result)}
|
|
||||||
})
|
|
||||||
|
|
||||||
# 记录历史(用于循环检测)
|
|
||||||
tool_call_history.append({"name": tool_name, "args": tool_args})
|
|
||||||
tool_result_history.append(str(tool_result))
|
|
||||||
|
|
||||||
new_messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name))
|
|
||||||
|
|
||||||
# 循环检测:相同工具 + 相似参数 + 相似结果 → 终止
|
|
||||||
if _should_stop_for_loop(tool_call_history, tool_result_history):
|
|
||||||
info(f"[Agent] ⚠️ 检测到循环,强制终止")
|
|
||||||
# 添加一条终止消息
|
|
||||||
messages.append(AIMessage(content="[系统] 检测到工具调用循环,已终止。"))
|
|
||||||
break
|
|
||||||
|
|
||||||
messages.extend(new_messages)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
info(f"[Agent] 第 {turn} 轮:完成,无工具调用")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
response_kwargs = {"content": full_content}
|
response_kwargs = {"content": full_content}
|
||||||
@@ -274,14 +224,15 @@ def create_agent_node(chat_services: dict):
|
|||||||
if full_reasoning_content:
|
if full_reasoning_content:
|
||||||
response.additional_kwargs["reasoning_content"] = full_reasoning_content
|
response.additional_kwargs["reasoning_content"] = full_reasoning_content
|
||||||
|
|
||||||
|
info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [response],
|
"messages": [response],
|
||||||
"current_step": turn,
|
|
||||||
"llm_calls": getattr(state, "llm_calls", 0) + 1
|
"llm_calls": getattr(state, "llm_calls", 0) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"[Agent] ❌ 第 {turn} 轮出错: {e}")
|
error(f"[Agent] 执行出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
error(f"[Agent] 堆栈: {traceback.format_exc()}")
|
error(f"[Agent] 堆栈: {traceback.format_exc()}")
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
|
|||||||
54
backend/app/main_graph/nodes/finalize.py
Normal file
54
backend/app/main_graph/nodes/finalize.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""
|
||||||
|
Finalize 节点 - 轻量后处理
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from backend.app.main_graph.state import AgentState
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
|
||||||
|
async def finalize_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Finalize 节点:轻量后处理
|
||||||
|
|
||||||
|
职责:
|
||||||
|
1. 从 messages 中提取最后一条 AIMessage.content 作为最终回复
|
||||||
|
2. 汇总元数据:步数、使用的工具、停止原因
|
||||||
|
3. 如果 final_reply 为空且有 stop_reason,生成说明文本
|
||||||
|
"""
|
||||||
|
info("[Finalize] 执行后处理")
|
||||||
|
|
||||||
|
# 获取最终回复
|
||||||
|
final_reply = ""
|
||||||
|
for msg in reversed(state.messages):
|
||||||
|
if isinstance(msg, AIMessage) and msg.content:
|
||||||
|
final_reply = msg.content
|
||||||
|
break
|
||||||
|
|
||||||
|
# 获取停止原因
|
||||||
|
stop_reason = getattr(state, "stop_reason", "")
|
||||||
|
if not final_reply and stop_reason:
|
||||||
|
# 如果没有回复但有停止原因,生成说明
|
||||||
|
reason_messages = {
|
||||||
|
"loop_detected": "[系统] 检测到工具调用循环,已终止。",
|
||||||
|
"max_steps": "[系统] 已达到最大步数限制。",
|
||||||
|
}
|
||||||
|
final_reply = reason_messages.get(stop_reason, f"[系统] 已终止,原因: {stop_reason}")
|
||||||
|
|
||||||
|
# 汇总元数据
|
||||||
|
metadata = {
|
||||||
|
"steps_taken": getattr(state, "current_step", 0),
|
||||||
|
"tools_used": getattr(state, "tools_used", []),
|
||||||
|
"stop_reason": stop_reason,
|
||||||
|
"llm_calls": getattr(state, "llm_calls", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
info(f"[Finalize] 完成 - steps: {metadata['steps_taken']}, tools: {len(metadata['tools_used'])}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"final_reply": final_reply,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
113
backend/app/main_graph/nodes/tools.py
Normal file
113
backend/app/main_graph/nodes/tools.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
Tools 节点 - 负责执行 tool_calls
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from backend.app.main_graph.state import AgentState
|
||||||
|
from backend.app.logger import info, error, debug
|
||||||
|
from backend.app.tools import ALL_TOOLS
|
||||||
|
from backend.app.agent.stream_context import get_stream_queue
|
||||||
|
|
||||||
|
|
||||||
|
async def tools_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Tools 节点:执行 AIMessage.tool_calls,返回 ToolMessage 列表
|
||||||
|
|
||||||
|
职责:
|
||||||
|
1. 获取最后一条 AIMessage 的 tool_calls
|
||||||
|
2. 遍历执行每个工具
|
||||||
|
3. 记录历史(tool_call_history, tool_result_history)
|
||||||
|
4. 更新步数 current_step += 1
|
||||||
|
5. 发送工具开始/结束事件
|
||||||
|
"""
|
||||||
|
queue = get_stream_queue()
|
||||||
|
is_streaming = queue is not None
|
||||||
|
|
||||||
|
# 获取最后一条 AIMessage
|
||||||
|
last_message = state.messages[-1]
|
||||||
|
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||||||
|
info("[Tools] 没有工具调用,跳过")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
tool_calls = last_message.tool_calls
|
||||||
|
info(f"[Tools] 执行 {len(tool_calls)} 个工具调用")
|
||||||
|
|
||||||
|
# 获取历史记录
|
||||||
|
tool_call_history: List[dict] = list(getattr(state, "tool_call_history", []))
|
||||||
|
tool_result_history: List[str] = list(getattr(state, "tool_result_history", []))
|
||||||
|
tools_used: List[str] = list(getattr(state, "tools_used", []))
|
||||||
|
|
||||||
|
tool_messages = []
|
||||||
|
|
||||||
|
for tc in tool_calls:
|
||||||
|
tool_name = tc["name"]
|
||||||
|
tool_args = tc["args"]
|
||||||
|
tool_id = tc["id"]
|
||||||
|
|
||||||
|
tools_used.append(tool_name)
|
||||||
|
|
||||||
|
# 发送工具开始事件
|
||||||
|
if is_streaming:
|
||||||
|
await queue.put({
|
||||||
|
"type": "custom",
|
||||||
|
"data": {
|
||||||
|
"type": "tool_start",
|
||||||
|
"tool": tool_name,
|
||||||
|
"args": tool_args,
|
||||||
|
"id": tool_id
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 查找并执行工具
|
||||||
|
tool_result = ""
|
||||||
|
tool_found = False
|
||||||
|
|
||||||
|
for tool in ALL_TOOLS:
|
||||||
|
if tool.name == tool_name:
|
||||||
|
tool_found = True
|
||||||
|
try:
|
||||||
|
tool_result = await tool.ainvoke(tool_args)
|
||||||
|
debug(f"[Tools] 工具 {tool_name} 执行成功")
|
||||||
|
except Exception as e:
|
||||||
|
tool_result = f"工具调用出错: {str(e)}"
|
||||||
|
error(f"[Tools] 工具 {tool_name} 调用出错: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not tool_found:
|
||||||
|
tool_result = f"未找到工具: {tool_name}"
|
||||||
|
error(f"[Tools] 未找到工具: {tool_name}")
|
||||||
|
|
||||||
|
# 发送工具结束事件
|
||||||
|
if is_streaming:
|
||||||
|
await queue.put({
|
||||||
|
"type": "custom",
|
||||||
|
"data": {
|
||||||
|
"type": "tool_end",
|
||||||
|
"tool": tool_name,
|
||||||
|
"id": tool_id,
|
||||||
|
"result": str(tool_result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 记录历史
|
||||||
|
tool_call_history.append({"name": tool_name, "args": tool_args})
|
||||||
|
tool_result_history.append(str(tool_result))
|
||||||
|
|
||||||
|
# 创建 ToolMessage
|
||||||
|
tool_messages.append(
|
||||||
|
ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新步数
|
||||||
|
current_step = getattr(state, "current_step", 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": tool_messages,
|
||||||
|
"current_step": current_step,
|
||||||
|
"tool_call_history": tool_call_history,
|
||||||
|
"tool_result_history": tool_result_history,
|
||||||
|
"tools_used": tools_used,
|
||||||
|
}
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
- 统计:llm_calls, last_token_usage, last_elapsed_time
|
- 统计:llm_calls, last_token_usage, last_elapsed_time
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Sequence, Optional, Dict, Any
|
from typing import Annotated, Sequence, Optional, Dict, Any, List
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from langgraph.graph import add_messages
|
from langgraph.graph import add_messages
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
@@ -35,3 +35,14 @@ class AgentState:
|
|||||||
llm_calls: int = 0
|
llm_calls: int = 0
|
||||||
last_token_usage: Dict[str, Any] = field(default_factory=dict)
|
last_token_usage: Dict[str, Any] = field(default_factory=dict)
|
||||||
last_elapsed_time: float = 0.0
|
last_elapsed_time: float = 0.0
|
||||||
|
|
||||||
|
# ========== 新增字段: 工具调用历史 ==========
|
||||||
|
tool_call_history: List[dict] = field(default_factory=list)
|
||||||
|
tool_result_history: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# ========== 新增字段: 停止控制 ==========
|
||||||
|
stop: bool = False
|
||||||
|
stop_reason: str = ""
|
||||||
|
|
||||||
|
# ========== 新增字段: 本轮使用的工具 ==========
|
||||||
|
tools_used: List[str] = field(default_factory=list)
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
"""
|
|
||||||
RAG 工具模块(完全异步)
|
|
||||||
|
|
||||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
|
||||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
|
||||||
|
|
||||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
|
||||||
"""
|
|
||||||
from typing import Callable, Optional
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
from ..rag.pipeline import RAGPipeline, create_rag_pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def create_rag_tool(
|
|
||||||
retriever: Optional[BaseRetriever] = None,
|
|
||||||
llm: Optional[BaseLanguageModel] = "default_small",
|
|
||||||
num_queries: int = 3,
|
|
||||||
rerank_top_n: int = 5,
|
|
||||||
collection_name: str = "rag_documents",
|
|
||||||
) -> Callable:
|
|
||||||
"""
|
|
||||||
创建一个配置好的 RAG 检索工具(完全异步)。
|
|
||||||
|
|
||||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
retriever: 基础检索器对象(可选,不提供则自动创建)
|
|
||||||
llm: 用于生成多路查询的语言模型。
|
|
||||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
|
||||||
- None / False: 不做查询改写
|
|
||||||
- BaseLanguageModel 实例: 自定义模型
|
|
||||||
num_queries: 生成的查询变体数量
|
|
||||||
rerank_top_n: 最终返回的文档数量
|
|
||||||
collection_name: Qdrant 集合名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Async LangChain Tool 函数
|
|
||||||
"""
|
|
||||||
pipeline = RAGPipeline(
|
|
||||||
retriever=retriever,
|
|
||||||
llm=llm,
|
|
||||||
num_queries=num_queries,
|
|
||||||
rerank_top_n=rerank_top_n,
|
|
||||||
collection_name=collection_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def search_knowledge_base(query: str) -> str:
|
|
||||||
"""
|
|
||||||
在知识库中搜索与查询相关的文档片段(完全异步)。
|
|
||||||
|
|
||||||
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
|
||||||
检索效果最优。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户提出的问题或查询字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
格式化后的相关文档内容
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
documents = await pipeline.aretrieve(query)
|
|
||||||
if not documents:
|
|
||||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
|
||||||
|
|
||||||
context = pipeline.format_context(documents)
|
|
||||||
return context
|
|
||||||
except Exception as e:
|
|
||||||
return f"检索过程中发生错误: {str(e)}"
|
|
||||||
|
|
||||||
return search_knowledge_base
|
|
||||||
@@ -2,124 +2,17 @@
|
|||||||
Agent Tools - 所有工具统一定义
|
Agent Tools - 所有工具统一定义
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from .rag import rag_search
|
||||||
from backend.app.logger import info
|
from .web_search import web_search
|
||||||
|
from .subgraph import contact_lookup, dictionary_lookup, news_analysis
|
||||||
# ========== RAG ==========
|
|
||||||
|
|
||||||
_rag_pipeline = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_rag_pipeline():
|
|
||||||
global _rag_pipeline
|
|
||||||
if _rag_pipeline is None:
|
|
||||||
from backend.app.rag.pipeline import RAGPipeline
|
|
||||||
_rag_pipeline = RAGPipeline(
|
|
||||||
num_queries=3,
|
|
||||||
rerank_top_n=5,
|
|
||||||
use_rerank=True,
|
|
||||||
return_parent_docs=True,
|
|
||||||
)
|
|
||||||
return _rag_pipeline
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def rag_search(query: str) -> str:
|
|
||||||
"""
|
|
||||||
检索知识库获取相关信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含检索结果和置信度的结构化回复,格式:
|
|
||||||
- 内容:检索到的相关信息
|
|
||||||
- 置信度评估:基于向量相似度、重排分数、LLM判断的综合评分
|
|
||||||
"""
|
|
||||||
info(f"[Tool] rag_search: {query[:30]}...")
|
|
||||||
try:
|
|
||||||
pipeline = _get_rag_pipeline()
|
|
||||||
# 使用带置信度的检索
|
|
||||||
result = await pipeline.aretrieve_with_confidence(query, original_query=query)
|
|
||||||
|
|
||||||
if not result.content:
|
|
||||||
return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度:0.0\n建议:可尝试联网搜索获取信息。"
|
|
||||||
|
|
||||||
# 构建包含置信度的回复
|
|
||||||
confidence_desc = "高"
|
|
||||||
if result.confidence < 0.4:
|
|
||||||
confidence_desc = "低"
|
|
||||||
elif result.confidence < 0.6:
|
|
||||||
confidence_desc = "中"
|
|
||||||
|
|
||||||
response = f"""【RAG检索结果】
|
|
||||||
{result.content}
|
|
||||||
|
|
||||||
【置信度评估】
|
|
||||||
- 综合置信度:{result.confidence:.2f}({confidence_desc})
|
|
||||||
- 向量相似度:{result.scores['embedding']:.2f}
|
|
||||||
- 重排分数:{result.scores['rerank']:.2f}
|
|
||||||
- LLM评估:{result.scores['llm']:.2f}
|
|
||||||
|
|
||||||
{'✅ 检索结果可信,可直接使用' if result.is_useful else '⚠️ 检索结果置信度较低,可能需要联网搜索补充'}"""
|
|
||||||
|
|
||||||
info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}")
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
info(f"[Tool] rag_search 失败: {e}")
|
|
||||||
return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 联网搜索 ==========
|
|
||||||
|
|
||||||
@tool
|
|
||||||
def web_search(query: str) -> str:
|
|
||||||
"""联网搜索获取最新信息"""
|
|
||||||
info(f"[Tool] web_search: {query[:30]}...")
|
|
||||||
try:
|
|
||||||
from backend.app.core.web_search import web_search as search_fn
|
|
||||||
return search_fn(query, max_results=5)
|
|
||||||
except Exception as e:
|
|
||||||
info(f"[Tool] web_search 失败: {e}")
|
|
||||||
return f"联网搜索失败: {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 子图工具 ==========
|
|
||||||
|
|
||||||
async def _call_subgraph(builder_fn, state_cls, query: str) -> str:
|
|
||||||
"""通用子图调用"""
|
|
||||||
try:
|
|
||||||
graph = builder_fn().compile()
|
|
||||||
state = state_cls(user_query=query)
|
|
||||||
result = await graph.ainvoke(state)
|
|
||||||
return result.get("final_result", "执行完成")
|
|
||||||
except Exception as e:
|
|
||||||
info(f"[Tool] 子图调用失败: {e}")
|
|
||||||
return f"执行失败: {str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def contact_lookup(query: str) -> str:
|
|
||||||
"""查询通讯录"""
|
|
||||||
from backend.app.subgraphs.contact.graph import build_contact_subgraph
|
|
||||||
from backend.app.subgraphs.contact.state import ContactState
|
|
||||||
return await _call_subgraph(build_contact_subgraph, ContactState, query)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def dictionary_lookup(word: str) -> str:
|
|
||||||
"""查询词典/翻译"""
|
|
||||||
from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph
|
|
||||||
from backend.app.subgraphs.dictionary.state import DictionaryState
|
|
||||||
return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def news_analysis(topic: str) -> str:
|
|
||||||
"""分析新闻热点"""
|
|
||||||
from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph
|
|
||||||
from backend.app.subgraphs.news_analysis.state import NewsAnalysisState
|
|
||||||
return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic)
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 导出 ==========
|
|
||||||
|
|
||||||
ALL_TOOLS = [rag_search, web_search, contact_lookup, dictionary_lookup, news_analysis]
|
ALL_TOOLS = [rag_search, web_search, contact_lookup, dictionary_lookup, news_analysis]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"rag_search",
|
||||||
|
"web_search",
|
||||||
|
"contact_lookup",
|
||||||
|
"dictionary_lookup",
|
||||||
|
"news_analysis",
|
||||||
|
"ALL_TOOLS",
|
||||||
|
]
|
||||||
|
|||||||
8
backend/app/tools/base.py
Normal file
8
backend/app/tools/base.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
工具模块配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
__all__ = ["tool", "info"]
|
||||||
70
backend/app/tools/rag.py
Normal file
70
backend/app/tools/rag.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""
|
||||||
|
RAG 检索工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
|
||||||
|
_rag_pipeline = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rag_pipeline():
|
||||||
|
"""获取或创建 RAG pipeline 单例"""
|
||||||
|
global _rag_pipeline
|
||||||
|
if _rag_pipeline is None:
|
||||||
|
from backend.app.rag.pipeline import RAGPipeline
|
||||||
|
_rag_pipeline = RAGPipeline(
|
||||||
|
num_queries=3,
|
||||||
|
rerank_top_n=5,
|
||||||
|
use_rerank=True,
|
||||||
|
return_parent_docs=True,
|
||||||
|
)
|
||||||
|
return _rag_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def _format_confidence(result) -> str:
|
||||||
|
"""格式化置信度描述"""
|
||||||
|
if result.confidence < 0.4:
|
||||||
|
return "低"
|
||||||
|
elif result.confidence < 0.6:
|
||||||
|
return "中"
|
||||||
|
return "高"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def rag_search(query: str) -> str:
|
||||||
|
"""
|
||||||
|
检索知识库获取相关信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含检索结果和置信度的结构化回复
|
||||||
|
"""
|
||||||
|
info(f"[Tool] rag_search: {query[:30]}...")
|
||||||
|
try:
|
||||||
|
pipeline = _get_rag_pipeline()
|
||||||
|
result = await pipeline.aretrieve_with_confidence(query, original_query=query)
|
||||||
|
|
||||||
|
if not result.content:
|
||||||
|
return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度:0.0\n建议:可尝试联网搜索获取信息。"
|
||||||
|
|
||||||
|
confidence_desc = _format_confidence(result)
|
||||||
|
is_useful_note = "✅ 检索结果可信,可直接使用" if result.is_useful else "⚠️ 检索结果置信度较低,可能需要联网搜索补充"
|
||||||
|
|
||||||
|
response = f"""【RAG检索结果】
|
||||||
|
{result.content}
|
||||||
|
|
||||||
|
【置信度评估】
|
||||||
|
- 综合置信度:{result.confidence:.2f}({confidence_desc})
|
||||||
|
- 向量相似度:{result.scores['embedding']:.2f}
|
||||||
|
- 重排分数:{result.scores['rerank']:.2f}
|
||||||
|
- LLM评估:{result.scores['llm']:.2f}
|
||||||
|
|
||||||
|
{is_useful_note}"""
|
||||||
|
|
||||||
|
info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[Tool] rag_search 失败: {e}")
|
||||||
|
return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索"
|
||||||
42
backend/app/tools/subgraph.py
Normal file
42
backend/app/tools/subgraph.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
子图工具 - 通讯录、词典、新闻分析等
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_subgraph(builder_fn, state_cls, query: str) -> str:
|
||||||
|
"""通用子图调用"""
|
||||||
|
try:
|
||||||
|
graph = builder_fn().compile()
|
||||||
|
state = state_cls(user_query=query)
|
||||||
|
result = await graph.ainvoke(state)
|
||||||
|
return result.get("final_result", "执行完成")
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[Tool] 子图调用失败: {e}")
|
||||||
|
return f"执行失败: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def contact_lookup(query: str) -> str:
|
||||||
|
"""查询通讯录"""
|
||||||
|
from backend.app.subgraphs.contact.graph import build_contact_subgraph
|
||||||
|
from backend.app.subgraphs.contact.state import ContactState
|
||||||
|
return await _call_subgraph(build_contact_subgraph, ContactState, query)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def dictionary_lookup(word: str) -> str:
|
||||||
|
"""查询词典/翻译"""
|
||||||
|
from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph
|
||||||
|
from backend.app.subgraphs.dictionary.state import DictionaryState
|
||||||
|
return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def news_analysis(topic: str) -> str:
|
||||||
|
"""分析新闻热点"""
|
||||||
|
from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph
|
||||||
|
from backend.app.subgraphs.news_analysis.state import NewsAnalysisState
|
||||||
|
return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic)
|
||||||
18
backend/app/tools/web_search.py
Normal file
18
backend/app/tools/web_search.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
联网搜索工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from backend.app.logger import info
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def web_search(query: str) -> str:
|
||||||
|
"""联网搜索获取最新信息"""
|
||||||
|
info(f"[Tool] web_search: {query[:30]}...")
|
||||||
|
try:
|
||||||
|
from backend.app.core.web_search import web_search as search_fn
|
||||||
|
return search_fn(query, max_results=5)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[Tool] web_search 失败: {e}")
|
||||||
|
return f"联网搜索失败: {str(e)}"
|
||||||
199
docs/superpowers/specs/2026-05-08-react-separation-design.md
Normal file
199
docs/superpowers/specs/2026-05-08-react-separation-design.md
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# ReAct 循环分离设计
|
||||||
|
|
||||||
|
**日期**: 2026-05-08
|
||||||
|
**状态**: 已批准
|
||||||
|
**目标**: 将 "fat node" 中的 ReAct 循环拆分为独立的 agent/tools 图节点
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 目标
|
||||||
|
|
||||||
|
将当前 `agent.py` 中的"厚节点"(包含完整 while 循环、工具执行、循环检测)重构为标准 LangGraph 图架构:
|
||||||
|
- **推理节点 (agent)**:只负责单步 LLM 调用
|
||||||
|
- **工具节点 (tools)**:只负责执行工具
|
||||||
|
- **条件边**:控制循环逻辑
|
||||||
|
- **finalize 节点**:轻量后处理
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 新图结构
|
||||||
|
|
||||||
|
```
|
||||||
|
START
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
init_state ──→ memory_trigger ──→ agent ──┬──→ (条件边) ──→ tools ──→ agent (循环)
|
||||||
|
│ │ ▲
|
||||||
|
│ └──────────────┘
|
||||||
|
│ (无工具调用时)
|
||||||
|
▼
|
||||||
|
finalize
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
END
|
||||||
|
```
|
||||||
|
|
||||||
|
### 节点职责
|
||||||
|
|
||||||
|
| 节点 | 职责 | 输入 | 输出 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| `init_state` | 初始化状态,重置步数 | 无 | `{current_step: 0, max_steps: N}` |
|
||||||
|
| `memory_trigger` | 检测记忆指令,触发 Mem0 存储 | AgentState | 无修改 |
|
||||||
|
| `agent` | 单步 LLM 调用,输出 AIMessage | AgentState | `{messages: [AIMessage], ...}` |
|
||||||
|
| `tools` | 执行 tool_calls,返回 ToolMessage | AgentState | `{messages: [ToolMessage], current_step: N+1, ...}` |
|
||||||
|
| `finalize` | 轻量后处理 | AgentState | `{final_reply: str, metadata: {...}}` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 状态定义
|
||||||
|
|
||||||
|
在 `AgentState` 中新增字段:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class AgentState:
|
||||||
|
# 现有字段保留...
|
||||||
|
|
||||||
|
# 新增字段
|
||||||
|
tool_call_history: List[dict] = field(default_factory=list) # [{"name": "...", "args": {...}}]
|
||||||
|
tool_result_history: List[str] = field(default_factory=list) # ["结果1", "结果2", ...]
|
||||||
|
stop: bool = False # 手动停止标志
|
||||||
|
stop_reason: str = "" # 停止原因
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 节点实现
|
||||||
|
|
||||||
|
### 4.1 agent 节点
|
||||||
|
|
||||||
|
**职责**:单步 LLM 调用,不执行工具
|
||||||
|
|
||||||
|
**流程**:
|
||||||
|
1. 步数检查(已达上限则用无工具模型)
|
||||||
|
2. 循环检测(检测到异常则设置 `stop=True`)
|
||||||
|
3. 调用 LLM(带工具绑定)
|
||||||
|
4. 流式推送 token
|
||||||
|
|
||||||
|
**返回值**:
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"messages": [AIMessage(content=..., tool_calls=[...])],
|
||||||
|
"stop": bool,
|
||||||
|
"stop_reason": str,
|
||||||
|
"llm_calls": int + 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 tools 节点
|
||||||
|
|
||||||
|
**职责**:执行 `AIMessage.tool_calls`,生成 `ToolMessage`
|
||||||
|
|
||||||
|
**流程**:
|
||||||
|
1. 获取最后一条 `AIMessage`
|
||||||
|
2. 提取 `tool_calls` 列表
|
||||||
|
3. 遍历执行每个工具
|
||||||
|
4. 记录历史(tool_call_history, tool_result_history)
|
||||||
|
5. 更新步数 `current_step += 1`
|
||||||
|
6. 发送工具开始/结束事件(非 token 流)
|
||||||
|
|
||||||
|
**返回值**:
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"messages": [ToolMessage(...), ...],
|
||||||
|
"current_step": current_step + 1,
|
||||||
|
"tool_call_history": [...],
|
||||||
|
"tool_result_history": [...],
|
||||||
|
"tools_used": [tool_names] # 新增:记录本轮使用的工具
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 finalize 节点
|
||||||
|
|
||||||
|
**职责**:轻量后处理
|
||||||
|
|
||||||
|
**流程**:
|
||||||
|
1. 从 `messages` 中提取最后一条 `AIMessage.content` 作为最终回复
|
||||||
|
2. 汇总元数据:步数、使用的工具、停止原因
|
||||||
|
3. 如果 `final_reply` 为空且有 `stop_reason`,生成说明文本
|
||||||
|
|
||||||
|
**返回值**:
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"final_reply": str,
|
||||||
|
"metadata": {
|
||||||
|
"steps_taken": int,
|
||||||
|
"tools_used": List[str],
|
||||||
|
"stop_reason": str,
|
||||||
|
"llm_calls": int
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.4 条件边
|
||||||
|
|
||||||
|
```python
|
||||||
|
def should_continue(state: AgentState) -> Literal["tools", "finalize"]:
|
||||||
|
"""根据 agent 节点输出决定下一步"""
|
||||||
|
# 手动停止标志
|
||||||
|
if getattr(state, "stop", False):
|
||||||
|
return "finalize"
|
||||||
|
|
||||||
|
# 检查是否有工具调用
|
||||||
|
last_msg = state.messages[-1]
|
||||||
|
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
||||||
|
return "tools"
|
||||||
|
|
||||||
|
return "finalize"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.5 循环回边
|
||||||
|
|
||||||
|
```
|
||||||
|
tools 节点执行完后 → 无条件回到 agent 节点
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 流式事件
|
||||||
|
|
||||||
|
| 阶段 | 事件类型 | 内容 |
|
||||||
|
|------|----------|------|
|
||||||
|
| agent | `node_start` | `{"node": "agent"}` |
|
||||||
|
| agent | `llm_token` | `{"token": "...", "reasoning_token": "..."}` |
|
||||||
|
| agent | `node_end` | `{"node": "agent"}` |
|
||||||
|
| tools | `tool_start` | `{"tool": "name", "args": {...}, "id": "..."}` |
|
||||||
|
| tools | `tool_end` | `{"tool": "name", "id": "...", "result": "..."}` |
|
||||||
|
| finalize | 无 | - |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 循环检测
|
||||||
|
|
||||||
|
保留现有 `_should_stop_for_loop` 函数,放在 agent 节点中。
|
||||||
|
|
||||||
|
**检测逻辑**:连续 2 次调用相同工具 + 参数相似 + 结果相似 → 设置 `stop=True, stop_reason="loop_detected"`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 文件变更
|
||||||
|
|
||||||
|
### 新增文件
|
||||||
|
- `backend/app/main_graph/nodes/tools.py` — tools 节点实现
|
||||||
|
- `backend/app/main_graph/nodes/finalize.py` — finalize 节点实现
|
||||||
|
|
||||||
|
### 修改文件
|
||||||
|
- `backend/app/main_graph/state.py` — 新增状态字段
|
||||||
|
- `backend/app/main_graph/main_graph_builder.py` — 重构图构建逻辑
|
||||||
|
- `backend/app/main_graph/nodes/agent.py` — 移除 while 循环和工具执行逻辑
|
||||||
|
|
||||||
|
### 删除
|
||||||
|
- 无
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 兼容性
|
||||||
|
|
||||||
|
- 对外接口(`process_message`, `process_message_stream`)保持不变
|
||||||
|
- 返回值格式调整:新增 `metadata` 字段
|
||||||
|
- checkpointer 兼容性:新增字段需设置默认值
|
||||||
@@ -125,9 +125,9 @@ async def run_single_test(graph, test_case: dict) -> dict:
|
|||||||
print("开始执行图...")
|
print("开始执行图...")
|
||||||
result = await graph.ainvoke(input_state, config=config)
|
result = await graph.ainvoke(input_state, config=config)
|
||||||
|
|
||||||
# 提取最终回复
|
# 提取最终回复(优先使用 final_reply)
|
||||||
reply = ""
|
reply = result.get("final_reply", "")
|
||||||
if result.get("messages"):
|
if not reply and result.get("messages"):
|
||||||
reply = result["messages"][-1].content
|
reply = result["messages"][-1].content
|
||||||
|
|
||||||
print(f"\n✓ 执行完成")
|
print(f"\n✓ 执行完成")
|
||||||
|
|||||||
Reference in New Issue
Block a user