89 lines
3.1 KiB
Python
89 lines
3.1 KiB
Python
"""
|
||
AI Agent 服务类
|
||
"""
|
||
|
||
from typing import AsyncGenerator, Dict, Any
|
||
|
||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||
|
||
from backend.app.model_services import get_cached_chat_services
|
||
from backend.app.main_graph.main_graph_builder import build_agent_graph
|
||
from backend.app.logger import info
|
||
from backend.app.memory.mem0_client import Mem0Client
|
||
|
||
from .service_config import ServiceConfig
|
||
from .stream_handler import run_graph_stream
|
||
|
||
|
||
class AIAgentService:
|
||
def __init__(self, checkpointer):
|
||
self.checkpointer = checkpointer
|
||
self.graph = None
|
||
self.config: ServiceConfig = None
|
||
self.mem0_client = None
|
||
|
||
async def initialize(self) -> "AIAgentService":
|
||
"""初始化 Agent 服务"""
|
||
self.mem0_client = Mem0Client()
|
||
|
||
self.chat_services = get_cached_chat_services()
|
||
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
|
||
|
||
graph_builder = build_agent_graph(
|
||
chat_services=self.chat_services,
|
||
mem0_client=self.mem0_client
|
||
)
|
||
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
||
|
||
self.config = ServiceConfig(self.chat_services)
|
||
info(f"✅ Agent 图初始化完成")
|
||
|
||
return self
|
||
|
||
def _resolve_and_build(
|
||
self, message: str, thread_id: str, model: str, user_id: str
|
||
):
|
||
"""解析模型并构建调用参数"""
|
||
resolved_model = self.config.resolve_model(model)
|
||
return resolved_model, self.config.build_invocation(
|
||
message, thread_id, resolved_model, user_id
|
||
)
|
||
|
||
async def process_message(
|
||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||
) -> dict:
|
||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||
resolved_model, (config, input_state) = self._resolve_and_build(
|
||
message, thread_id, model, user_id
|
||
)
|
||
|
||
result = await self.graph.ainvoke(input_state, config=config)
|
||
|
||
reply = result.get("final_reply", "")
|
||
if not reply and result.get("messages"):
|
||
reply = result["messages"][-1].content
|
||
|
||
return {
|
||
"reply": reply,
|
||
"token_usage": result.get("last_token_usage", {}),
|
||
"elapsed_time": result.get("last_elapsed_time", 0.0),
|
||
"model_used": resolved_model,
|
||
"metadata": result.get("metadata", {}),
|
||
}
|
||
|
||
async def process_message_stream(
|
||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||
"""流式处理消息"""
|
||
resolved_model, (config, input_state) = self._resolve_and_build(
|
||
message, thread_id, model, user_id
|
||
)
|
||
|
||
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
|
||
|
||
async for event in run_graph_stream(self.graph, input_state, config):
|
||
if event.get("type") != "done":
|
||
yield event
|
||
else:
|
||
yield {**event, "model_used": resolved_model}
|