Files
ailine/backend/app/agent/service.py
root 048f57a89f
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
集成三个子图到主Agent架构 + 修复前后端字段不匹配问题
主要变更:
1. 创建 subgraph_tools.py - 将三个子图包装为 LangChain 工具
2. 更新 graph_tools.py - 删除旧工具,添加子图工具
3. 更新系统提示词 - 介绍三个子系统 + RAG 能力
4. 简化 backend.py - 删除独立子图 API 端点
5. 修复 service.py 字段名不匹配问题 - content -> token
6. 前端界面优化 - 移动子图测试到侧边栏、删除测试审核按钮
7. 添加 pyjwt 依赖到 requirements.txt
8. 更新 docker-compose.yml - 添加前端代码挂载
2026-04-27 15:23:50 +08:00

397 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
import asyncio
# 本地模块
from ..graph.graph_builder import GraphBuilder, GraphContext
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
from .rag_initializer import init_rag_tool
from .intent_classifier import get_intent_classifier
from ..logger import info, warning
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graphs = {}
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
# 添加:意图分类器
self.intent_classifier = get_intent_classifier()
# RAG 管道(可选,需要时设置)
self.rag_pipeline = None
async def initialize(self):
# 1. 初始化 RAG 工具(如果需要)
def create_local_llm():
provider = LocalVLLMChatProvider()
return provider.get_service()
rag_tool = await init_rag_tool(create_local_llm)
if rag_tool:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
# 2. 构建各模型的 Graph
chat_services = get_all_chat_services()
for name, llm in chat_services.items():
try:
info(f"🔄 初始化模型 '{name}'...")
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
available = list(self.graphs.keys())
if not available:
raise RuntimeError("没有可用的模型")
model = available[0]
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
graph = self.graphs[model]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""流式处理消息,返回异步生成器(支持混合路由)"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
# ========== 新增:混合路由 ==========
intent_result = await self.intent_classifier.classify(message)
info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})")
info(f"📝 推理: {intent_result.reasoning}")
# 发送意图分类事件
yield {
"type": "intent_classified",
"intent": intent_result.intent_type.value,
"confidence": intent_result.confidence,
"reasoning": intent_result.reasoning
}
# 根据意图决定路径
use_react_loop = True
if intent_result.confidence >= 0.6:
intent_str = intent_result.intent_type.value
if intent_str in ["chitchat", "clarify"]:
use_react_loop = False
elif intent_str == "knowledge" and self.rag_pipeline:
use_react_loop = False
# 发送路径决策事件
yield {
"type": "path_decision",
"path": "react_loop" if use_react_loop else "fast",
"intent": intent_result.intent_type.value
}
# ====================================
if use_react_loop:
# ========== React 循环路径 ==========
current_node = None
tool_calls_in_progress = {}
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"],
version="v2",
subgraphs=True
):
chunk_type = chunk["type"]
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
# 检测节点变化,发送节点开始事件
if node_name != current_node:
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "node_start",
"node": node_name
}
current_node = node_name
# 处理消息内容
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# 处理思考过程
if reasoning_token:
processed_event = {
"type": "llm_token",
"node": node_name,
"reasoning_token": reasoning_token
}
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
# 记录工具调用开始
if tool_call_id not in tool_calls_in_progress:
tool_calls_in_progress[tool_call_id] = {
"name": tool_name,
"args": tool_args
}
yield {
"type": "tool_call_start",
"tool": tool_name,
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token
elif token_content:
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content, # ✅ 改为 token
"reasoning_token": reasoning_token
}
elif chunk_type == "updates":
updates_data = chunk["data"]
serialized_data = self._serialize_value(updates_data)
# 检查是否有人工审核请求
if "review_pending" in serialized_data and serialized_data["review_pending"]:
review_id = serialized_data.get("review_id", "")
content_to_review = serialized_data.get("content_to_review", "")
yield {
"type": "human_review_request",
"review_id": review_id,
"content": content_to_review
}
# 检查是否有工具结果
if "messages" in serialized_data:
for msg in serialized_data["messages"]:
# 检测工具结果消息
if msg.get("role") == "tool":
tool_call_id = msg.get("tool_call_id", "")
tool_name = msg.get("name", "")
tool_output = msg.get("content", "")
if tool_call_id in tool_calls_in_progress:
yield {
"type": "tool_call_end",
"tool": tool_name,
"id": tool_call_id,
"result": tool_output
}
del tool_calls_in_progress[tool_call_id]
processed_event = {
"type": "state_update",
"data": serialized_data
}
elif chunk_type == "custom":
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
if processed_event:
yield processed_event
# 发送结束事件
if current_node:
yield {
"type": "node_end",
"node": current_node
}
yield {
"type": "done"
}
else:
# ========== 快速路径 ==========
intent_str = intent_result.intent_type.value
if intent_str == "chitchat":
# 闲聊直接回答
reply = await self._generate_fast_reply(
message,
"你是一个友好的助手,请礼貌回应用户的问候或闲聊。"
)
for char in reply:
yield {
"type": "llm_token",
"node": "fast_path",
"token": char # ✅ 改为 token
}
await asyncio.sleep(0.03)
elif intent_str == "clarify":
# 澄清反问
reply = await self._generate_fast_reply(
message,
"用户的问题不够明确,请礼貌地询问更多细节,以便更好地帮助用户。"
)
for char in reply:
yield {
"type": "llm_token",
"node": "fast_path",
"token": char # ✅ 改为 token
}
await asyncio.sleep(0.03)
elif intent_str == "knowledge" and self.rag_pipeline:
# 快速 RAG
yield {
"type": "node_start",
"node": "fast_rag"
}
yield {
"type": "reasoning",
"node": "fast_rag",
"content": "正在查询知识库..."
}
# 模拟 RAG 检索
await asyncio.sleep(0.3)
# 使用 RAG 生成回答
reply = await self._generate_rag_reply(message)
yield {
"type": "node_end",
"node": "fast_rag"
}
for char in reply:
yield {
"type": "llm_token",
"node": "fast_path",
"token": char # ✅ 改为 token
}
await asyncio.sleep(0.03)
else:
# 兜底:直接回答
reply = await self._generate_fast_reply(
message,
"请简洁回答用户的问题。"
)
for char in reply:
yield {
"type": "llm_token",
"node": "fast_path",
"token": char # ✅ 改为 token
}
await asyncio.sleep(0.03)
yield {
"type": "done"
}
async def _generate_fast_reply(self, message: str, system_prompt: str) -> str:
"""快速生成回复(不经过 React 循环)"""
# 使用默认模型生成回复
model_name = next(iter(self.graphs.keys()), "zhipu")
llm = get_all_chat_services().get(model_name)
if not llm:
return "抱歉,服务暂时不可用。"
prompt = f"{system_prompt}\n\n用户: {message}"
response = await llm.ainvoke(prompt)
return response.content if hasattr(response, 'content') else str(response)
async def _generate_rag_reply(self, message: str) -> str:
"""使用 RAG 生成回复"""
if not self.rag_pipeline:
return await self._generate_fast_reply(message, "请简洁回答用户的问题。")
# 检索文档
docs = await self.rag_pipeline.aretrieve(message)
context = self.rag_pipeline.format_context(docs)
# 生成回答
model_name = next(iter(self.graphs.keys()), "zhipu")
llm = get_all_chat_services().get(model_name)
if not llm:
return "抱歉,服务暂时不可用。"
prompt = f"""请根据以下参考文档回答用户问题。
参考文档:
{context or "(无相关文档)"}
用户问题: {message}
"""
response = await llm.ainvoke(prompt)
return response.content if hasattr(response, 'content') else str(response)