Files
ailine/backend/app/agent/agent_service.py
root 0f1691b578
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
修复:更新 msgpack 序列化配置,加入新的类型
修复问题:
1. 更新模块路径从 "app.core.intent" 到 "backend.app.core.intent"
2. 添加新的状态类型:
   - ReactReasoningState
   - HybridRouterState
   - FastPathState
3. 添加 HybridRouterResult
2026-05-06 14:36:16 +08:00

441 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 单图方案 + 动态模型选择
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
# 本地模块
from ..model_services import get_cached_chat_services
from ..main_graph.main_graph_builder import build_react_main_graph
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from ..main_graph.config import set_stream_writer
from ..main_graph.utils.rag_initializer import init_rag_tool
from backend.app.core.intent_classifier import get_intent_classifier
from backend.app.logger import debug, info, warning, error
from ..main_graph.state import MainGraphState, CurrentAction
# ========== 自定义类型序列化器 ==========
def create_serde() -> JsonPlusSerializer:
"""创建带自定义类型注册的序列化器"""
from backend.app.core.intent import ReasoningAction, RetrievalConfig, ReasoningResult
from backend.app.main_graph.state import (
CurrentAction, ErrorSeverity, ErrorRecord,
ReactReasoningState, HybridRouterState, FastPathState
)
from backend.app.main_graph.nodes.hybrid_router import HybridRouterResult
return JsonPlusSerializer(
allowed_msgpack_modules=[
("backend.app.core.intent", "ReasoningAction"),
("backend.app.core.intent", "RetrievalConfig"),
("backend.app.core.intent", "ReasoningResult"),
("backend.app.main_graph.state", "CurrentAction"),
("backend.app.main_graph.state", "ErrorSeverity"),
("backend.app.main_graph.state", "ErrorRecord"),
("backend.app.main_graph.state", "ReactReasoningState"),
("backend.app.main_graph.state", "HybridRouterState"),
("backend.app.main_graph.state", "FastPathState"),
("backend.app.main_graph.nodes.hybrid_router", "HybridRouterResult"),
]
)
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graph = None # 只有一张图
self.chat_services = None # 缓存的模型字典
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
# 添加:意图分类器
self.intent_classifier = get_intent_classifier()
# RAG 管道(可选,需要时设置)
self.rag_pipeline = None
# Mem0 客户端
self.mem0_client = None
async def initialize(self):
# 0. 初始化 Mem0 客户端
from ..memory.mem0_client import Mem0Client
self.mem0_client = Mem0Client()
# 1. 初始化 RAG 工具(如果需要)
rag_tool = await init_rag_tool()
if rag_tool:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
self.rag_tool = rag_tool # 保存到实例变量,供 config 注入
# 2. 获取缓存的模型字典
self.chat_services = get_cached_chat_services()
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
# 3. 只构建一次图(传入 chat_services 字典)
info(f"🔄 构建单图...")
graph_builder = build_react_main_graph(
chat_services=self.chat_services,
tools=self.tools,
mem0_client=self.mem0_client
)
# 注意serde 已在创建 checkpointer 时传入,这里只需传入 checkpointer
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
info(f"✅ 单图初始化完成")
return self
def _resolve_model(self, model: str) -> str:
"""
解析并验证模型名称,不可用时回退到第一个可用模型
Args:
model: 目标模型名称
Returns:
实际使用的模型名称
"""
if not model or model not in self.chat_services:
fallback = next(iter(self.chat_services.keys()))
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
return fallback
return model
def _build_invocation(
self, message: str, thread_id: str, model: str, user_id: str
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
构建图调用所需的 config 和 input_state
Args:
message: 用户消息
thread_id: 会话 ID
model: 模型名称
user_id: 用户 ID
Returns:
(config, input_state) 元组
"""
config = {
"configurable": {
"thread_id": thread_id,
"rag_tool": getattr(self, "rag_tool", None),
},
"metadata": {"user_id": user_id}
}
input_state = {
"user_query": message,
"messages": [{"role": "user", "content": message}],
"user_id": user_id,
"current_model": model,
"current_action": CurrentAction.NONE
}
return config, input_state
async def process_message(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
result = await self.graph.ainvoke(input_state, config=config)
reply = result.get("final_result", "")
if not reply and result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
actual_model = result.get("current_model", resolved_model)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time,
"model_used": actual_model
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def _handle_message_chunk(
self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""处理 messages 类型的 chunk"""
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
new_current_node = current_node
# 检测节点变化,发送节点开始事件
if node_name != current_node:
if current_node:
yield {"type": "node_end", "node": current_node}
yield {"type": "node_start", "node": node_name}
new_current_node = node_name
# 处理消息内容
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# 处理思考过程
if reasoning_token:
yield {
"type": "llm_token",
"node": node_name,
"reasoning_token": reasoning_token
}
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
# 记录工具调用开始,避免重复
if tool_call_id and tool_call_id not in tool_calls_in_progress:
tool_calls_in_progress[tool_call_id] = {
"name": tool_name,
"args": tool_args
}
yield {
"type": "tool_call_start",
"tool": tool_name,
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token
elif token_content:
yield {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token
}
# 返回更新后的 current_node
yield {"type": "_update_state", "current_node": new_current_node}
async def _handle_updates_chunk(
self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str
) -> AsyncGenerator[Dict[str, Any], None]:
"""处理 updates 类型的 chunk"""
updates_data = chunk["data"]
new_actual_model = actual_model_used
debug(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}")
# 特别检查 final_result 和 current_model
if isinstance(updates_data, dict):
if "final_result" in updates_data:
debug(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...")
if "current_model" in updates_data:
new_actual_model = updates_data["current_model"]
info(f"[Stream] 实际使用模型: {new_actual_model}")
serialized_data = self._serialize_value(updates_data)
# 检查是否有人工审核请求
if "review_pending" in serialized_data and serialized_data["review_pending"]:
review_id = serialized_data.get("review_id", "")
content_to_review = serialized_data.get("content_to_review", "")
yield {
"type": "human_review_request",
"review_id": review_id,
"content": content_to_review
}
# 检查是否有工具结果
if "messages" in serialized_data:
for msg in serialized_data["messages"]:
# 检测工具结果消息
if msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id", "")
tool_name = msg.get("name", "")
tool_result = msg.get("content", "")
if tool_call_id and tool_call_id in tool_calls_in_progress:
yield {
"type": "tool_call_end",
"tool": tool_name,
"id": tool_call_id,
"result": tool_result
}
del tool_calls_in_progress[tool_call_id]
yield {
"type": "state_update",
"data": serialized_data
}
# 返回更新后的模型
yield {"type": "_update_state", "actual_model_used": new_actual_model}
async def _handle_custom_chunk(self, chunk: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]:
"""处理 custom 类型的 chunk"""
custom_data = chunk["data"]
# 处理我们从 react_reason_node 发送的自定义推理事件
if isinstance(custom_data, dict):
# 检查是否是我们的推理事件
if "action" in custom_data and "reasoning" in custom_data:
yield {
"type": "react_reasoning",
"step": custom_data.get("step", 1),
"action": custom_data.get("action", "unknown"),
"confidence": custom_data.get("confidence", 0),
"reasoning": custom_data.get("reasoning", "")
}
else:
# 处理其他自定义事件
serialized_data = self._serialize_value(custom_data)
yield {
"type": "custom",
"data": serialized_data
}
else:
# 处理其他自定义事件
serialized_data = self._serialize_value(custom_data)
yield {
"type": "custom",
"data": serialized_data
}
async def process_message_stream(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> AsyncGenerator[Dict[str, Any], None]:
"""流式处理消息,返回异步生成器"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
# ========== 意图识别(保留用于日志和后续路由)==========
intent_result = await self.intent_classifier.classify(message)
info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})")
info(f"📝 推理: {intent_result.reasoning}")
# 注入意图到状态(让 hybrid_router 可以利用)
input_state["intent_type"] = intent_result.intent_type.value
input_state["intent_confidence"] = intent_result.confidence
# 发送意图分类事件
yield {
"type": "intent_classified",
"intent": intent_result.intent_type.value,
"confidence": intent_result.confidence,
"reasoning": intent_result.reasoning
}
# 发送路径决策事件(目前硬编码,但状态中有意图信息供后续使用)
yield {
"type": "path_decision",
"path": "react_loop",
"intent": intent_result.intent_type.value
}
# =============================================
# ========== React 循环路径 ==========
info(f"🚀 开始执行单图,指定模型: {resolved_model}")
current_node = None
tool_calls_in_progress: Dict[str, Any] = {}
actual_model_used = resolved_model
chunk_count = 0
full_message_content = ""
try:
info(f"📡 开始调用 graph.astream()...")
async for chunk in self.graph.astream(
input_state,
config=config,
stream_mode=["messages", "updates", "custom"],
version="v2",
subgraphs=True
):
chunk_count += 1
chunk_type = chunk["type"]
if chunk_type == "messages":
async for event in self._handle_message_chunk(
chunk, current_node, tool_calls_in_progress
):
if event.get("type") == "_update_state":
current_node = event.get("current_node", current_node)
else:
# 如果是 llm_call 节点的 token收集完整消息
if (
event.get("type") == "llm_token"
and event.get("node") == "llm_call"
and "token" in event
):
full_message_content += event["token"]
yield event
elif chunk_type == "updates":
async for event in self._handle_updates_chunk(
chunk, tool_calls_in_progress, actual_model_used
):
if event.get("type") == "_update_state":
actual_model_used = event.get("actual_model_used", actual_model_used)
else:
yield event
elif chunk_type == "custom":
async for event in self._handle_custom_chunk(chunk):
yield event
# 完整消息集合完成后,一次性打印
info(f"✅ graph.astream() 完成,共 {chunk_count} 个 chunks")
if full_message_content:
info(f"📄 完整消息内容: {repr(full_message_content)}")
info(f"🤖 实际使用模型: {actual_model_used}")
except Exception as e:
error(f"❌ 执行单图时出错: {e}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
yield {
"type": "error",
"message": str(e)
}
finally:
# 无论成功或失败,都发送结束事件,保证前端平稳关闭
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "done",
"model_used": actual_model_used
}