Files
ailine/backend/app/agent/agent_service.py
root 128aad0c22
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m31s
refactor: 重构快速路径流程,统一通过 llm_call 输出
- 重构 fast_paths.py,让 fast_chitchat 和 fast_rag 都进入 llm_call 而不是直接设置 final_result
- 修改 check_fast_path_success 函数返回 'llm_call' 而不是 'success'
- 更新 main_graph_builder.py 的条件边配置,支持路由到 llm_call
- 在快速路径节点中添加清除 state.final_result 的逻辑,避免复用旧结果
- 重构 RAG 工具初始化方式,使用模块级变量管理
- 修改 finalize.py 让它返回 final_result
- 更新 agent_service.py 的 RAG 工具注入方式
- 简化 hybrid_router.py 的代码结构
- 清理 rag_nodes.py 的全局变量相关代码
- 更新相关测试文件
2026-05-05 04:32:42 +08:00

345 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
import asyncio
# 本地模块
from ..main_graph.utils.main_graph_builder import build_react_main_graph
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from ..main_graph.config import set_stream_writer
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
from ..main_graph.utils.rag_initializer import init_rag_tool
from ..core.intent_classifier import get_intent_classifier
from ..logger import info, warning, error
from ..main_graph.state import MainGraphState, CurrentAction
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graphs = {}
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
# 添加:意图分类器
self.intent_classifier = get_intent_classifier()
# RAG 管道(可选,需要时设置)
self.rag_pipeline = None
# Mem0 客户端
self.mem0_client = None
async def initialize(self):
# 0. 初始化 Mem0 客户端
from ..memory.mem0_client import Mem0Client
# 创建一个临时的 LLM 用于 Mem0用第一个可用的
chat_services = get_all_chat_services()
temp_llm = None
if chat_services:
temp_llm = list(chat_services.values())[0]
self.mem0_client = Mem0Client(temp_llm)
# 1. 初始化 RAG 工具(如果需要)
def create_local_llm():
provider = LocalVLLMChatProvider()
return provider.get_service()
rag_tool = await init_rag_tool(create_local_llm)
if rag_tool:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
self.rag_tool = rag_tool # 保存到实例变量,供 config 注入
# 2. 构建各模型的 Graph使用新版 React 模式)
for name, llm in chat_services.items():
try:
info(f"🔄 初始化模型 '{name}'...")
graph = build_react_main_graph(
llm=llm,
tools=self.tools,
mem0_client=self.mem0_client
).compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
available = list(self.graphs.keys())
if not available:
raise RuntimeError("没有可用的模型")
model = available[0]
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
graph = self.graphs[model]
config = {
"configurable": {
"thread_id": thread_id,
"rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具
},
"metadata": {"user_id": user_id}
}
# 新版状态输入:传入完整的 MainGraphState关键是 user_query
from ..main_graph.state import MainGraphState, CurrentAction
input_state = {
"user_query": message,
"messages": [{"role": "user", "content": message}],
"user_id": user_id,
"current_action": CurrentAction.NONE
}
result = await graph.ainvoke(input_state, config=config)
reply = result.get("final_result", "")
if not reply and result.get("messages"):
reply = result["messages"][-1].content
token_usage = result.get("debug_info", {}).get("token_usage", {})
elapsed_time = result.get("debug_info", {}).get("elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""流式处理消息,返回异步生成器(全部走 React 模式)"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {
"thread_id": thread_id,
"rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具
},
"metadata": {"user_id": user_id}
}
input_state = {
"user_query": message,
"messages": [{"role": "user", "content": message}],
"user_id": user_id,
"current_action": CurrentAction.NONE
}
# ========== 意图识别(保留用于日志)==========
intent_result = await self.intent_classifier.classify(message)
info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})")
info(f"📝 推理: {intent_result.reasoning}")
# 发送意图分类事件
yield {
"type": "intent_classified",
"intent": intent_result.intent_type.value,
"confidence": intent_result.confidence,
"reasoning": intent_result.reasoning
}
# 发送路径决策事件(现在都是 react_loop
yield {
"type": "path_decision",
"path": "react_loop",
"intent": intent_result.intent_type.value
}
# ========================================
# ========== React 循环路径 ==========
info(f"🚀 开始执行 React 图,模型: {model_name}")
current_node = None
tool_calls_in_progress = {}
try:
info(f"📡 开始调用 graph.astream()...")
chunk_count = 0
full_message_content = "" # 收集完整消息内容
async for chunk in graph.astream(
input_state,
config=config,
stream_mode=["messages", "updates", "custom"],
version="v2",
subgraphs=True
):
chunk_count += 1
chunk_type = chunk["type"]
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
# 检测节点变化,发送节点开始事件
if node_name != current_node:
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "node_start",
"node": node_name
}
current_node = node_name
# 处理消息内容
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# 处理思考过程
if reasoning_token:
processed_event = {
"type": "llm_token",
"node": node_name,
"reasoning_token": reasoning_token
}
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
# 记录工具调用开始
if tool_call_id not in tool_calls_in_progress:
tool_calls_in_progress[tool_call_id] = {
"name": tool_name,
"args": tool_args
}
yield {
"type": "tool_call_start",
"tool": tool_name,
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token - 只收集,不打印单个 token
elif token_content:
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token
}
if node_name == "llm_call":
full_message_content += token_content
elif chunk_type == "updates":
updates_data = chunk["data"]
info(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}")
# 特别检查 final_result
if isinstance(updates_data, dict) and "final_result" in updates_data:
info(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...")
serialized_data = self._serialize_value(updates_data)
if "review_pending" in serialized_data and serialized_data["review_pending"]:
review_id = serialized_data.get("review_id", "")
content_to_review = serialized_data.get("content_to_review", "")
yield {
"type": "human_review_request",
"review_id": review_id,
"content": content_to_review
}
# 检查是否有工具结果
if "messages" in serialized_data:
for msg in serialized_data["messages"]:
# 检测工具结果消息
if msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id", "")
tool_name = msg.get("name", "")
tool_output = msg.get("content", "")
if tool_call_id in tool_calls_in_progress:
yield {
"type": "tool_call_end",
"tool": tool_name,
"id": tool_call_id,
"result": tool_output
}
del tool_calls_in_progress[tool_call_id]
processed_event = {
"type": "state_update",
"data": serialized_data
}
elif chunk_type == "custom":
custom_data = chunk["data"]
# 处理我们从 react_reason_node 发送的自定义推理事件
if isinstance(custom_data, dict):
# 检查是否是我们的推理事件
if "action" in custom_data and "reasoning" in custom_data:
yield {
"type": "react_reasoning",
"step": custom_data.get("step", 1),
"action": custom_data.get("action", "unknown"),
"confidence": custom_data.get("confidence", 0),
"reasoning": custom_data.get("reasoning", "")
}
else:
# 处理其他自定义事件
serialized_data = self._serialize_value(custom_data)
processed_event = {
"type": "custom",
"data": serialized_data
}
else:
# 处理其他自定义事件
serialized_data = self._serialize_value(custom_data)
processed_event = {
"type": "custom",
"data": serialized_data
}
if processed_event:
yield processed_event
# 完整消息集合完成后,一次性打印
info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks")
if full_message_content:
info(f"📄 完整消息内容: {repr(full_message_content)}")
except Exception as e:
error(f"❌ 执行 React 图时出错: {e}")
import traceback
error(f"📋 堆栈: {traceback.format_exc()}")
raise
# 发送结束事件
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "done"
}