Files
ailine/backend/app/agent/service.py
root 8b354b7ccc
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 47m14s
重构代码,统一config配置
2026-04-21 11:02:16 +08:00

154 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
# 本地模块
from ..graph.graph_builder import GraphBuilder, GraphContext
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from .llm_factory import LLMFactory
from .rag_initializer import init_rag_tool
from ..logger import info, warning
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graphs = {}
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
async def initialize(self):
# 1. 初始化 RAG 工具(如果需要)
rag_tool = await init_rag_tool(LLMFactory.create_local)
if rag_tool:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
# 2. 构建各模型的 Graph
for name, creator in LLMFactory.CREATORS.items():
try:
info(f"🔄 初始化模型 '{name}'...")
llm = creator()
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
available = list(self.graphs.keys())
if not available:
raise RuntimeError("没有可用的模型")
model = available[0]
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
graph = self.graphs[model]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""流式处理消息,返回异步生成器"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"],
version="v2",
subgraphs=True
):
chunk_type = chunk["type"]
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token,
"metadata": metadata
}
elif chunk_type == "updates":
updates_data = chunk["data"]
serialized_data = self._serialize_value(updates_data)
processed_event = {
"type": "state_update",
"data": serialized_data
}
if "messages" in serialized_data:
processed_event["messages"] = serialized_data["messages"]
elif chunk_type == "custom":
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
else:
continue
if processed_event:
yield processed_event