Files
ailine/app/agent.py
root 8dd94c6c19
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 27s
添加长期记忆
2026-04-14 17:34:12 +08:00

152 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import os
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.logger import debug, info, warning, error
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres.aio import AsyncPostgresStore
load_dotenv()
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def __init__(self, checkpointer: AsyncPostgresSaver, store: AsyncPostgresStore):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
store: 已经初始化的 AsyncPostgresStore 实例
"""
self.checkpointer = checkpointer
self.store = store
self.graphs = {} # 存储不同模型对应的 graph 实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=api_key,
temperature=0.1,
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
)
def _create_deepseek_llm(self):
"""创建 DeepSeek LLM使用 OpenAI 兼容 API"""
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set in environment")
return ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(api_key),
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
temperature=0.1,
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
)
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://115.190.121.151:18000/v1"
)
return ChatOpenAI(
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer 和 store"""
model_configs = {
"zhipu": self._create_zhipu_llm,
"deepseek": self._create_deepseek_llm,
"local": self._create_local_llm,
}
for model_name, llm_creator in model_configs.items():
try:
info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator()
# 测试 LLM 连接(可选,用于调试)
if model_name == "local":
debug(f" 测试 vLLM 连接: {os.getenv('VLLM_BASE_URL', '未设置')}")
elif model_name == "deepseek":
debug(f" 测试 DeepSeek API 连接: https://api.deepseek.com")
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer, store=self.store)
self.graphs[model_name] = graph
info(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
import traceback
error_detail = traceback.format_exc()
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
debug(f" 详细错误:\n{error_detail}")
if not self.graphs:
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
"1. ZHIPUAI_API_KEY 未配置或无效\n"
"2. DEEPSEEK_API_KEY 未配置或无效\n"
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
"4. 网络连接问题")
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
"""
处理用户消息返回包含回复、token统计和耗时的字典
Returns:
dict: {
"reply": str, # AI 回复内容
"token_usage": dict, # Token 使用详情
"elapsed_time": float # 调用耗时(秒)
}
"""
if model not in self.graphs:
fallback_model = next(iter(self.graphs.keys()))
warning(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
model = fallback_model
graph = self.graphs[model]
config = {"configurable": {"thread_id": thread_id}}
input_state = {"messages": [HumanMessage(content=message)]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}