Files
ailine/backend/app/agent/agent_service.py
root 57a917b2c6
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m36s
remove: 移除快速路径逻辑,全部走 React 模式
2026-05-01 11:24:13 +08:00

280 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
import asyncio
# 本地模块
from app.main_graph.utils.main_graph_builder import build_react_main_graph
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.main_graph.config import set_stream_writer
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
from app.main_graph.utils.rag_initializer import init_rag_tool
from app.core.intent_classifier import get_intent_classifier
from app.logger import info, warning
from app.main_graph.state import MainGraphState, CurrentAction
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graphs = {}
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
# 添加:意图分类器
self.intent_classifier = get_intent_classifier()
# RAG 管道(可选,需要时设置)
self.rag_pipeline = None
async def initialize(self):
# 1. 初始化 RAG 工具(如果需要)
def create_local_llm():
provider = LocalVLLMChatProvider()
return provider.get_service()
rag_tool = await init_rag_tool(create_local_llm)
if rag_tool:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
# 2. 构建各模型的 Graph使用新版 React 模式)
chat_services = get_all_chat_services()
for name, llm in chat_services.items():
try:
info(f"🔄 初始化模型 '{name}'...")
graph = build_react_main_graph().compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
available = list(self.graphs.keys())
if not available:
raise RuntimeError("没有可用的模型")
model = available[0]
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
graph = self.graphs[model]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
# 新版状态输入:传入完整的 MainGraphState关键是 user_query
from app.main_graph.state import MainGraphState, CurrentAction
input_state = {
"user_query": message,
"messages": [{"role": "user", "content": message}],
"user_id": user_id,
"current_action": CurrentAction.NONE
}
result = await graph.ainvoke(input_state, config=config)
reply = result.get("final_result", "")
if not reply and result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("debug_info", {}).get("token_usage", {})
elapsed_time = result.get("debug_info", {}).get("elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""流式处理消息,返回异步生成器(全部走 React 模式)"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {
"user_query": message,
"messages": [{"role": "user", "content": message}],
"user_id": user_id,
"current_action": CurrentAction.NONE
}
# ========== 意图识别(保留用于日志) ==========
intent_result = await self.intent_classifier.classify(message)
info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})")
info(f"📝 推理: {intent_result.reasoning}")
# 发送意图分类事件
yield {
"type": "intent_classified",
"intent": intent_result.intent_type.value,
"confidence": intent_result.confidence,
"reasoning": intent_result.reasoning
}
# 发送路径决策事件(现在都是 react_loop
yield {
"type": "path_decision",
"path": "react_loop",
"intent": intent_result.intent_type.value
}
# ========================================
# ========== React 循环路径 ==========
current_node = None
tool_calls_in_progress = {}
async for chunk in graph.astream(
input_state,
config=config,
stream_mode=["messages", "updates", "custom"],
version="v2",
subgraphs=True
):
chunk_type = chunk["type"]
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
# 检测节点变化,发送节点开始事件
if node_name != current_node:
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "node_start",
"node": node_name
}
current_node = node_name
# 处理消息内容
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# 处理思考过程
if reasoning_token:
processed_event = {
"type": "llm_token",
"node": node_name,
"reasoning_token": reasoning_token
}
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
# 记录工具调用开始
if tool_call_id not in tool_calls_in_progress:
tool_calls_in_progress[tool_call_id] = {
"name": tool_name,
"args": tool_args
}
yield {
"type": "tool_call_start",
"tool": tool_name,
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token
elif token_content:
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token
}
elif chunk_type == "updates":
updates_data = chunk["data"]
serialized_data = self._serialize_value(updates_data)
# 检查是否有人工审核请求
if "review_pending" in serialized_data and serialized_data["review_pending"]:
review_id = serialized_data.get("review_id", "")
content_to_review = serialized_data.get("content_to_review", "")
yield {
"type": "human_review_request",
"review_id": review_id,
"content": content_to_review
}
# 检查是否有工具结果
if "messages" in serialized_data:
for msg in serialized_data["messages"]:
# 检测工具结果消息
if msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id", "")
tool_name = msg.get("name", "")
tool_output = msg.get("content", "")
if tool_call_id in tool_calls_in_progress:
yield {
"type": "tool_call_end",
"tool": tool_name,
"id": tool_call_id,
"result": tool_output
}
del tool_calls_in_progress[tool_call_id]
processed_event = {
"type": "state_update",
"data": serialized_data
}
elif chunk_type == "custom":
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
if processed_event:
yield processed_event
# 发送结束事件
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "done"
}