refactor: 单图方案重构 + 动态模型选择 + chat_services优化
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 12m9s
## 核心改动 ### 1. 单图方案重构 - 删除了多图(self.graphs),改为单图(self.graph) - 新增 MainGraphState.current_model 字段用于运行时注入模型 - llm_call 节点改为动态选择模型(create_dynamic_llm_call_node) ### 2. chat_services 优化 - 添加 _cached_services 缓存,避免重复初始化 - 新增 get_cached_chat_services() 函数,用于单图注入 - 新增 _check_http_service_available() 统一HTTP探测逻辑 - 减少重复代码,LocalVLLMChatProvider和LocalSmallModelProvider共用探测方法 ### 3. AIAgentService 重构 - initialize() 只构建一次图,传入 chat_services 字典 - 新增 _resolve_model() 模型回退逻辑 - 新增 _build_invocation() 统一构建调用参数 - process_message() 和 process_message_stream() 改为注入 current_model - 流式处理代码拆分,增加可读性 ### 4. 新增和删除文件 - 新增:backend/app/main_graph/main_graph_builder.py(图构建) - 新增:backend/app/main_graph/subgraph_wrapper.py(子图封装) - 新增:tools/test/test_tavily_search.py(测试) - 删除:backend/app/main_graph/graph.py(旧图) - 删除:backend/app/main_graph/utils/main_graph_builder.py(旧构建器) - 删除:backend/app/main_graph/utils/__init__.py ### 5. 其他更新 - README.md:新增模型服务使用情况详解章节 - backend/app/model_services/__init__.py:新增 get_cached_chat_services 导出 ## 方案优势 - 内存优化:N张图 → 1张图 - 灵活性:运行时动态选择模型,支持同会话不同模型 - 性能:模型服务缓存,初始化仅一次 - 可维护性:减少重复代码,统一HTTP探测逻辑
This commit is contained in:
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
LangGraph 核心组件重新导出
|
||||
统一导入入口,避免直接依赖 langgraph
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, START, END, add_messages
|
||||
|
||||
__all__ = ["StateGraph", "START", "END", "add_messages"]
|
||||
229
backend/app/main_graph/main_graph_builder.py
Normal file
229
backend/app/main_graph/main_graph_builder.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
主图构建器 - 构建整合后的完整主图
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing import Dict, Any
|
||||
|
||||
from .state import MainGraphState
|
||||
from .nodes.reasoning import react_reason_node
|
||||
from .nodes.web_search import web_search_node
|
||||
from .nodes.error_handling import error_handling_node
|
||||
from .nodes.routing import init_state_node, route_by_reasoning, should_summarize
|
||||
from .nodes.hybrid_router import (
|
||||
hybrid_router_node,
|
||||
route_from_hybrid_decision,
|
||||
check_fast_path_success,
|
||||
)
|
||||
from .nodes.fast_paths import (
|
||||
fast_chitchat_node,
|
||||
fast_rag_node,
|
||||
fast_tool_node,
|
||||
)
|
||||
from .nodes.llm_call import create_dynamic_llm_call_node
|
||||
from .nodes.rag_nodes import rag_retrieve_node
|
||||
from .nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from .nodes.summarize import create_summarize_node
|
||||
from .nodes.finalize import finalize_node
|
||||
from ..subgraphs.contact import build_contact_subgraph
|
||||
from ..subgraphs.dictionary import build_dictionary_subgraph
|
||||
from ..subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from ..logger import info
|
||||
|
||||
from .subgraph_wrapper import create_subgraph_nodes
|
||||
|
||||
|
||||
# ========== 主图构建 ==========
|
||||
|
||||
def build_react_main_graph(
|
||||
chat_services: dict,
|
||||
tools=None,
|
||||
mem0_client=None,
|
||||
use_hybrid_router: bool = True
|
||||
) -> StateGraph:
|
||||
"""
|
||||
构建整合后的完整主图(支持混合路由 + 动态模型选择)
|
||||
|
||||
Args:
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
mem0_client: Mem0 客户端实例
|
||||
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
|
||||
|
||||
Returns:
|
||||
StateGraph: 构建好的图
|
||||
"""
|
||||
# 创建图
|
||||
graph = StateGraph(MainGraphState)
|
||||
|
||||
# 设置全局 mem0_client
|
||||
if mem0_client:
|
||||
set_mem0_client(mem0_client)
|
||||
|
||||
# ========== 创建节点 ==========
|
||||
|
||||
# LLM 调用节点
|
||||
llm_node = create_dynamic_llm_call_node(chat_services, tools or [])
|
||||
|
||||
# 记忆节点
|
||||
retrieve_memory_node = None
|
||||
summarize_node = None
|
||||
if mem0_client:
|
||||
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
||||
summarize_node = create_summarize_node(mem0_client)
|
||||
|
||||
# 子图节点
|
||||
contact_graph = build_contact_subgraph()
|
||||
dictionary_graph = build_dictionary_subgraph()
|
||||
news_analysis_graph = build_news_analysis_subgraph()
|
||||
subgraph_nodes = create_subgraph_nodes(
|
||||
contact_graph, dictionary_graph, news_analysis_graph
|
||||
)
|
||||
|
||||
# ========== 添加节点到图 ==========
|
||||
|
||||
# 阶段 1: 记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_node("retrieve_memory", retrieve_memory_node)
|
||||
graph.add_node("memory_trigger", memory_trigger_node)
|
||||
|
||||
# 阶段 2: 初始化
|
||||
graph.add_node("init_state", init_state_node)
|
||||
|
||||
# 阶段 3: 混合路由(可选)
|
||||
if use_hybrid_router:
|
||||
graph.add_node("hybrid_router", hybrid_router_node)
|
||||
graph.add_node("fast_chitchat", fast_chitchat_node)
|
||||
graph.add_node("fast_rag", fast_rag_node)
|
||||
graph.add_node("fast_tool", fast_tool_node)
|
||||
|
||||
# 阶段 4: React 循环推理(始终保留)
|
||||
graph.add_node("react_reason", react_reason_node)
|
||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||
graph.add_node("web_search", web_search_node)
|
||||
graph.add_node("handle_error", error_handling_node)
|
||||
|
||||
if llm_node is not None:
|
||||
graph.add_node("llm_call", llm_node)
|
||||
|
||||
# 子图节点
|
||||
for node_name, node_func in subgraph_nodes.items():
|
||||
graph.add_node(node_name, node_func)
|
||||
|
||||
# 阶段 5: 完成处理
|
||||
if summarize_node:
|
||||
graph.add_node("summarize", summarize_node)
|
||||
graph.add_node("finalize", finalize_node)
|
||||
|
||||
# ========== 添加边 ==========
|
||||
|
||||
# 阶段 1: 记忆检索
|
||||
_add_memory_edges(graph, retrieve_memory_node)
|
||||
|
||||
# 阶段 2: 初始化
|
||||
graph.add_edge("memory_trigger", "init_state")
|
||||
|
||||
# 阶段 3: 路由分支
|
||||
_add_routing_edges(graph, use_hybrid_router, llm_node)
|
||||
|
||||
# 阶段 4: React 循环边
|
||||
_add_react_loop_edges(graph, subgraph_nodes)
|
||||
|
||||
# 阶段 5: 完成阶段
|
||||
_add_finalize_edges(graph, llm_node, summarize_node)
|
||||
|
||||
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})")
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def _add_memory_edges(graph: StateGraph, retrieve_memory_node) -> None:
|
||||
"""添加记忆检索阶段的边"""
|
||||
if retrieve_memory_node:
|
||||
graph.add_edge(START, "retrieve_memory")
|
||||
graph.add_edge("retrieve_memory", "memory_trigger")
|
||||
else:
|
||||
graph.add_edge(START, "memory_trigger")
|
||||
|
||||
|
||||
def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) -> None:
|
||||
"""添加路由阶段的边"""
|
||||
if use_hybrid_router:
|
||||
graph.add_edge("init_state", "hybrid_router")
|
||||
|
||||
# 混合路由条件分支
|
||||
graph.add_conditional_edges(
|
||||
"hybrid_router",
|
||||
route_from_hybrid_decision,
|
||||
{
|
||||
"fast_chitchat": "fast_chitchat",
|
||||
"fast_rag": "fast_rag",
|
||||
"fast_tool": "fast_tool",
|
||||
"react_loop": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
# 快速路径的完成检查
|
||||
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
|
||||
graph.add_conditional_edges(
|
||||
fast_node,
|
||||
check_fast_path_success,
|
||||
{
|
||||
"llm_call": "llm_call",
|
||||
"escalate": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
info(f"✅ [图构建] 混合路由模式已启用")
|
||||
else:
|
||||
graph.add_edge("init_state", "react_reason")
|
||||
info(f"✅ [图构建] 纯 React 模式")
|
||||
|
||||
|
||||
def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> None:
|
||||
"""添加 React 循环阶段的边"""
|
||||
subgraph_names = list(subgraph_nodes.keys())
|
||||
|
||||
# React 推理的条件分支
|
||||
graph.add_conditional_edges(
|
||||
"react_reason",
|
||||
route_by_reasoning,
|
||||
{
|
||||
"rag_retrieve": "rag_retrieve",
|
||||
"web_search": "web_search",
|
||||
**{name: name for name in subgraph_names},
|
||||
"handle_error": "handle_error",
|
||||
"llm_call": "llm_call"
|
||||
}
|
||||
)
|
||||
|
||||
# 循环边(回到 react_reason)
|
||||
loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names
|
||||
for node_name in loop_back_nodes:
|
||||
graph.add_edge(node_name, "react_reason")
|
||||
|
||||
|
||||
def _add_finalize_edges(graph: StateGraph, llm_node, summarize_node) -> None:
|
||||
"""添加完成阶段的边"""
|
||||
if llm_node is not None:
|
||||
if summarize_node:
|
||||
graph.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_summarize,
|
||||
{
|
||||
"summarize": "summarize",
|
||||
"finalize": "finalize"
|
||||
}
|
||||
)
|
||||
graph.add_edge("summarize", "finalize")
|
||||
else:
|
||||
graph.add_edge("llm_call", "finalize")
|
||||
|
||||
graph.add_edge("finalize", END)
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
"build_react_main_graph",
|
||||
]
|
||||
@@ -6,8 +6,8 @@
|
||||
from .reasoning import react_reason_node
|
||||
from .web_search import web_search_node
|
||||
from .error_handling import error_handling_node
|
||||
from .routing import init_state_node, route_by_reasoning
|
||||
from .llm_call import create_llm_call_node
|
||||
from .routing import init_state_node, route_by_reasoning, should_summarize
|
||||
from .llm_call import create_dynamic_llm_call_node
|
||||
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
|
||||
|
||||
# 记忆节点
|
||||
@@ -38,7 +38,8 @@ __all__ = [
|
||||
"web_search_node",
|
||||
"error_handling_node",
|
||||
"route_by_reasoning",
|
||||
"create_llm_call_node",
|
||||
"should_summarize",
|
||||
"create_dynamic_llm_call_node",
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
# 记忆节点
|
||||
|
||||
@@ -5,7 +5,7 @@ LLM 调用节点模块
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
@@ -14,29 +14,34 @@ from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
|
||||
def create_llm_call_node(llm, tools: list):
|
||||
|
||||
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
|
||||
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
|
||||
|
||||
Args:
|
||||
llm: LangChain LLM 实例
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
# 预构建所有模型的 tools 绑定(避免每次调用都 bind)
|
||||
bound_models: Dict[str, Any] = {}
|
||||
for name, llm in chat_services.items():
|
||||
if tools:
|
||||
bound_models[name] = llm.bind_tools(tools)
|
||||
else:
|
||||
bound_models[name] = llm
|
||||
|
||||
# 预构建 prompt
|
||||
prompt = create_system_prompt(tools)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
chain = prompt | llm_with_tools
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
LLM 调用节点(动态选择模型)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
@@ -46,7 +51,7 @@ def create_llm_call_node(llm, tools: list):
|
||||
更新后的状态字典
|
||||
"""
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
|
||||
memory_context = getattr(state, "memory_context", "暂无用户信息")
|
||||
start_time = time.time()
|
||||
|
||||
@@ -62,9 +67,20 @@ def create_llm_call_node(llm, tools: list):
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
}
|
||||
|
||||
|
||||
# 动态选择模型
|
||||
model_name = getattr(state, "current_model", "")
|
||||
if not model_name or model_name not in bound_models:
|
||||
# 回退到第一个可用模型
|
||||
fallback_name = next(iter(bound_models.keys()))
|
||||
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
|
||||
model_name = fallback_name
|
||||
|
||||
llm_with_tools = bound_models[model_name]
|
||||
info(f"[llm_call] 使用模型: {model_name}")
|
||||
|
||||
try:
|
||||
# 添加 RAG 上下文到消息
|
||||
# 添加上下文到消息
|
||||
messages_with_context = list(state.messages)
|
||||
if state.rag_context:
|
||||
from langchain_core.messages import SystemMessage
|
||||
@@ -77,9 +93,10 @@ def create_llm_call_node(llm, tools: list):
|
||||
break
|
||||
if not inserted:
|
||||
messages_with_context.insert(0, rag_system_msg)
|
||||
|
||||
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chain = prompt | llm_with_tools
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
@@ -89,7 +106,7 @@ def create_llm_call_node(llm, tools: list):
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
@@ -97,14 +114,14 @@ def create_llm_call_node(llm, tools: list):
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
@@ -112,33 +129,33 @@ def create_llm_call_node(llm, tools: list):
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f"📥 [LLM输出] 模型: {model_name} 返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 检查是否有工具调用
|
||||
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
|
||||
|
||||
@@ -151,21 +168,22 @@ def create_llm_call_node(llm, tools: list):
|
||||
"final_result": response.content,
|
||||
"success": True,
|
||||
"current_phase": "done",
|
||||
"has_tool_calls": has_tool_calls
|
||||
"has_tool_calls": has_tool_calls,
|
||||
"current_model": model_name # 记录实际使用的模型
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f"\n❌ [LLM错误] 模型 {model_name} 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
@@ -178,10 +196,11 @@ def create_llm_call_node(llm, tools: list):
|
||||
"turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1,
|
||||
"final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。",
|
||||
"success": False,
|
||||
"current_phase": "done"
|
||||
"current_phase": "done",
|
||||
"current_model": model_name
|
||||
}
|
||||
|
||||
|
||||
log_state_change("llm_call", state, "离开(异常)")
|
||||
return error_result
|
||||
|
||||
return call_llm
|
||||
|
||||
return call_llm
|
||||
|
||||
@@ -118,3 +118,21 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
|
||||
info(f"[条件路由] 动作={latest_action}, 目标={target}")
|
||||
return target
|
||||
|
||||
|
||||
# ========== 完成阶段条件路由函数 ==========
|
||||
|
||||
def should_summarize(state: MainGraphState) -> str:
|
||||
"""
|
||||
检查是否需要总结对话(对话足够长时)
|
||||
|
||||
Args:
|
||||
state: 当前图状态
|
||||
|
||||
Returns:
|
||||
"summarize" 或 "finalize"
|
||||
"""
|
||||
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
|
||||
return "summarize"
|
||||
else:
|
||||
return "finalize"
|
||||
|
||||
@@ -6,7 +6,7 @@ Main Graph State Definition - React Mode Enhanced
|
||||
from enum import Enum, auto
|
||||
from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List
|
||||
from dataclasses import dataclass, field
|
||||
from app.main_graph.graph import add_messages
|
||||
from langgraph.graph import add_messages
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ class MainGraphState:
|
||||
# ========== 主图控制字段 ==========
|
||||
user_query: str = ""
|
||||
current_action: CurrentAction = CurrentAction.NONE
|
||||
current_model: str = "" # 新增:本次请求使用的模型
|
||||
intent_confidence: float = 0.0
|
||||
|
||||
# ========== React 推理专用字段 ==========
|
||||
|
||||
159
backend/app/main_graph/subgraph_wrapper.py
Normal file
159
backend/app/main_graph/subgraph_wrapper.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
子图包装器 - 为子图添加错误处理和事件追踪
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ..logger import info
|
||||
|
||||
|
||||
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
"""
|
||||
包装子图,使其错误能传递给主图
|
||||
|
||||
Args:
|
||||
subgraph: 编译好的子图
|
||||
name: 子图名称(用于错误标识)
|
||||
|
||||
Returns: 包装后的节点函数
|
||||
"""
|
||||
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
# 发送子图开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"开始执行 {name} 子图..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
|
||||
|
||||
try:
|
||||
# 调用子图
|
||||
result = subgraph.invoke(state)
|
||||
|
||||
# 更新主图状态
|
||||
subgraph_result = None
|
||||
if name == "contact":
|
||||
state.contact_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "dictionary":
|
||||
state.dictionary_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "news_analysis":
|
||||
state.news_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
|
||||
# 设置最终结果
|
||||
if subgraph_result:
|
||||
state.final_result = subgraph_result
|
||||
else:
|
||||
state.final_result = "子图执行完成"
|
||||
|
||||
# 标记成功
|
||||
state.success = True
|
||||
state.current_phase = "done"
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "subgraph_completed",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name}子图执行完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 发送子图完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行完成"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
# 捕获子图错误,传递给主图
|
||||
error_record = ErrorRecord(
|
||||
error_type=f"{name}SubgraphError",
|
||||
error_message=str(e),
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source=f"{name}_subgraph",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
context={"user_query": state.user_query}
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
state.success = False
|
||||
|
||||
# 发送子图错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行失败: {str(e)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
return wrapped_node
|
||||
|
||||
|
||||
def create_subgraph_nodes(contact_graph, dictionary_graph, news_analysis_graph) -> Dict[str, Any]:
|
||||
"""
|
||||
创建所有子图节点的字典
|
||||
|
||||
Args:
|
||||
contact_graph: 联系人子图
|
||||
dictionary_graph: 词典子图
|
||||
news_analysis_graph: 新闻分析子图
|
||||
|
||||
Returns:
|
||||
子图节点字典 {name: wrapped_node}
|
||||
"""
|
||||
return {
|
||||
"contact_subgraph": wrap_subgraph_for_error_handling(
|
||||
contact_graph.compile(), "contact"
|
||||
),
|
||||
"dictionary_subgraph": wrap_subgraph_for_error_handling(
|
||||
dictionary_graph.compile(), "dictionary"
|
||||
),
|
||||
"news_analysis_subgraph": wrap_subgraph_for_error_handling(
|
||||
news_analysis_graph.compile(), "news_analysis"
|
||||
),
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
"""主图工具函数"""
|
||||
@@ -1,371 +0,0 @@
|
||||
"""
|
||||
整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState
|
||||
"""
|
||||
|
||||
from ..graph import StateGraph, START, END
|
||||
from typing import Dict, Any, Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ..nodes.reasoning import react_reason_node
|
||||
from ..nodes.web_search import web_search_node
|
||||
from ..nodes.error_handling import error_handling_node
|
||||
from ..nodes.routing import init_state_node, route_by_reasoning
|
||||
from ..nodes.hybrid_router import (
|
||||
hybrid_router_node,
|
||||
route_from_hybrid_decision,
|
||||
check_fast_path_success,
|
||||
)
|
||||
from ..nodes.fast_paths import (
|
||||
fast_chitchat_node,
|
||||
fast_rag_node,
|
||||
fast_tool_node,
|
||||
)
|
||||
from ..nodes.llm_call import create_llm_call_node
|
||||
from ..nodes.rag_nodes import rag_retrieve_node
|
||||
from ..nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from ..nodes.summarize import create_summarize_node
|
||||
from ..nodes.finalize import finalize_node
|
||||
from ...subgraphs.contact import build_contact_subgraph
|
||||
from ...subgraphs.dictionary import build_dictionary_subgraph
|
||||
from ...subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...logger import info, debug
|
||||
|
||||
|
||||
# ========== 检查是否需要总结 ==========
|
||||
def should_summarize(state: MainGraphState) -> str:
|
||||
"""
|
||||
检查是否需要总结对话(对话足够长时)
|
||||
|
||||
Args:
|
||||
state: 当前图状态
|
||||
|
||||
Returns:
|
||||
"summarize" 或 "finalize"
|
||||
"""
|
||||
if state.turns_since_last_summary >= 5: # 每5轮对话总结一次
|
||||
return "summarize"
|
||||
else:
|
||||
return "finalize"
|
||||
|
||||
|
||||
# ========== 子图包装器(处理子图错误传递)==========
|
||||
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
"""
|
||||
包装子图,使其错误能传递给主图
|
||||
|
||||
Args:
|
||||
subgraph: 编译好的子图
|
||||
name: 子图名称(用于错误标识)
|
||||
|
||||
Returns: 包装后的节点函数
|
||||
"""
|
||||
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
# 发送子图开始事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_start",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"开始执行 {name} 子图..."
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
|
||||
|
||||
try:
|
||||
# 调用子图
|
||||
result = subgraph.invoke(state)
|
||||
|
||||
# 更新主图状态
|
||||
subgraph_result = None
|
||||
if name == "contact":
|
||||
state.contact_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "dictionary":
|
||||
state.dictionary_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
elif name == "news_analysis":
|
||||
state.news_result = result
|
||||
subgraph_result = result.get("final_result", "")
|
||||
|
||||
# 关键:设置最终结果
|
||||
if subgraph_result:
|
||||
state.final_result = subgraph_result
|
||||
else:
|
||||
state.final_result = "子图执行完成"
|
||||
|
||||
# 标记成功
|
||||
state.success = True
|
||||
state.current_phase = "done"
|
||||
# 标记不再需要推理,避免循环
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "subgraph_completed",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name}子图执行完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 发送子图完成事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_complete",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行完成"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
# 捕获子图错误,传递给主图
|
||||
from ..state import ErrorRecord, ErrorSeverity
|
||||
from datetime import datetime
|
||||
|
||||
error_record = ErrorRecord(
|
||||
error_type=f"{name}SubgraphError",
|
||||
error_message=str(e),
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source=f"{name}_subgraph",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=1,
|
||||
context={"user_query": state.user_query}
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
state.success = False
|
||||
|
||||
# 发送子图错误事件
|
||||
if config:
|
||||
try:
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks:
|
||||
await adispatch_custom_event(
|
||||
"react_reasoning",
|
||||
{
|
||||
"step": state.reasoning_step,
|
||||
"action": f"{name}_subgraph_error",
|
||||
"confidence": 1.0,
|
||||
"reasoning": f"{name} 子图执行失败: {str(e)}"
|
||||
},
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
|
||||
|
||||
return state
|
||||
|
||||
return wrapped_node
|
||||
|
||||
# ========== 主图构建 ==========
|
||||
|
||||
def build_react_main_graph(llm=None, tools=None, mem0_client=None, use_hybrid_router: bool = True) -> StateGraph:
|
||||
"""
|
||||
构建整合后的完整主图(支持混合路由)
|
||||
|
||||
Args:
|
||||
llm: LangChain ChatModel 实例
|
||||
tools: 工具列表
|
||||
mem0_client: Mem0 客户端实例
|
||||
use_hybrid_router: 是否使用混合路由(快速路径 + React 循环)
|
||||
|
||||
Returns:
|
||||
StateGraph: 构建好的图
|
||||
"""
|
||||
# 创建图
|
||||
graph = StateGraph(MainGraphState)
|
||||
|
||||
# 设置全局 mem0_client
|
||||
if mem0_client:
|
||||
set_mem0_client(mem0_client)
|
||||
|
||||
# 创建节点
|
||||
llm_node = None
|
||||
if llm is not None:
|
||||
llm_node = create_llm_call_node(llm, tools or [])
|
||||
|
||||
retrieve_memory_node = None
|
||||
summarize_node = None
|
||||
if mem0_client:
|
||||
retrieve_memory_node = create_retrieve_memory_node(mem0_client)
|
||||
summarize_node = create_summarize_node(mem0_client)
|
||||
|
||||
# ========== 添加节点 ==========
|
||||
|
||||
# 第一阶段:记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_node("retrieve_memory", retrieve_memory_node)
|
||||
graph.add_node("memory_trigger", memory_trigger_node)
|
||||
|
||||
# 第二阶段:初始化
|
||||
graph.add_node("init_state", init_state_node)
|
||||
|
||||
# ========== 混合路由节点(如果启用) ==========
|
||||
if use_hybrid_router:
|
||||
graph.add_node("hybrid_router", hybrid_router_node)
|
||||
graph.add_node("fast_chitchat", fast_chitchat_node)
|
||||
graph.add_node("fast_rag", fast_rag_node)
|
||||
graph.add_node("fast_tool", fast_tool_node)
|
||||
|
||||
# 第三阶段:React 循环推理(始终保留)
|
||||
graph.add_node("react_reason", react_reason_node)
|
||||
graph.add_node("rag_retrieve", rag_retrieve_node)
|
||||
graph.add_node("web_search", web_search_node)
|
||||
graph.add_node("handle_error", error_handling_node)
|
||||
|
||||
if llm_node is not None:
|
||||
graph.add_node("llm_call", llm_node)
|
||||
|
||||
# 子图节点
|
||||
contact_graph = build_contact_subgraph()
|
||||
dictionary_graph = build_dictionary_subgraph()
|
||||
news_analysis_graph = build_news_analysis_subgraph()
|
||||
|
||||
graph.add_node(
|
||||
"contact_subgraph",
|
||||
wrap_subgraph_for_error_handling(contact_graph.compile(), "contact")
|
||||
)
|
||||
graph.add_node(
|
||||
"dictionary_subgraph",
|
||||
wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary")
|
||||
)
|
||||
graph.add_node(
|
||||
"news_analysis_subgraph",
|
||||
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
|
||||
)
|
||||
|
||||
# 第四阶段:完成处理
|
||||
if summarize_node:
|
||||
graph.add_node("summarize", summarize_node)
|
||||
graph.add_node("finalize", finalize_node)
|
||||
|
||||
# ========== 添加边 ==========
|
||||
|
||||
# 第一阶段:记忆检索
|
||||
if retrieve_memory_node:
|
||||
graph.add_edge(START, "retrieve_memory")
|
||||
graph.add_edge("retrieve_memory", "memory_trigger")
|
||||
else:
|
||||
graph.add_edge(START, "memory_trigger")
|
||||
|
||||
# 进入初始化
|
||||
graph.add_edge("memory_trigger", "init_state")
|
||||
|
||||
# ========== 混合路由分支(如果启用) ==========
|
||||
if use_hybrid_router:
|
||||
graph.add_edge("init_state", "hybrid_router")
|
||||
|
||||
# 从 hybrid_router 条件分支
|
||||
graph.add_conditional_edges(
|
||||
"hybrid_router",
|
||||
route_from_hybrid_decision,
|
||||
{
|
||||
"fast_chitchat": "fast_chitchat",
|
||||
"fast_rag": "fast_rag",
|
||||
"fast_tool": "fast_tool",
|
||||
"react_loop": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
# 快速路径的完成检查
|
||||
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
|
||||
graph.add_conditional_edges(
|
||||
fast_node,
|
||||
check_fast_path_success,
|
||||
{
|
||||
"llm_call": "llm_call",
|
||||
"escalate": "react_reason"
|
||||
}
|
||||
)
|
||||
|
||||
info(f"✅ [图构建] 混合路由模式已启用")
|
||||
else:
|
||||
# 无混合路由,直接到 react_reason
|
||||
graph.add_edge("init_state", "react_reason")
|
||||
info(f"✅ [图构建] 纯 React 模式")
|
||||
|
||||
# ========== React 循环边(始终保留) ==========
|
||||
graph.add_conditional_edges(
|
||||
"react_reason",
|
||||
route_by_reasoning,
|
||||
{
|
||||
"rag_retrieve": "rag_retrieve",
|
||||
"web_search": "web_search",
|
||||
"contact_subgraph": "contact_subgraph",
|
||||
"dictionary_subgraph": "dictionary_subgraph",
|
||||
"news_analysis_subgraph": "news_analysis_subgraph",
|
||||
"handle_error": "handle_error",
|
||||
"llm_call": "llm_call"
|
||||
}
|
||||
)
|
||||
|
||||
# 循环边(rag、web_search、子图、error都回到 reason)
|
||||
graph.add_edge("rag_retrieve", "react_reason")
|
||||
graph.add_edge("web_search", "react_reason")
|
||||
graph.add_edge("contact_subgraph", "react_reason")
|
||||
graph.add_edge("dictionary_subgraph", "react_reason")
|
||||
graph.add_edge("news_analysis_subgraph", "react_reason")
|
||||
graph.add_edge("handle_error", "react_reason")
|
||||
|
||||
# ========== 最终完成阶段 ==========
|
||||
if llm_node is not None:
|
||||
if summarize_node:
|
||||
# 检查是否需要总结
|
||||
graph.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_summarize,
|
||||
{
|
||||
"summarize": "summarize",
|
||||
"finalize": "finalize"
|
||||
}
|
||||
)
|
||||
graph.add_edge("summarize", "finalize")
|
||||
else:
|
||||
# 没有 summarize 节点,直接 finalize
|
||||
graph.add_edge("llm_call", "finalize")
|
||||
|
||||
# 完成
|
||||
graph.add_edge("finalize", END)
|
||||
|
||||
info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})")
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# ========== 兼容性:保留旧的函数名 ==========
|
||||
def build_main_graph() -> StateGraph:
|
||||
"""
|
||||
兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本
|
||||
"""
|
||||
return build_react_main_graph()
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
"build_react_main_graph",
|
||||
"build_main_graph",
|
||||
"wrap_subgraph_for_error_handling"
|
||||
]
|
||||
Reference in New Issue
Block a user