Files
ailine/app/agent.py

277 lines
11 KiB
Python
Raw Normal View History

2026-04-13 19:49:18 +08:00
"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
2026-04-12 01:46:51 +08:00
import os
2026-04-17 01:26:05 +08:00
import json
2026-04-13 19:49:18 +08:00
from dotenv import load_dotenv
2026-04-12 01:46:51 +08:00
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
2026-04-12 01:46:51 +08:00
2026-04-13 19:49:18 +08:00
# 本地模块
2026-04-14 17:34:12 +08:00
from app.graph_builder import GraphBuilder, GraphContext
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
2026-04-14 17:34:12 +08:00
from app.logger import debug, info, warning, error
2026-04-12 01:46:51 +08:00
2026-04-13 19:49:18 +08:00
load_dotenv()
2026-04-12 01:46:51 +08:00
2026-04-13 19:49:18 +08:00
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def __init__(self, checkpointer: AsyncPostgresSaver):
2026-04-13 19:49:18 +08:00
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
return ChatZhipuAI(
2026-04-12 01:46:51 +08:00
model="glm-4.7-flash",
2026-04-13 19:49:18 +08:00
api_key=api_key,
2026-04-12 01:46:51 +08:00
temperature=0.1,
max_tokens=4096,
2026-04-17 01:26:05 +08:00
timeout=120.0, # 增加请求超时时间原为60秒
max_retries=3, # 增加重试次数原为2次
streaming=True, # 确保开启流式输出
2026-04-14 17:34:12 +08:00
)
def _create_deepseek_llm(self):
"""创建 DeepSeek LLM使用 OpenAI 兼容 API"""
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set in environment")
return ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(api_key),
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
temperature=0.1,
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
2026-04-17 01:26:05 +08:00
streaming=True, # 确保开启流式输出
2026-04-12 01:46:51 +08:00
)
2026-04-13 19:49:18 +08:00
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
2026-04-14 17:34:12 +08:00
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
2026-04-17 01:26:05 +08:00
"http://127.0.0.1:8081/v1"
2026-04-14 17:34:12 +08:00
)
2026-04-13 19:49:18 +08:00
return ChatOpenAI(
2026-04-14 17:34:12 +08:00
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
2026-04-13 19:49:18 +08:00
model="gemma-4-E2B-it",
2026-04-14 17:34:12 +08:00
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
2026-04-17 01:26:05 +08:00
streaming=True, # 确保开启流式输出
2026-04-13 19:49:18 +08:00
)
2026-04-12 01:46:51 +08:00
2026-04-13 19:49:18 +08:00
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
2026-04-13 19:49:18 +08:00
model_configs = {
2026-04-17 01:26:05 +08:00
"local": self._create_local_llm, # 本地模型作为第一个
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
"zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个
2026-04-13 19:49:18 +08:00
}
for model_name, llm_creator in model_configs.items():
try:
2026-04-14 17:34:12 +08:00
info(f"🔄 正在初始化模型 '{model_name}'...")
2026-04-13 19:49:18 +08:00
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer)
2026-04-13 19:49:18 +08:00
self.graphs[model_name] = graph
2026-04-14 17:34:12 +08:00
info(f"✅ 模型 '{model_name}' 初始化成功")
2026-04-13 19:49:18 +08:00
except Exception as e:
2026-04-14 17:34:12 +08:00
import traceback
error_detail = traceback.format_exc()
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
debug(f" 详细错误:\n{error_detail}")
2026-04-13 19:49:18 +08:00
if not self.graphs:
2026-04-14 17:34:12 +08:00
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
"1. ZHIPUAI_API_KEY 未配置或无效\n"
"2. DEEPSEEK_API_KEY 未配置或无效\n"
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
"4. 网络连接问题")
2026-04-13 19:49:18 +08:00
return self
2026-04-17 01:26:05 +08:00
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
2026-04-14 17:34:12 +08:00
"""
处理用户消息返回包含回复token统计和耗时的字典
Returns:
dict: {
"reply": str, # AI 回复内容
"token_usage": dict, # Token 使用详情
"elapsed_time": float # 调用耗时(秒)
}
"""
# 尝试使用指定模型,如果不可用则循环尝试其他模型
2026-04-13 19:49:18 +08:00
if model not in self.graphs:
warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型")
found = False
for available_model in self.graphs.keys():
try:
# 这里可以添加额外的模型可用性检查逻辑
model = available_model
found = True
info(f"已切换到可用模型: '{model}'")
break
except Exception as e:
warning(f"模型 '{available_model}' 也不可用: {str(e)}")
continue
if not found:
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
2026-04-13 19:49:18 +08:00
graph = self.graphs[model]
2026-04-16 03:21:38 +08:00
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
}
input_state = {"messages": [{"role": "user", "content": message}]}
2026-04-14 17:34:12 +08:00
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
2026-04-16 03:21:38 +08:00
2026-04-17 01:26:05 +08:00
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
# LangChain 消息对象
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
2026-04-16 03:21:38 +08:00
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""
流式处理消息返回异步生成器
Args:
message: 用户消息
thread_id: 线程 ID
model_name: 模型名称
user_id: 用户 ID
Yields:
字典包含事件类型和数据
"""
graph = self.graphs.get(model_name)
2026-04-17 01:26:05 +08:00
2026-04-16 03:21:38 +08:00
if not graph:
2026-04-17 01:26:05 +08:00
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
2026-04-16 03:21:38 +08:00
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
2026-04-17 01:26:05 +08:00
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom
version="v2", # 使用统一的v2格式
subgraphs=True # 如果你使用了子图,请开启此项
):
chunk_type = chunk["type"]
processed_event = {}
# 1. 处理 LLM Token 流 (实现打字机效果)
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
# 提取元数据
node_name = metadata.get("langgraph_node", "unknown")
# 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串
token_content = getattr(message_chunk, 'content', str(message_chunk))
# 提取 DeepSeek reasoner 的思考过程 token
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它
if reasoning_token:
import logging
logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}")
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token,
"metadata": metadata # 可选的元数据
}
2026-04-16 03:21:38 +08:00
2026-04-17 01:26:05 +08:00
# 2. 处理状态更新 (节点执行完成)
elif chunk_type == "updates":
updates_data = chunk["data"]
# 序列化 updates 中的所有数据
serialized_data = self._serialize_value(updates_data)
processed_event = {
"type": "state_update",
"data": serialized_data
}
# 为了兼容前端旧字段,也保留 messages 字段(可选)
if "messages" in serialized_data:
processed_event["messages"] = serialized_data["messages"]
2026-04-16 03:21:38 +08:00
2026-04-17 01:26:05 +08:00
# 3. 处理自定义数据 (如果需要)
elif chunk_type == "custom":
# 自定义事件同样需要序列化
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
2026-04-16 03:21:38 +08:00
}
2026-04-17 01:26:05 +08:00
# 4. 其他类型debug, tasks等按需处理
else:
# 对于不需要的类型,直接跳过
continue
# 确保事件有数据再发送
if processed_event:
yield processed_event