Files
ailine/backend/app/agent/agent_service.py
root 5b41598d50
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m41s
重构:简化流式架构,将 ReAct 循环移入 agent 节点
主要变更:
- 简化 agent_service:移除复杂双协程,只用 stream_mode=["updates"]
- stream_context:提供更清晰的 API (set_stream_queue/get_stream_queue)
- main_graph_builder:简化图结构,移除 tools 节点和条件边
- agent 节点:包含完整 ReAct 循环 + 流式 Tool Calling 拼接
- 前端:适配新的事件格式
- 添加测试文件:test_full_react_streaming.py, test_stream.py
2026-05-07 02:56:35 +08:00

196 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 完全简化版本!
按照指南实现,不用 stream_mode="messages" 避免重复 token
"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
# 本地模块
from backend.app.model_services import get_cached_chat_services
from backend.app.main_graph.main_graph_builder import build_agent_graph
from backend.app.logger import debug, info, warning, error
from backend.app.main_graph.state import AgentState
from .stream_context import set_stream_queue
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graph = None
self.chat_services = None
# Mem0 客户端
self.mem0_client = None
async def initialize(self):
# 0. 初始化 Mem0 客户端
from ..memory.mem0_client import Mem0Client
self.mem0_client = Mem0Client()
# 1. 获取缓存的模型字典
self.chat_services = get_cached_chat_services()
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
# 2. 构建图
info(f"🔄 构建 Agent 图...")
graph_builder = build_agent_graph(
chat_services=self.chat_services,
mem0_client=self.mem0_client
)
# 编译图
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
info(f"✅ Agent 图初始化完成")
return self
def _resolve_model(self, model: str) -> str:
"""
解析并验证模型名称,不可用时回退到第一个可用模型
Args:
model: 目标模型名称
Returns:
实际使用的模型名称
"""
if not model or model not in self.chat_services:
fallback = next(iter(self.chat_services.keys()))
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
return fallback
return model
def _build_invocation(
self, message: str, thread_id: str, model: str, user_id: str
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
构建图调用所需的 config 和 input_state
Args:
message: 用户消息
thread_id: 会话 ID
model: 模型名称
user_id: 用户 ID
Returns:
(config, input_state) 元组
"""
from langchain_core.messages import HumanMessage
config = {
"configurable": {
"thread_id": thread_id,
},
"metadata": {"user_id": user_id}
}
input_state = {
"messages": [HumanMessage(content=message)],
"user_id": user_id,
}
return config, input_state
async def process_message(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
result = await self.graph.ainvoke(input_state, config=config)
reply = ""
if result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time,
"model_used": resolved_model
}
async def process_message_stream(
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
) -> AsyncGenerator[Dict[str, Any], None]:
"""流式处理消息 - 完全简化!"""
# 解析模型名称
resolved_model = self._resolve_model(model)
# 构建调用参数
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
actual_model_used = resolved_model
# 创建 token 队列
queue = asyncio.Queue()
set_stream_queue(queue) # 设置上下文变量
async def run_graph():
"""后台任务:运行 graph只获取 updates不要用 stream_mode="messages" 避免重复 token"""
try:
info(f"📡 开始调用 graph.astream()...")
# 注意:只用 stream_mode=["updates"],不要 "messages"!避免重复 token
async for chunk in self.graph.astream(
input_state,
config=config,
stream_mode=["updates"],
version="v2",
subgraphs=True
):
# 可以处理一些状态更新事件,如 final_result 等
await queue.put({
"type": "graph_update",
"data": chunk,
})
except Exception as e:
error(f"❌ 执行图时出错: {e}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
await queue.put({"type": "error", "message": str(e)})
finally:
await queue.put(None) # 结束哨兵
# 启动后台任务
bg_task = asyncio.create_task(run_graph())
try:
while True:
event = await queue.get()
if event is None:
break
yield event
except GeneratorExit:
# 客户端断开连接,取消后台任务
info("⚠️ GeneratorExit取消后台任务")
bg_task.cancel()
raise
finally:
# 保证任务被清理
if not bg_task.done():
info("⏹️ 清理后台任务")
bg_task.cancel()
try:
await bg_task
except asyncio.CancelledError:
info("✅ 后台任务已取消")
# 发送结束事件,保证前端平稳关闭
yield {
"type": "done",
"model_used": actual_model_used
}