218 lines
8.5 KiB
Python
218 lines
8.5 KiB
Python
"""
|
||
AI Agent 服务类 - 支持多模型动态切换
|
||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||
"""
|
||
|
||
import os
|
||
from dotenv import load_dotenv
|
||
from langchain_community.chat_models import ChatZhipuAI
|
||
from langchain_openai import ChatOpenAI
|
||
from pydantic import SecretStr
|
||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||
|
||
# 本地模块
|
||
from app.graph_builder import GraphBuilder, GraphContext
|
||
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||
from app.logger import debug, info, warning, error
|
||
|
||
|
||
load_dotenv()
|
||
|
||
|
||
class AIAgentService:
|
||
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
|
||
|
||
def __init__(self, checkpointer: AsyncPostgresSaver):
|
||
"""
|
||
初始化服务
|
||
Args:
|
||
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
|
||
"""
|
||
self.checkpointer = checkpointer
|
||
self.graphs = {} # 存储不同模型对应的 graph 实例
|
||
|
||
def _create_zhipu_llm(self):
|
||
"""创建智谱在线 LLM"""
|
||
api_key = os.getenv("ZHIPUAI_API_KEY")
|
||
if not api_key:
|
||
raise ValueError("ZHIPUAI_API_KEY not set in environment")
|
||
return ChatZhipuAI(
|
||
model="glm-4.7-flash",
|
||
api_key=api_key,
|
||
temperature=0.1,
|
||
max_tokens=4096,
|
||
timeout=60.0, # 请求超时时间(秒)
|
||
max_retries=2, # 失败后自动重试次数
|
||
)
|
||
|
||
def _create_deepseek_llm(self):
|
||
"""创建 DeepSeek LLM(使用 OpenAI 兼容 API)"""
|
||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||
if not api_key:
|
||
raise ValueError("DEEPSEEK_API_KEY not set in environment")
|
||
return ChatOpenAI(
|
||
base_url="https://api.deepseek.com",
|
||
api_key=SecretStr(api_key),
|
||
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
|
||
temperature=0.1,
|
||
max_tokens=4096,
|
||
timeout=60.0, # 请求超时时间(秒)
|
||
max_retries=2, # 失败后自动重试次数
|
||
)
|
||
|
||
def _create_local_llm(self):
|
||
"""创建本地 vLLM 服务 LLM"""
|
||
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
|
||
vllm_base_url = os.getenv(
|
||
"VLLM_BASE_URL",
|
||
"http://localhost:8081/v1"
|
||
)
|
||
|
||
return ChatOpenAI(
|
||
base_url=vllm_base_url,
|
||
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
|
||
model="gemma-4-E2B-it",
|
||
timeout=60.0, # 请求超时时间(秒)
|
||
max_retries=2, # 失败后自动重试次数
|
||
)
|
||
|
||
async def initialize(self):
|
||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||
model_configs = {
|
||
"zhipu": self._create_zhipu_llm,
|
||
"deepseek": self._create_deepseek_llm,
|
||
"local": self._create_local_llm,
|
||
}
|
||
|
||
for model_name, llm_creator in model_configs.items():
|
||
try:
|
||
info(f"🔄 正在初始化模型 '{model_name}'...")
|
||
llm = llm_creator()
|
||
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
|
||
graph = builder.compile(checkpointer=self.checkpointer)
|
||
self.graphs[model_name] = graph
|
||
info(f"✅ 模型 '{model_name}' 初始化成功")
|
||
except Exception as e:
|
||
import traceback
|
||
error_detail = traceback.format_exc()
|
||
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
|
||
debug(f" 详细错误:\n{error_detail}")
|
||
|
||
if not self.graphs:
|
||
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
|
||
"1. ZHIPUAI_API_KEY 未配置或无效\n"
|
||
"2. DEEPSEEK_API_KEY 未配置或无效\n"
|
||
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
|
||
"4. 网络连接问题")
|
||
|
||
return self
|
||
|
||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
|
||
"""
|
||
处理用户消息,返回包含回复、token统计和耗时的字典
|
||
|
||
Returns:
|
||
dict: {
|
||
"reply": str, # AI 回复内容
|
||
"token_usage": dict, # Token 使用详情
|
||
"elapsed_time": float # 调用耗时(秒)
|
||
}
|
||
"""
|
||
# 尝试使用指定模型,如果不可用则循环尝试其他模型
|
||
if model not in self.graphs:
|
||
warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型")
|
||
found = False
|
||
for available_model in self.graphs.keys():
|
||
try:
|
||
# 这里可以添加额外的模型可用性检查逻辑
|
||
model = available_model
|
||
found = True
|
||
info(f"已切换到可用模型: '{model}'")
|
||
break
|
||
except Exception as e:
|
||
warning(f"模型 '{available_model}' 也不可用: {str(e)}")
|
||
continue
|
||
|
||
if not found:
|
||
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
|
||
|
||
graph = self.graphs[model]
|
||
config = {
|
||
"configurable": {"thread_id": thread_id},
|
||
"metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
|
||
}
|
||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||
context = GraphContext(user_id=user_id)
|
||
|
||
result = await graph.ainvoke(input_state, config=config, context=context)
|
||
|
||
reply = result["messages"][-1].content
|
||
token_usage = result.get("last_token_usage", {})
|
||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||
|
||
return {
|
||
"reply": reply,
|
||
"token_usage": token_usage,
|
||
"elapsed_time": elapsed_time
|
||
}
|
||
|
||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||
"""
|
||
流式处理消息,返回异步生成器
|
||
|
||
Args:
|
||
message: 用户消息
|
||
thread_id: 线程 ID
|
||
model_name: 模型名称
|
||
user_id: 用户 ID
|
||
|
||
Yields:
|
||
字典,包含事件类型和数据
|
||
"""
|
||
graph = self.graphs.get(model_name)
|
||
if not graph:
|
||
warning(f"警告: 模型 '{model_name}' 不可用,使用默认模型")
|
||
model_name = next(iter(self.graphs.keys()))
|
||
graph = self.graphs[model_name]
|
||
|
||
config = {
|
||
"configurable": {"thread_id": thread_id},
|
||
"metadata": {"user_id": user_id}
|
||
}
|
||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||
context = GraphContext(user_id=user_id)
|
||
|
||
# 使用 astream_events 获取流式事件
|
||
async for event in graph.astream_events(input_state, config=config, context=context, version="v2"):
|
||
kind = event["event"]
|
||
|
||
# 聊天模型流式输出
|
||
if kind == "on_chat_model_stream":
|
||
content = event["data"]["chunk"].content
|
||
if content:
|
||
yield {"type": "token", "content": content}
|
||
|
||
# 工具调用开始
|
||
elif kind == "on_tool_start":
|
||
tool_name = event["name"]
|
||
yield {"type": "tool_start", "tool": tool_name}
|
||
|
||
# 工具调用结束
|
||
elif kind == "on_tool_end":
|
||
tool_name = event["name"]
|
||
yield {"type": "tool_end", "tool": tool_name}
|
||
|
||
# 链结束,获取最终结果
|
||
elif kind == "on_chain_end" and event["name"] == "LangGraph":
|
||
output = event["data"]["output"]
|
||
reply = output["messages"][-1].content if output.get("messages") else ""
|
||
token_usage = output.get("last_token_usage", {})
|
||
elapsed_time = output.get("last_elapsed_time", 0.0)
|
||
|
||
yield {
|
||
"type": "done",
|
||
"reply": reply,
|
||
"token_usage": token_usage,
|
||
"elapsed_time": elapsed_time
|
||
}
|