refactor: 单图方案重构 + 动态模型选择 + chat_services优化
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
## 核心改动 ### 1. 单图方案重构 - 删除了多图(self.graphs),改为单图(self.graph) - 新增 MainGraphState.current_model 字段用于运行时注入模型 - llm_call 节点改为动态选择模型(create_dynamic_llm_call_node) ### 2. chat_services 优化 - 添加 _cached_services 缓存,避免重复初始化 - 新增 get_cached_chat_services() 函数,用于单图注入 - 新增 _check_http_service_available() 统一HTTP探测逻辑 - 减少重复代码,LocalVLLMChatProvider和LocalSmallModelProvider共用探测方法 ### 3. AIAgentService 重构 - initialize() 只构建一次图,传入 chat_services 字典 - 新增 _resolve_model() 模型回退逻辑 - 新增 _build_invocation() 统一构建调用参数 - process_message() 和 process_message_stream() 改为注入 current_model - 流式处理代码拆分,增加可读性 ### 4. 新增和删除文件 - 新增:backend/app/main_graph/main_graph_builder.py(图构建) - 新增:backend/app/main_graph/subgraph_wrapper.py(子图封装) - 新增:tools/test/test_tavily_search.py(测试) - 删除:backend/app/main_graph/graph.py(旧图) - 删除:backend/app/main_graph/utils/main_graph_builder.py(旧构建器) - 删除:backend/app/main_graph/utils/__init__.py ### 5. 其他更新 - README.md:新增模型服务使用情况详解章节 - backend/app/model_services/__init__.py:新增 get_cached_chat_services 导出 ## 方案优势 - 内存优化:N张图 → 1张图 - 灵活性:运行时动态选择模型,支持同会话不同模型 - 性能:模型服务缓存,初始化仅一次 - 可维护性:减少重复代码,统一HTTP探测逻辑
This commit is contained in:
@@ -1,25 +1,28 @@
|
||||
"""
|
||||
AI Agent 服务类 - 支持多模型动态切换
|
||||
AI Agent 服务类 - 单图方案 + 动态模型选择
|
||||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
|
||||
|
||||
# 本地模块
|
||||
from ..main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from ..model_services import get_cached_chat_services
|
||||
from ..main_graph.main_graph_builder import build_react_main_graph
|
||||
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from ..main_graph.config import set_stream_writer
|
||||
from ..main_graph.utils.rag_initializer import init_rag_tool
|
||||
from ..core.intent_classifier import get_intent_classifier
|
||||
from ..logger import info, warning, error
|
||||
from ..logger import debug, info, warning, error
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {}
|
||||
self.graph = None # 只有一张图
|
||||
self.chat_services = None # 缓存的模型字典
|
||||
self.tools = AVAILABLE_TOOLS.copy()
|
||||
self.tools_by_name = TOOLS_BY_NAME.copy()
|
||||
# 添加:意图分类器
|
||||
@@ -40,64 +43,94 @@ class AIAgentService:
|
||||
self.tools.append(rag_tool)
|
||||
self.tools_by_name[rag_tool.name] = rag_tool
|
||||
self.rag_tool = rag_tool # 保存到实例变量,供 config 注入
|
||||
|
||||
# 2. 构建各模型的 Graph(使用新版 React 模式)
|
||||
for name, llm in chat_services.items():
|
||||
try:
|
||||
info(f"🔄 初始化模型 '{name}'...")
|
||||
graph = build_react_main_graph(
|
||||
llm=llm,
|
||||
tools=self.tools,
|
||||
mem0_client=self.mem0_client
|
||||
).compile(checkpointer=self.checkpointer)
|
||||
self.graphs[name] = graph
|
||||
info(f"✅ 模型 '{name}' 初始化成功")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
|
||||
# 2. 获取缓存的模型字典
|
||||
self.chat_services = get_cached_chat_services()
|
||||
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
|
||||
|
||||
# 3. 只构建一次图(传入 chat_services 字典)
|
||||
info(f"🔄 构建单图...")
|
||||
graph_builder = build_react_main_graph(
|
||||
chat_services=self.chat_services,
|
||||
tools=self.tools,
|
||||
mem0_client=self.mem0_client
|
||||
)
|
||||
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
||||
info(f"✅ 单图初始化完成")
|
||||
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
if model not in self.graphs:
|
||||
# 回退到第一个可用模型
|
||||
available = list(self.graphs.keys())
|
||||
if not available:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
model = available[0]
|
||||
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""
|
||||
解析并验证模型名称,不可用时回退到第一个可用模型
|
||||
|
||||
Args:
|
||||
model: 目标模型名称
|
||||
|
||||
Returns:
|
||||
实际使用的模型名称
|
||||
"""
|
||||
if not model or model not in self.chat_services:
|
||||
fallback = next(iter(self.chat_services.keys()))
|
||||
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
|
||||
return fallback
|
||||
return model
|
||||
|
||||
graph = self.graphs[model]
|
||||
def _build_invocation(
|
||||
self, message: str, thread_id: str, model: str, user_id: str
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
构建图调用所需的 config 和 input_state
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
thread_id: 会话 ID
|
||||
model: 模型名称
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
(config, input_state) 元组
|
||||
"""
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具
|
||||
"rag_tool": getattr(self, "rag_tool", None),
|
||||
},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
# 新版状态输入:传入完整的 MainGraphState,关键是 user_query
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
input_state = {
|
||||
"user_query": message,
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
"user_id": user_id,
|
||||
"current_model": model,
|
||||
"current_action": CurrentAction.NONE
|
||||
}
|
||||
return config, input_state
|
||||
|
||||
result = await graph.ainvoke(input_state, config=config)
|
||||
async def process_message(
|
||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||
) -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
# 解析模型名称
|
||||
resolved_model = self._resolve_model(model)
|
||||
|
||||
# 构建调用参数
|
||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
||||
|
||||
result = await self.graph.ainvoke(input_state, config=config)
|
||||
|
||||
reply = result.get("final_result", "")
|
||||
if not reply and result.get("messages"):
|
||||
reply = result["messages"][-1].content
|
||||
token_usage = result.get("debug_info", {}).get("token_usage", {})
|
||||
elapsed_time = result.get("debug_info", {}).get("elapsed_time", 0.0)
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
actual_model = result.get("current_model", resolved_model)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
"elapsed_time": elapsed_time,
|
||||
"model_used": actual_model
|
||||
}
|
||||
|
||||
def _serialize_value(self, value):
|
||||
@@ -121,31 +154,169 @@ class AIAgentService:
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""流式处理消息,返回异步生成器(全部走 React 模式)"""
|
||||
graph = self.graphs.get(model_name)
|
||||
if not graph:
|
||||
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
|
||||
async def _handle_message_chunk(
|
||||
self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any]
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""处理 messages 类型的 chunk"""
|
||||
message_chunk, metadata = chunk["data"]
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
new_current_node = current_node
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具
|
||||
},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {
|
||||
"user_query": message,
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
"user_id": user_id,
|
||||
"current_action": CurrentAction.NONE
|
||||
}
|
||||
# 检测节点变化,发送节点开始事件
|
||||
if node_name != current_node:
|
||||
if current_node:
|
||||
yield {"type": "node_end", "node": current_node}
|
||||
yield {"type": "node_start", "node": node_name}
|
||||
new_current_node = node_name
|
||||
|
||||
# ========== 意图识别(保留用于日志)==========
|
||||
# 处理消息内容
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
# 处理思考过程
|
||||
if reasoning_token:
|
||||
yield {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
# 处理工具调用
|
||||
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
|
||||
for tool_call in message_chunk.tool_calls:
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("name", "")
|
||||
tool_args = tool_call.get("args", {})
|
||||
|
||||
# 记录工具调用开始,避免重复
|
||||
if tool_call_id and tool_call_id not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[tool_call_id] = {
|
||||
"name": tool_name,
|
||||
"args": tool_args
|
||||
}
|
||||
yield {
|
||||
"type": "tool_call_start",
|
||||
"tool": tool_name,
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}
|
||||
# 处理普通 token
|
||||
elif token_content:
|
||||
yield {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
|
||||
# 返回更新后的 current_node
|
||||
yield {"type": "_update_state", "current_node": new_current_node}
|
||||
|
||||
async def _handle_updates_chunk(
|
||||
self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""处理 updates 类型的 chunk"""
|
||||
updates_data = chunk["data"]
|
||||
new_actual_model = actual_model_used
|
||||
|
||||
debug(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}")
|
||||
|
||||
# 特别检查 final_result 和 current_model
|
||||
if isinstance(updates_data, dict):
|
||||
if "final_result" in updates_data:
|
||||
debug(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...")
|
||||
if "current_model" in updates_data:
|
||||
new_actual_model = updates_data["current_model"]
|
||||
info(f"[Stream] 实际使用模型: {new_actual_model}")
|
||||
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
|
||||
# 检查是否有人工审核请求
|
||||
if "review_pending" in serialized_data and serialized_data["review_pending"]:
|
||||
review_id = serialized_data.get("review_id", "")
|
||||
content_to_review = serialized_data.get("content_to_review", "")
|
||||
yield {
|
||||
"type": "human_review_request",
|
||||
"review_id": review_id,
|
||||
"content": content_to_review
|
||||
}
|
||||
|
||||
# 检查是否有工具结果
|
||||
if "messages" in serialized_data:
|
||||
for msg in serialized_data["messages"]:
|
||||
# 检测工具结果消息
|
||||
if msg.get("role") == "tool":
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
tool_name = msg.get("name", "")
|
||||
tool_result = msg.get("content", "")
|
||||
|
||||
if tool_call_id and tool_call_id in tool_calls_in_progress:
|
||||
yield {
|
||||
"type": "tool_call_end",
|
||||
"tool": tool_name,
|
||||
"id": tool_call_id,
|
||||
"result": tool_result
|
||||
}
|
||||
del tool_calls_in_progress[tool_call_id]
|
||||
|
||||
yield {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
# 返回更新后的模型
|
||||
yield {"type": "_update_state", "actual_model_used": new_actual_model}
|
||||
|
||||
async def _handle_custom_chunk(self, chunk: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""处理 custom 类型的 chunk"""
|
||||
custom_data = chunk["data"]
|
||||
|
||||
# 处理我们从 react_reason_node 发送的自定义推理事件
|
||||
if isinstance(custom_data, dict):
|
||||
# 检查是否是我们的推理事件
|
||||
if "action" in custom_data and "reasoning" in custom_data:
|
||||
yield {
|
||||
"type": "react_reasoning",
|
||||
"step": custom_data.get("step", 1),
|
||||
"action": custom_data.get("action", "unknown"),
|
||||
"confidence": custom_data.get("confidence", 0),
|
||||
"reasoning": custom_data.get("reasoning", "")
|
||||
}
|
||||
else:
|
||||
# 处理其他自定义事件
|
||||
serialized_data = self._serialize_value(custom_data)
|
||||
yield {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
else:
|
||||
# 处理其他自定义事件
|
||||
serialized_data = self._serialize_value(custom_data)
|
||||
yield {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
async def process_message_stream(
|
||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""流式处理消息,返回异步生成器"""
|
||||
# 解析模型名称
|
||||
resolved_model = self._resolve_model(model)
|
||||
|
||||
# 构建调用参数
|
||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
||||
|
||||
# ========== 意图识别(保留用于日志和后续路由)==========
|
||||
intent_result = await self.intent_classifier.classify(message)
|
||||
info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})")
|
||||
info(f"📝 推理: {intent_result.reasoning}")
|
||||
|
||||
# 注入意图到状态(让 hybrid_router 可以利用)
|
||||
input_state["intent_type"] = intent_result.intent_type.value
|
||||
input_state["intent_confidence"] = intent_result.confidence
|
||||
|
||||
# 发送意图分类事件
|
||||
yield {
|
||||
"type": "intent_classified",
|
||||
@@ -154,25 +325,26 @@ class AIAgentService:
|
||||
"reasoning": intent_result.reasoning
|
||||
}
|
||||
|
||||
# 发送路径决策事件(现在都是 react_loop)
|
||||
# 发送路径决策事件(目前硬编码,但状态中有意图信息供后续使用)
|
||||
yield {
|
||||
"type": "path_decision",
|
||||
"path": "react_loop",
|
||||
"intent": intent_result.intent_type.value
|
||||
}
|
||||
# ========================================
|
||||
# =============================================
|
||||
|
||||
# ========== React 循环路径 ==========
|
||||
info(f"🚀 开始执行 React 图,模型: {model_name}")
|
||||
info(f"🚀 开始执行单图,指定模型: {resolved_model}")
|
||||
current_node = None
|
||||
tool_calls_in_progress = {}
|
||||
tool_calls_in_progress: Dict[str, Any] = {}
|
||||
actual_model_used = resolved_model
|
||||
chunk_count = 0
|
||||
full_message_content = ""
|
||||
|
||||
try:
|
||||
info(f"📡 开始调用 graph.astream()...")
|
||||
chunk_count = 0
|
||||
full_message_content = "" # 收集完整消息内容
|
||||
|
||||
async for chunk in graph.astream(
|
||||
async for chunk in self.graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
stream_mode=["messages", "updates", "custom"],
|
||||
@@ -181,156 +353,58 @@ class AIAgentService:
|
||||
):
|
||||
chunk_count += 1
|
||||
chunk_type = chunk["type"]
|
||||
processed_event = {}
|
||||
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
|
||||
# 检测节点变化,发送节点开始事件
|
||||
if node_name != current_node:
|
||||
if current_node:
|
||||
yield {
|
||||
"type": "node_end",
|
||||
"node": current_node
|
||||
}
|
||||
yield {
|
||||
"type": "node_start",
|
||||
"node": node_name
|
||||
}
|
||||
current_node = node_name
|
||||
|
||||
# 处理消息内容
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
# 处理思考过程
|
||||
if reasoning_token:
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
# 处理工具调用
|
||||
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
|
||||
for tool_call in message_chunk.tool_calls:
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("name", "")
|
||||
tool_args = tool_call.get("args", {})
|
||||
|
||||
# 记录工具调用开始
|
||||
if tool_call_id not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[tool_call_id] = {
|
||||
"name": tool_name,
|
||||
"args": tool_args
|
||||
}
|
||||
yield {
|
||||
"type": "tool_call_start",
|
||||
"tool": tool_name,
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}
|
||||
# 处理普通 token - 只收集,不打印单个 token
|
||||
elif token_content:
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
if node_name == "llm_call":
|
||||
full_message_content += token_content
|
||||
async for event in self._handle_message_chunk(
|
||||
chunk, current_node, tool_calls_in_progress
|
||||
):
|
||||
if event.get("type") == "_update_state":
|
||||
current_node = event.get("current_node", current_node)
|
||||
else:
|
||||
# 如果是 llm_call 节点的 token,收集完整消息
|
||||
if (
|
||||
event.get("type") == "llm_token"
|
||||
and event.get("node") == "llm_call"
|
||||
and "token" in event
|
||||
):
|
||||
full_message_content += event["token"]
|
||||
yield event
|
||||
|
||||
elif chunk_type == "updates":
|
||||
updates_data = chunk["data"]
|
||||
info(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}")
|
||||
# 特别检查 final_result
|
||||
if isinstance(updates_data, dict) and "final_result" in updates_data:
|
||||
info(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...")
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
if "review_pending" in serialized_data and serialized_data["review_pending"]:
|
||||
review_id = serialized_data.get("review_id", "")
|
||||
content_to_review = serialized_data.get("content_to_review", "")
|
||||
yield {
|
||||
"type": "human_review_request",
|
||||
"review_id": review_id,
|
||||
"content": content_to_review
|
||||
}
|
||||
|
||||
# 检查是否有工具结果
|
||||
if "messages" in serialized_data:
|
||||
for msg in serialized_data["messages"]:
|
||||
# 检测工具结果消息
|
||||
if msg.get("role") == "tool":
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
tool_name = msg.get("name", "")
|
||||
tool_output = msg.get("content", "")
|
||||
|
||||
if tool_call_id in tool_calls_in_progress:
|
||||
yield {
|
||||
"type": "tool_call_end",
|
||||
"tool": tool_name,
|
||||
"id": tool_call_id,
|
||||
"result": tool_output
|
||||
}
|
||||
del tool_calls_in_progress[tool_call_id]
|
||||
|
||||
processed_event = {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
async for event in self._handle_updates_chunk(
|
||||
chunk, tool_calls_in_progress, actual_model_used
|
||||
):
|
||||
if event.get("type") == "_update_state":
|
||||
actual_model_used = event.get("actual_model_used", actual_model_used)
|
||||
else:
|
||||
yield event
|
||||
|
||||
elif chunk_type == "custom":
|
||||
custom_data = chunk["data"]
|
||||
|
||||
# 处理我们从 react_reason_node 发送的自定义推理事件
|
||||
if isinstance(custom_data, dict):
|
||||
# 检查是否是我们的推理事件
|
||||
if "action" in custom_data and "reasoning" in custom_data:
|
||||
yield {
|
||||
"type": "react_reasoning",
|
||||
"step": custom_data.get("step", 1),
|
||||
"action": custom_data.get("action", "unknown"),
|
||||
"confidence": custom_data.get("confidence", 0),
|
||||
"reasoning": custom_data.get("reasoning", "")
|
||||
}
|
||||
else:
|
||||
# 处理其他自定义事件
|
||||
serialized_data = self._serialize_value(custom_data)
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
else:
|
||||
# 处理其他自定义事件
|
||||
serialized_data = self._serialize_value(custom_data)
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
async for event in self._handle_custom_chunk(chunk):
|
||||
yield event
|
||||
|
||||
# 完整消息集合完成后,一次性打印
|
||||
info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks")
|
||||
info(f"✅ graph.astream() 完成,共 {chunk_count} 个 chunks")
|
||||
if full_message_content:
|
||||
info(f"📄 完整消息内容: {repr(full_message_content)}")
|
||||
info(f"🤖 实际使用模型: {actual_model_used}")
|
||||
|
||||
except Exception as e:
|
||||
error(f"❌ 执行 React 图时出错: {e}")
|
||||
error(f"❌ 执行单图时出错: {e}")
|
||||
import traceback
|
||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
# 发送结束事件
|
||||
if current_node:
|
||||
yield {
|
||||
"type": "node_end",
|
||||
"node": current_node
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
finally:
|
||||
# 无论成功或失败,都发送结束事件,保证前端平稳关闭
|
||||
if current_node:
|
||||
yield {
|
||||
"type": "node_end",
|
||||
"node": current_node
|
||||
}
|
||||
yield {
|
||||
"type": "done",
|
||||
"model_used": actual_model_used
|
||||
}
|
||||
yield {
|
||||
"type": "done"
|
||||
}
|
||||
@@ -22,9 +22,8 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
|
||||
"3. 📇 通讯录子系统 - 查询联系人、添加联系人、管理通讯录\n"
|
||||
"4. 🔍 RAG检索 - 从知识库中检索相关信息回答问题\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
f"{tools_section}\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
|
||||
@@ -127,6 +127,13 @@ BACKEND_PORT = _get_int("BACKEND_PORT")
|
||||
MEMORY_SUMMARIZE_INTERVAL = _get_int("MEMORY_SUMMARIZE_INTERVAL")
|
||||
|
||||
|
||||
# ========== Tavily 搜索配置 ==========
|
||||
# Tavily API:https://app.tavily.com
|
||||
# 免费额度:1000次/天
|
||||
TAVILY_API_KEY = _get_str("TAVILY_API_KEY")
|
||||
TAVILY_MAX_RESULTS = _get_int("TAVILY_MAX_RESULTS") or 5
|
||||
|
||||
|
||||
# ========== Graph 执行追踪配置 ==========
|
||||
# 是否启用 Graph 流转追踪(通过环境变量控制)
|
||||
ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE")
|
||||
|
||||
@@ -33,18 +33,29 @@ class WebSearchTool:
|
||||
|
||||
def search(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||
"""
|
||||
使用多种方式搜索
|
||||
|
||||
使用多种方式搜索,按优先级尝试
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量,默认使用初始化时的设置
|
||||
|
||||
|
||||
Returns:
|
||||
搜索结果列表
|
||||
"""
|
||||
num_results = max_results or self.max_results
|
||||
|
||||
# 方式 1: 尝试用 ddgs 包
|
||||
# 方式 1: Tavily (需要 API Key,质量最高)
|
||||
try:
|
||||
return self._search_tavily(query, num_results)
|
||||
except ImportError:
|
||||
print("[WebSearch] tavily 未安装,尝试其他搜索方式")
|
||||
except Exception as e:
|
||||
if "API_KEY" in str(e) or "未配置" in str(e):
|
||||
print(f"[WebSearch] Tavily API Key 未配置: {e}")
|
||||
else:
|
||||
print(f"[WebSearch] Tavily 搜索失败: {e}")
|
||||
|
||||
# 方式 2: 尝试用 ddgs 包
|
||||
try:
|
||||
from ddgs import DDGS
|
||||
print(f"[WebSearch] 使用 ddgs 搜索: {query}")
|
||||
@@ -65,29 +76,7 @@ class WebSearchTool:
|
||||
print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search")
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] ddgs 搜索失败: {e}")
|
||||
|
||||
# 方式 2: 尝试用旧的 duckduckgo-search 包
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
print(f"[WebSearch] 使用 duckduckgo-search 搜索: {query}")
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(query, max_results=num_results))
|
||||
if results:
|
||||
search_results = []
|
||||
for r in results:
|
||||
search_results.append(SearchResult(
|
||||
title=r.get("title", ""),
|
||||
url=r.get("href", ""),
|
||||
snippet=r.get("body", ""),
|
||||
source="DuckDuckGo"
|
||||
))
|
||||
print(f"[WebSearch] duckduckgo-search 返回 {len(search_results)} 条结果")
|
||||
return search_results
|
||||
except ImportError:
|
||||
print("[WebSearch] duckduckgo-search 未安装")
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] duckduckgo-search 搜索失败: {e}")
|
||||
|
||||
|
||||
# 方式 3: 尝试用简单 HTTP 请求
|
||||
try:
|
||||
return self._search_http(query, num_results)
|
||||
@@ -97,6 +86,34 @@ class WebSearchTool:
|
||||
# 方式 4: 返回模拟数据作为最后兜底
|
||||
return self._search_mock(query, num_results)
|
||||
|
||||
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""使用 Tavily API 搜索"""
|
||||
from tavily import TavilyClient
|
||||
from app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS
|
||||
|
||||
if not TAVILY_API_KEY:
|
||||
raise ValueError("TAVILY_API_KEY 未配置")
|
||||
|
||||
client = TavilyClient(api_key=TAVILY_API_KEY)
|
||||
response = client.search(
|
||||
query=query,
|
||||
max_results=min(max_results, TAVILY_MAX_RESULTS or 5),
|
||||
include_answer=True,
|
||||
include_raw_content=False
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in response.get("results", []):
|
||||
results.append(SearchResult(
|
||||
title=item.get("title", ""),
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("content", ""),
|
||||
source="Tavily"
|
||||
))
|
||||
|
||||
print(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
||||
return results
|
||||
|
||||
def _search_http(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源"""
|
||||
print(f"[WebSearch] 尝试 HTTP 搜索")
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
LangGraph 核心组件重新导出
|
||||
统一导入入口,避免直接依赖 langgraph
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, START, END, add_messages
|
||||
|
||||
__all__ = ["StateGraph", "START", "END", "add_messages"]
|
||||
229
backend/app/main_graph/main_graph_builder.py
Normal file
229
backend/app/main_graph/main_graph_builder.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
主图构建器 - 构建整合后的完整主图
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing import Dict, Any
|
||||
|
||||
from .state import MainGraphState
|
||||
from .nodes.reasoning import react_reason_node
|
||||
from .nodes.web_search import web_search_node
|
||||
from .nodes.error_handling import error_handling_node
|
||||
from .nodes.routing import init_state_node, route_by_reasoning, should_summarize
|
||||
from .nodes.hybrid_router import (
|
||||
hybrid_router_node,
|
||||
route_from_hybrid_decision,
|
||||
check_fast_path_success,
|
||||
)
|
||||
from .nodes.fast_paths import (
|
||||
fast_chitchat_node,
|
||||
fast_rag_node,
|
||||
fast_tool_node,
|
||||
)
|
||||
from .nodes.llm_call import create_dynamic_llm_call_node
|
||||
from .nodes.rag_nodes import rag_retrieve_node
|
||||
from .nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from .nodes.summarize import create_summarize_node
|
||||
from .nodes.finalize import finalize_node
|
||||
from ..subgraphs.contact import build_contact_subgraph
|
||||
from ..subgraphs.dictionary import build_dictionary_subgraph
|
||||
from ..subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from ..logger import info
|
||||
|
||||
from .subgraph_wrapper import create_subgraph_nodes
|
||||
|
||||
|
||||
# ========== 主图构建 ==========
|
||||
|
||||
def build_react_main_graph(
|
||||
chat_services: dict,
|
||||
tools=None,
|
||||
mem0_client=None,
|
||||
use_hybrid_router: bool = True
|
||||
) -> StateGraph:
|
||||
"""
|
||||
构建整合后的完整主图(支持混合路由 + 动态模型选择)
|
||||
|
||||
Args:
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
mem0_client: Mem0 客户端实例
|
||||
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
|
||||
|
||||
Returns:
|
||||
StateGraph: 构建好的图
|
||||
"""
|
||||
# 创建图
|
||||
graph = StateGraph(MainGraphState)
|
||||
|
||||
# 设置全局 mem0_client
|
||||
if mem0_client:
|
||||
set_mem0_client(mem0_client)
|
||||
|
||||
# ========== 创建节点 ==========
|
||||
|
||||
# LLM 调用节点
|
||||
llm_node = create_dynamic_llm_call_node(chat_services, tools or [])
|
||||
|
||||
# 记忆节点
|
||||
retrieve_memory_node = None
|
||||
summarize_node = None
|
||||
if mem0_client:
|
||||
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
||||
summarize_node = create_summarize_node(mem0_client)
|
||||
|
||||
# 子图节点
|
||||
contact_graph = build_contact_subgraph()
|
||||
dictionary_graph = build_dictionary_subgraph()
|
||||
news_analysis_graph = build_news_analysis_subgraph()
|
||||
subgraph_nodes = create_subgraph_nodes(
|
||||
contact_graph, dictionary_graph, news_analysis_graph
|
||||
)
|
||||
|
||||
# ========== 添加节点到图 ==========
|
||||
|
||||
# 阶段 1: 记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_node("retrieve_memory", retrieve_memory_node)
|
||||
graph.add_node("memory_trigger", memory_trigger_node)
|
||||
|
||||
# 阶段 2: 初始化
|
||||
graph.add_node("init_state", init_state_node)
|
||||
|
||||
# 阶段 3: 混合路由(可选)
|
||||
if use_hybrid_router:
|
||||
graph.add_node("hybrid_router", hybrid_router_node)
|
||||
graph.add_node("fast_chitchat", fast_chitchat_node)
|
||||
graph.add_node("fast_rag", fast_rag_node)
|
||||
graph.add_node("fast_tool", fast_tool_node)
|
||||
|
||||
# 阶段 4: React 循环推理(始终保留)
|
||||
graph.add_node("react_reason", react_reason_node)
|
||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||
graph.add_node("web_search", web_search_node)
|
||||
graph.add_node("handle_error", error_handling_node)
|
||||
|
||||
if llm_node is not None:
|
||||
graph.add_node("llm_call", llm_node)
|
||||
|
||||
# 子图节点
|
||||
for node_name, node_func in subgraph_nodes.items():
|
||||
graph.add_node(node_name, node_func)
|
||||
|
||||
# 阶段 5: 完成处理
|
||||
if summarize_node:
|
||||
graph.add_node("summarize", summarize_node)
|
||||
graph.add_node("finalize", finalize_node)
|
||||
|
||||
# ========== 添加边 ==========
|
||||
|
||||
# 阶段 1: 记忆检索
|
||||
_add_memory_edges(graph, retrieve_memory_node)
|
||||
|
||||
# 阶段 2: 初始化
|
||||
graph.add_edge("memory_trigger", "init_state")
|
||||
|
||||
# 阶段 3: 路由分支
|
||||
_add_routing_edges(graph, use_hybrid_router, llm_node)
|
||||
|
||||
# 阶段 4: React 循环边
|
||||
_add_react_loop_edges(graph, subgraph_nodes)
|
||||
|
||||
# 阶段 5: 完成阶段
|
||||
_add_finalize_edges(graph, llm_node, summarize_node)
|
||||
|
||||
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})")
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def _add_memory_edges(graph: StateGraph, retrieve_memory_node) -> None:
|
||||
"""添加记忆检索阶段的边"""
|
||||
if retrieve_memory_node:
|
||||
graph.add_edge(START, "retrieve_memory")
|
||||
graph.add_edge("retrieve_memory", "memory_trigger")
|
||||
else:
|
||||
graph.add_edge(START, "memory_trigger")
|
||||
|
||||
|
||||
def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) -> None:
|
||||
"""添加路由阶段的边"""
|
||||
if use_hybrid_router:
|
||||
graph.add_edge("init_state", "hybrid_router")
|
||||
|
||||
# 混合路由条件分支
|
||||
graph.add_conditional_edges(
|
||||
"hybrid_router",
|
||||
route_from_hybrid_decision,
|
||||
{
|
||||
"fast_chitchat": "fast_chitchat",
|
||||
"fast_rag": "fast_rag",
|
||||
"fast_tool": "fast_tool",
|
||||
"react_loop": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
# 快速路径的完成检查
|
||||
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
|
||||
graph.add_conditional_edges(
|
||||
fast_node,
|
||||
check_fast_path_success,
|
||||
{
|
||||
"llm_call": "llm_call",
|
||||
"escalate": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
info(f"✅ [图构建] 混合路由模式已启用")
|
||||
else:
|
||||
graph.add_edge("init_state", "react_reason")
|
||||
info(f"✅ [图构建] 纯 React 模式")
|
||||
|
||||
|
||||
def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> None:
|
||||
"""添加 React 循环阶段的边"""
|
||||
subgraph_names = list(subgraph_nodes.keys())
|
||||
|
||||
# React 推理的条件分支
|
||||
graph.add_conditional_edges(
|
||||
"react_reason",
|
||||
route_by_reasoning,
|
||||
{
|
||||
"rag_retrieve": "rag_retrieve",
|
||||
"web_search": "web_search",
|
||||
**{name: name for name in subgraph_names},
|
||||
"handle_error": "handle_error",
|
||||
"llm_call": "llm_call"
|
||||
}
|
||||
)
|
||||
|
||||
# 循环边(回到 react_reason)
|
||||
loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names
|
||||
for node_name in loop_back_nodes:
|
||||
graph.add_edge(node_name, "react_reason")
|
||||
|
||||
|
||||
def _add_finalize_edges(graph: StateGraph, llm_node, summarize_node) -> None:
|
||||
"""添加完成阶段的边"""
|
||||
if llm_node is not None:
|
||||
if summarize_node:
|
||||
graph.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_summarize,
|
||||
{
|
||||
"summarize": "summarize",
|
||||
"finalize": "finalize"
|
||||
}
|
||||
)
|
||||
graph.add_edge("summarize", "finalize")
|
||||
else:
|
||||
graph.add_edge("llm_call", "finalize")
|
||||
|
||||
graph.add_edge("finalize", END)
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
"build_react_main_graph",
|
||||
]
|
||||
@@ -6,8 +6,8 @@
|
||||
from .reasoning import react_reason_node
|
||||
from .web_search import web_search_node
|
||||
from .error_handling import error_handling_node
|
||||
from .routing import init_state_node, route_by_reasoning
|
||||
from .llm_call import create_llm_call_node
|
||||
from .routing import init_state_node, route_by_reasoning, should_summarize
|
||||
from .llm_call import create_dynamic_llm_call_node
|
||||
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
|
||||
|
||||
# 记忆节点
|
||||
@@ -38,7 +38,8 @@ __all__ = [
|
||||
"web_search_node",
|
||||
"error_handling_node",
|
||||
"route_by_reasoning",
|
||||
"create_llm_call_node",
|
||||
"should_summarize",
|
||||
"create_dynamic_llm_call_node",
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
# 记忆节点
|
||||
|
||||
@@ -5,7 +5,7 @@ LLM 调用节点模块
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
@@ -14,29 +14,34 @@ from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
|
||||
def create_llm_call_node(llm, tools: list):
|
||||
|
||||
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
|
||||
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
|
||||
|
||||
Args:
|
||||
llm: LangChain LLM 实例
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
# 预构建所有模型的 tools 绑定(避免每次调用都 bind)
|
||||
bound_models: Dict[str, Any] = {}
|
||||
for name, llm in chat_services.items():
|
||||
if tools:
|
||||
bound_models[name] = llm.bind_tools(tools)
|
||||
else:
|
||||
bound_models[name] = llm
|
||||
|
||||
# 预构建 prompt
|
||||
prompt = create_system_prompt(tools)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
chain = prompt | llm_with_tools
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
LLM 调用节点(动态选择模型)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
@@ -46,7 +51,7 @@ def create_llm_call_node(llm, tools: list):
|
||||
更新后的状态字典
|
||||
"""
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
|
||||
memory_context = getattr(state, "memory_context", "暂无用户信息")
|
||||
start_time = time.time()
|
||||
|
||||
@@ -62,9 +67,20 @@ def create_llm_call_node(llm, tools: list):
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
}
|
||||
|
||||
|
||||
# 动态选择模型
|
||||
model_name = getattr(state, "current_model", "")
|
||||
if not model_name or model_name not in bound_models:
|
||||
# 回退到第一个可用模型
|
||||
fallback_name = next(iter(bound_models.keys()))
|
||||
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
|
||||
model_name = fallback_name
|
||||
|
||||
llm_with_tools = bound_models[model_name]
|
||||
info(f"[llm_call] 使用模型: {model_name}")
|
||||
|
||||
try:
|
||||
# 添加 RAG 上下文到消息
|
||||
# 添加上下文到消息
|
||||
messages_with_context = list(state.messages)
|
||||
if state.rag_context:
|
||||
from langchain_core.messages import SystemMessage
|
||||
@@ -77,9 +93,10 @@ def create_llm_call_node(llm, tools: list):
|
||||
break
|
||||
if not inserted:
|
||||
messages_with_context.insert(0, rag_system_msg)
|
||||
|
||||
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chain = prompt | llm_with_tools
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
@@ -89,7 +106,7 @@ def create_llm_call_node(llm, tools: list):
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
@@ -97,14 +114,14 @@ def create_llm_call_node(llm, tools: list):
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
@@ -112,33 +129,33 @@ def create_llm_call_node(llm, tools: list):
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f"📥 [LLM输出] 模型: {model_name} 返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 检查是否有工具调用
|
||||
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
|
||||
|
||||
@@ -151,21 +168,22 @@ def create_llm_call_node(llm, tools: list):
|
||||
"final_result": response.content,
|
||||
"success": True,
|
||||
"current_phase": "done",
|
||||
"has_tool_calls": has_tool_calls
|
||||
"has_tool_calls": has_tool_calls,
|
||||
"current_model": model_name # 记录实际使用的模型
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f"\n❌ [LLM错误] 模型 {model_name} 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
@@ -178,10 +196,11 @@ def create_llm_call_node(llm, tools: list):
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
|
||||
"success": False,
|
||||
"current_phase": "done"
|
||||
"current_phase": "done",
|
||||
"current_model": model_name
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开(异常)")
|
||||
return error_result
|
||||
|
||||
return call_llm
|
||||
|
||||
return call_llm
|
||||
|
||||
@@ -118,3 +118,21 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
|
||||
info(f"[条件路由] 动作={latest_action}, 目标={target}")
|
||||
return target
|
||||
|
||||
|
||||
# ========== 完成阶段条件路由函数 ==========
|
||||
|
||||
def should_summarize(state: MainGraphState) -> str:
|
||||
"""
|
||||
检查是否需要总结对话(对话足够长时)
|
||||
|
||||
Args:
|
||||
state: 当前图状态
|
||||
|
||||
Returns:
|
||||
"summarize" 或 "finalize"
|
||||
"""
|
||||
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
|
||||
return "summarize"
|
||||
else:
|
||||
return "finalize"
|
||||
|
||||
@@ -6,7 +6,7 @@ Main Graph State Definition - React Mode Enhanced
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List
|
||||
from dataclasses import dataclass, field
|
||||
from app.main_graph.graph import add_messages
|
||||
from langgraph.graph import add_messages
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ class MainGraphState:
|
||||
# ========== 主图控制字段 ==========
|
||||
user_query: str = ""
|
||||
current_action: CurrentAction = CurrentAction.NONE
|
||||
current_model: str = "" # 新增:本次请求使用的模型
|
||||
intent_confidence: float = 0.0
|
||||
|
||||
# ========== React 推理专用字段 ==========
|
||||
|
||||
159
backend/app/main_graph/subgraph_wrapper.py
Normal file
159
backend/app/main_graph/subgraph_wrapper.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
子图包装器 - 为子图添加错误处理和事件追踪
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ..logger import info
|
||||
|
||||
|
||||
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
"""
|
||||
包装子图,使其错误能传递给主图
|
||||
|
||||
Args:
|
||||
subgraph: 编译好的子图
|
||||
name: 子图名称(用于错误标识)
|
||||
|
||||
Returns: 包装后的节点函数
|
||||
"""
|
||||
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
# 发送子图开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"开始执行 {name} 子图..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
|
||||
|
||||
try:
|
||||
# 调用子图
|
||||
result = subgraph.invoke(state)
|
||||
|
||||
# 更新主图状态
|
||||
subgraph_result = None
|
||||
if name == "contact":
|
||||
state.contact_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "dictionary":
|
||||
state.dictionary_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "news_analysis":
|
||||
state.news_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
|
||||
# 设置最终结果
|
||||
if subgraph_result:
|
||||
state.final_result = subgraph_result
|
||||
else:
|
||||
state.final_result = "子图执行完成"
|
||||
|
||||
# 标记成功
|
||||
state.success = True
|
||||
state.current_phase = "done"
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "subgraph_completed",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name}子图执行完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 发送子图完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行完成"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
# 捕获子图错误,传递给主图
|
||||
error_record = ErrorRecord(
|
||||
error_type=f"{name}SubgraphError",
|
||||
error_message=str(e),
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source=f"{name}_subgraph",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
context={"user_query": state.user_query}
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
state.success = False
|
||||
|
||||
# 发送子图错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行失败: {str(e)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
return wrapped_node
|
||||
|
||||
|
||||
def create_subgraph_nodes(contact_graph, dictionary_graph, news_analysis_graph) -> Dict[str, Any]:
|
||||
"""
|
||||
创建所有子图节点的字典
|
||||
|
||||
Args:
|
||||
contact_graph: 联系人子图
|
||||
dictionary_graph: 词典子图
|
||||
news_analysis_graph: 新闻分析子图
|
||||
|
||||
Returns:
|
||||
子图节点字典 {name: wrapped_node}
|
||||
"""
|
||||
return {
|
||||
"contact_subgraph": wrap_subgraph_for_error_handling(
|
||||
contact_graph.compile(), "contact"
|
||||
),
|
||||
"dictionary_subgraph": wrap_subgraph_for_error_handling(
|
||||
dictionary_graph.compile(), "dictionary"
|
||||
),
|
||||
"news_analysis_subgraph": wrap_subgraph_for_error_handling(
|
||||
news_analysis_graph.compile(), "news_analysis"
|
||||
),
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
"""主图工具函数"""
|
||||
@@ -1,371 +0,0 @@
|
||||
"""
|
||||
整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState
|
||||
"""
|
||||
|
||||
from ..graph import StateGraph, START, END
|
||||
from typing import Dict, Any, Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ..nodes.reasoning import react_reason_node
|
||||
from ..nodes.web_search import web_search_node
|
||||
from ..nodes.error_handling import error_handling_node
|
||||
from ..nodes.routing import init_state_node, route_by_reasoning
|
||||
from ..nodes.hybrid_router import (
|
||||
hybrid_router_node,
|
||||
route_from_hybrid_decision,
|
||||
check_fast_path_success,
|
||||
)
|
||||
from ..nodes.fast_paths import (
|
||||
fast_chitchat_node,
|
||||
fast_rag_node,
|
||||
fast_tool_node,
|
||||
)
|
||||
from ..nodes.llm_call import create_llm_call_node
|
||||
from ..nodes.rag_nodes import rag_retrieve_node
|
||||
from ..nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from ..nodes.summarize import create_summarize_node
|
||||
from ..nodes.finalize import finalize_node
|
||||
from ...subgraphs.contact import build_contact_subgraph
|
||||
from ...subgraphs.dictionary import build_dictionary_subgraph
|
||||
from ...subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...logger import info, debug
|
||||
|
||||
|
||||
# ========== 检查是否需要总结 ==========
|
||||
def should_summarize(state: MainGraphState) -> str:
|
||||
"""
|
||||
检查是否需要总结对话(对话足够长时)
|
||||
|
||||
Args:
|
||||
state: 当前图状态
|
||||
|
||||
Returns:
|
||||
"summarize" 或 "finalize"
|
||||
"""
|
||||
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
|
||||
return "summarize"
|
||||
else:
|
||||
return "finalize"
|
||||
|
||||
|
||||
# ========== 子图包装器(处理子图错误传递)==========
|
||||
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
"""
|
||||
包装子图,使其错误能传递给主图
|
||||
|
||||
Args:
|
||||
subgraph: 编译好的子图
|
||||
name: 子图名称(用于错误标识)
|
||||
|
||||
Returns: 包装后的节点函数
|
||||
"""
|
||||
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
# 发送子图开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"开始执行 {name} 子图..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
|
||||
|
||||
try:
|
||||
# 调用子图
|
||||
result = subgraph.invoke(state)
|
||||
|
||||
# 更新主图状态
|
||||
subgraph_result = None
|
||||
if name == "contact":
|
||||
state.contact_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "dictionary":
|
||||
state.dictionary_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "news_analysis":
|
||||
state.news_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
|
||||
# 关键:设置最终结果
|
||||
if subgraph_result:
|
||||
state.final_result = subgraph_result
|
||||
else:
|
||||
state.final_result = "子图执行完成"
|
||||
|
||||
# 标记成功
|
||||
state.success = True
|
||||
state.current_phase = "done"
|
||||
# 标记不再需要推理,避免循环
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "subgraph_completed",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name}子图执行完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 发送子图完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行完成"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
# 捕获子图错误,传递给主图
|
||||
from ..state import ErrorRecord, ErrorSeverity
|
||||
from datetime import datetime
|
||||
|
||||
error_record = ErrorRecord(
|
||||
error_type=f"{name}SubgraphError",
|
||||
error_message=str(e),
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source=f"{name}_subgraph",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
context={"user_query": state.user_query}
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
state.success = False
|
||||
|
||||
# 发送子图错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行失败: {str(e)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
return wrapped_node
|
||||
|
||||
# ========== 主图构建 ==========
|
||||
|
||||
def build_react_main_graph(llm=None, tools=None, mem0_client=None, use_hybrid_router: bool = True) -> StateGraph:
|
||||
"""
|
||||
构建整合后的完整主图(支持混合路由)
|
||||
|
||||
Args:
|
||||
llm: LangChain ChatModel 实例
|
||||
tools: 工具列表
|
||||
mem0_client: Mem0 客户端实例
|
||||
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
|
||||
|
||||
Returns:
|
||||
StateGraph: 构建好的图
|
||||
"""
|
||||
# 创建图
|
||||
graph = StateGraph(MainGraphState)
|
||||
|
||||
# 设置全局 mem0_client
|
||||
if mem0_client:
|
||||
set_mem0_client(mem0_client)
|
||||
|
||||
# 创建节点
|
||||
llm_node = None
|
||||
if llm is not None:
|
||||
llm_node = create_llm_call_node(llm, tools or [])
|
||||
|
||||
retrieve_memory_node = None
|
||||
summarize_node = None
|
||||
if mem0_client:
|
||||
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
||||
summarize_node = create_summarize_node(mem0_client)
|
||||
|
||||
# ========== 添加节点 ==========
|
||||
|
||||
# 第一阶段:记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_node("retrieve_memory", retrieve_memory_node)
|
||||
graph.add_node("memory_trigger", memory_trigger_node)
|
||||
|
||||
# 第二阶段:初始化
|
||||
graph.add_node("init_state", init_state_node)
|
||||
|
||||
# ========== 混合路由节点(如果启用) ==========
|
||||
if use_hybrid_router:
|
||||
graph.add_node("hybrid_router", hybrid_router_node)
|
||||
graph.add_node("fast_chitchat", fast_chitchat_node)
|
||||
graph.add_node("fast_rag", fast_rag_node)
|
||||
graph.add_node("fast_tool", fast_tool_node)
|
||||
|
||||
# 第三阶段:React 循环推理(始终保留)
|
||||
graph.add_node("react_reason", react_reason_node)
|
||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||
graph.add_node("web_search", web_search_node)
|
||||
graph.add_node("handle_error", error_handling_node)
|
||||
|
||||
if llm_node is not None:
|
||||
graph.add_node("llm_call", llm_node)
|
||||
|
||||
# 子图节点
|
||||
contact_graph = build_contact_subgraph()
|
||||
dictionary_graph = build_dictionary_subgraph()
|
||||
news_analysis_graph = build_news_analysis_subgraph()
|
||||
|
||||
graph.add_node(
|
||||
"contact_subgraph",
|
||||
wrap_subgraph_for_error_handling(contact_graph.compile(), "contact")
|
||||
)
|
||||
graph.add_node(
|
||||
"dictionary_subgraph",
|
||||
wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary")
|
||||
)
|
||||
graph.add_node(
|
||||
"news_analysis_subgraph",
|
||||
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
|
||||
)
|
||||
|
||||
# 第四阶段:完成处理
|
||||
if summarize_node:
|
||||
graph.add_node("summarize", summarize_node)
|
||||
graph.add_node("finalize", finalize_node)
|
||||
|
||||
# ========== 添加边 ==========
|
||||
|
||||
# 第一阶段:记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_edge(START, "retrieve_memory")
|
||||
graph.add_edge("retrieve_memory", "memory_trigger")
|
||||
else:
|
||||
graph.add_edge(START, "memory_trigger")
|
||||
|
||||
# 进入初始化
|
||||
graph.add_edge("memory_trigger", "init_state")
|
||||
|
||||
# ========== 混合路由分支(如果启用) ==========
|
||||
if use_hybrid_router:
|
||||
graph.add_edge("init_state", "hybrid_router")
|
||||
|
||||
# 从 hybrid_router 条件分支
|
||||
graph.add_conditional_edges(
|
||||
"hybrid_router",
|
||||
route_from_hybrid_decision,
|
||||
{
|
||||
"fast_chitchat": "fast_chitchat",
|
||||
"fast_rag": "fast_rag",
|
||||
"fast_tool": "fast_tool",
|
||||
"react_loop": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
# 快速路径的完成检查
|
||||
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
|
||||
graph.add_conditional_edges(
|
||||
fast_node,
|
||||
check_fast_path_success,
|
||||
{
|
||||
"llm_call": "llm_call",
|
||||
"escalate": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
info(f"✅ [图构建] 混合路由模式已启用")
|
||||
else:
|
||||
# 无混合路由,直接到 react_reason
|
||||
graph.add_edge("init_state", "react_reason")
|
||||
info(f"✅ [图构建] 纯 React 模式")
|
||||
|
||||
# ========== React 循环边(始终保留) ==========
|
||||
graph.add_conditional_edges(
|
||||
"react_reason",
|
||||
route_by_reasoning,
|
||||
{
|
||||
"rag_retrieve": "rag_retrieve",
|
||||
"web_search": "web_search",
|
||||
"contact_subgraph": "contact_subgraph",
|
||||
"dictionary_subgraph": "dictionary_subgraph",
|
||||
"news_analysis_subgraph": "news_analysis_subgraph",
|
||||
"handle_error": "handle_error",
|
||||
"llm_call": "llm_call"
|
||||
}
|
||||
)
|
||||
|
||||
# 循环边(rag、web_search、子图、error都回到 reason)
|
||||
graph.add_edge("rag_retrieve", "react_reason")
|
||||
graph.add_edge("web_search", "react_reason")
|
||||
graph.add_edge("contact_subgraph", "react_reason")
|
||||
graph.add_edge("dictionary_subgraph", "react_reason")
|
||||
graph.add_edge("news_analysis_subgraph", "react_reason")
|
||||
graph.add_edge("handle_error", "react_reason")
|
||||
|
||||
# ========== 最终完成阶段 ==========
|
||||
if llm_node is not None:
|
||||
if summarize_node:
|
||||
# 检查是否需要总结
|
||||
graph.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_summarize,
|
||||
{
|
||||
"summarize": "summarize",
|
||||
"finalize": "finalize"
|
||||
}
|
||||
)
|
||||
graph.add_edge("summarize", "finalize")
|
||||
else:
|
||||
# 没有 summarize 节点,直接 finalize
|
||||
graph.add_edge("llm_call", "finalize")
|
||||
|
||||
# 完成
|
||||
graph.add_edge("finalize", END)
|
||||
|
||||
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})")
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# ========== 兼容性:保留旧的函数名 ==========
|
||||
def build_main_graph() -> StateGraph:
|
||||
"""
|
||||
兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本
|
||||
"""
|
||||
return build_react_main_graph()
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
"build_react_main_graph",
|
||||
"build_main_graph",
|
||||
"wrap_subgraph_for_error_handling"
|
||||
]
|
||||
@@ -6,11 +6,17 @@
|
||||
|
||||
from .embedding_services import get_embedding_service
|
||||
from .rerank_services import get_rerank_service, BaseRerankService
|
||||
from .chat_services import get_small_llm_service
|
||||
from .chat_services import (
|
||||
get_small_llm_service,
|
||||
get_cached_chat_services,
|
||||
get_all_chat_services
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_embedding_service",
|
||||
"get_rerank_service",
|
||||
"get_small_llm_service",
|
||||
"get_cached_chat_services",
|
||||
"get_all_chat_services",
|
||||
"BaseRerankService"
|
||||
]
|
||||
|
||||
@@ -33,6 +33,21 @@ from app.config import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存已初始化的模型字典
|
||||
_cached_services: Dict[str, BaseChatModel] | None = None
|
||||
|
||||
|
||||
def _check_http_service_available(base_url: str, api_key: str = "", timeout: float = 2.0) -> bool:
|
||||
"""通过探测 /models 端点检查 HTTP API 是否可用(内部工具函数)"""
|
||||
try:
|
||||
import httpx
|
||||
client = httpx.Client(base_url=base_url.rstrip('/'), timeout=timeout)
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
resp = client.get("/models", headers=headers)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
@@ -54,46 +69,8 @@ class LocalVLLMChatProvider(BaseServiceProvider[BaseChatModel]):
|
||||
logger.warning("VLLM_BASE_URL 未配置")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 先测试主机名能否解析
|
||||
import httpx
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(VLLM_BASE_URL)
|
||||
host = parsed_url.hostname
|
||||
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
|
||||
|
||||
# 测试能否建立 TCP 连接(快速失败)
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(2.0)
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
sock.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"本地 VLLM 服务无法连接: {host}:{port} - {e}")
|
||||
return False
|
||||
|
||||
# 再尝试调用简单的 API(比如 models 接口)
|
||||
client = httpx.Client(base_url=VLLM_BASE_URL.rstrip('/'), timeout=5.0)
|
||||
headers = {}
|
||||
if LLM_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {LLM_API_KEY}"
|
||||
|
||||
try:
|
||||
response = client.get("/models", headers=headers)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"本地 VLLM 服务可用: {self._model}")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 如果 /v1/models 失败,也认为服务不可用
|
||||
logger.warning(f"本地 VLLM 服务响应异常")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"本地 VLLM 服务不可用: {e}")
|
||||
return False
|
||||
# 使用统一的 HTTP 探测方法
|
||||
return _check_http_service_available(VLLM_BASE_URL, LLM_API_KEY, timeout=2.0)
|
||||
|
||||
def get_service(self) -> BaseChatModel:
|
||||
"""
|
||||
@@ -238,45 +215,8 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 先测试主机名能否解析
|
||||
import httpx
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(self._base_url)
|
||||
host = parsed_url.hostname
|
||||
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
|
||||
|
||||
# 测试能否建立 TCP 连接(快速失败)
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(2.0)
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
sock.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"本地小模型服务无法连接: {host}:{port} - {e}")
|
||||
return False
|
||||
|
||||
# 再尝试调用简单的 API
|
||||
client = httpx.Client(base_url=self._base_url.rstrip('/'), timeout=5.0)
|
||||
headers = {}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
try:
|
||||
response = client.get("/models", headers=headers)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"本地小模型服务可用: {self._model}")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.warning(f"本地小模型服务响应异常")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"本地小模型服务不可用: {e}")
|
||||
return False
|
||||
# 使用统一的 HTTP 探测方法
|
||||
return _check_http_service_available(self._base_url, self._api_key, timeout=2.0)
|
||||
|
||||
def get_service(self) -> BaseChatModel:
|
||||
"""获取本地小模型服务"""
|
||||
@@ -358,25 +298,18 @@ def get_chat_service() -> BaseChatModel:
|
||||
return chain.get_available_service()
|
||||
|
||||
|
||||
def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""
|
||||
获取所有可用的生成式大模型服务(用于多模型切换)
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
|
||||
"""
|
||||
def _init_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""实际初始化所有可用模型(仅在首次调用)"""
|
||||
services = {}
|
||||
|
||||
for name, provider_factory in CHAT_PROVIDERS.items():
|
||||
try:
|
||||
provider = provider_factory()
|
||||
if provider.is_available():
|
||||
logger.info(f"模型 '{name}' 可用")
|
||||
services[name] = provider.get_service()
|
||||
else:
|
||||
logger.warning(f"模型 '{name}' 不可用,跳过")
|
||||
logger.info(f"已加载模型: {name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"初始化模型 '{name}' 失败: {e}")
|
||||
logger.warning(f"模型 {name} 初始化失败: {e}")
|
||||
|
||||
if not services:
|
||||
raise RuntimeError(f"没有可用的生成式大模型,尝试了: {list(CHAT_PROVIDERS.keys())}")
|
||||
@@ -384,6 +317,25 @@ def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||||
return services
|
||||
|
||||
|
||||
def get_cached_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""获取缓存的可用模型字典(用于单图动态注入)"""
|
||||
global _cached_services
|
||||
if _cached_services is None:
|
||||
_cached_services = _init_chat_services()
|
||||
return _cached_services
|
||||
|
||||
|
||||
def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||||
"""
|
||||
获取所有可用的生成式大模型服务(用于多模型切换,保留兼容性)
|
||||
新代码请使用 get_cached_chat_services() 获取缓存版本
|
||||
|
||||
Returns:
|
||||
Dict[str, BaseChatModel]: 模型名称 -> ChatModel 实例 的字典
|
||||
"""
|
||||
return get_cached_chat_services()
|
||||
|
||||
|
||||
def get_small_llm_service() -> BaseChatModel:
|
||||
"""
|
||||
获取轻量级大模型服务(用于查询改写、意图分类等简单任务)
|
||||
|
||||
@@ -4,7 +4,7 @@ Contact Subgraph Builder
|
||||
支持 API 注入的工厂模式
|
||||
"""
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
from .state import ContactState
|
||||
from .nodes import create_contact_nodes
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Dictionary Subgraph Builder - Complete
|
||||
"""
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
from .state import DictionaryState
|
||||
from .nodes import (
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
News Analysis Subgraph Builder
|
||||
"""
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
from .state import NewsAnalysisState
|
||||
from .nodes import (
|
||||
|
||||
Reference in New Issue
Block a user