refactor: 统一使用新版 React 模式图,移除旧版 GraphBuilder 调用

This commit is contained in:
2026-05-01 00:13:13 +08:00
parent 3e438b6e1c
commit 9d4cf15c96

View File

@@ -7,7 +7,7 @@ import json
import asyncio
# 本地模块
from app.main_graph.graph_builder import GraphBuilder, GraphContext
from app.main_graph.utils.subgraph_builder import build_react_main_graph
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.main_graph.config import set_stream_writer
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
@@ -36,13 +36,12 @@ class AIAgentService:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
# 2. 构建各模型的 Graph
# 2. 构建各模型的 Graph(使用新版 React 模式)
chat_services = get_all_chat_services()
for name, llm in chat_services.items():
try:
info(f"🔄 初始化模型 '{name}'...")
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
graph = build_react_main_graph().compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
@@ -67,14 +66,22 @@ class AIAgentService:
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
# 新版状态输入:传入完整的 MainGraphState关键是 user_query
from app.main_graph.state import MainGraphState, CurrentAction
input_state = {
"user_query": message,
"messages": [{"role": "user", "content": message}],
"user_id": user_id,
"current_action": CurrentAction.NONE
}
result = await graph.ainvoke(input_state, config=config, context=context)
result = await graph.ainvoke(input_state, config=config)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
reply = result.get("final_result", "")
if not reply and result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("debug_info", {}).get("token_usage", {})
elapsed_time = result.get("debug_info", {}).get("elapsed_time", 0.0)
return {
"reply": reply,