""" AI Agent 服务类 - 支持多模型动态切换 接收外部传入的 checkpointer,不负责管理连接生命周期 """ import os from dotenv import load_dotenv from langchain_community.chat_models import ChatZhipuAI from langchain_openai import ChatOpenAI from pydantic import SecretStr from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver # 本地模块 from app.graph_builder import GraphBuilder, GraphContext from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME from app.logger import debug, info, warning, error load_dotenv() class AIAgentService: """异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer""" def __init__(self, checkpointer: AsyncPostgresSaver): """ 初始化服务 Args: checkpointer: 已经初始化的 AsyncPostgresSaver 实例 """ self.checkpointer = checkpointer self.graphs = {} # 存储不同模型对应的 graph 实例 def _create_zhipu_llm(self): """创建智谱在线 LLM""" api_key = os.getenv("ZHIPUAI_API_KEY") if not api_key: raise ValueError("ZHIPUAI_API_KEY not set in environment") return ChatZhipuAI( model="glm-4.7-flash", api_key=api_key, temperature=0.1, max_tokens=4096, timeout=60.0, # 请求超时时间(秒) max_retries=2, # 失败后自动重试次数 ) def _create_deepseek_llm(self): """创建 DeepSeek LLM(使用 OpenAI 兼容 API)""" api_key = os.getenv("DEEPSEEK_API_KEY") if not api_key: raise ValueError("DEEPSEEK_API_KEY not set in environment") return ChatOpenAI( base_url="https://api.deepseek.com", api_key=SecretStr(api_key), model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式 temperature=0.1, max_tokens=4096, timeout=60.0, # 请求超时时间(秒) max_retries=2, # 失败后自动重试次数 ) def _create_local_llm(self): """创建本地 vLLM 服务 LLM""" # vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发 vllm_base_url = os.getenv( "VLLM_BASE_URL", "http://localhost:8081/v1" ) return ChatOpenAI( base_url=vllm_base_url, api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), model="gemma-4-E2B-it", timeout=60.0, # 请求超时时间(秒) max_retries=2, # 失败后自动重试次数 ) async def initialize(self): """预编译所有模型的 graph(使用传入的 checkpointer)""" model_configs = { "zhipu": self._create_zhipu_llm, "deepseek": self._create_deepseek_llm, "local": self._create_local_llm, } for model_name, llm_creator in model_configs.items(): try: info(f"🔄 正在初始化模型 '{model_name}'...") llm = llm_creator() builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build() graph = builder.compile(checkpointer=self.checkpointer) self.graphs[model_name] = graph info(f"✅ 模型 '{model_name}' 初始化成功") except Exception as e: import traceback error_detail = traceback.format_exc() warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}") debug(f" 详细错误:\n{error_detail}") if not self.graphs: raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n" "1. ZHIPUAI_API_KEY 未配置或无效\n" "2. DEEPSEEK_API_KEY 未配置或无效\n" "3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n" "4. 网络连接问题") return self async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict: """ 处理用户消息,返回包含回复、token统计和耗时的字典 Returns: dict: { "reply": str, # AI 回复内容 "token_usage": dict, # Token 使用详情 "elapsed_time": float # 调用耗时(秒) } """ # 尝试使用指定模型,如果不可用则循环尝试其他模型 if model not in self.graphs: warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型") found = False for available_model in self.graphs.keys(): try: # 这里可以添加额外的模型可用性检查逻辑 model = available_model found = True info(f"已切换到可用模型: '{model}'") break except Exception as e: warning(f"模型 '{available_model}' 也不可用: {str(e)}") continue if not found: raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}") graph = self.graphs[model] config = { "configurable": {"thread_id": thread_id}, "metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用 } input_state = {"messages": [{"role": "user", "content": message}]} context = GraphContext(user_id=user_id) result = await graph.ainvoke(input_state, config=config, context=context) reply = result["messages"][-1].content token_usage = result.get("last_token_usage", {}) elapsed_time = result.get("last_elapsed_time", 0.0) return { "reply": reply, "token_usage": token_usage, "elapsed_time": elapsed_time } async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"): """ 流式处理消息,返回异步生成器 Args: message: 用户消息 thread_id: 线程 ID model_name: 模型名称 user_id: 用户 ID Yields: 字典,包含事件类型和数据 """ graph = self.graphs.get(model_name) if not graph: warning(f"警告: 模型 '{model_name}' 不可用,使用默认模型") model_name = next(iter(self.graphs.keys())) graph = self.graphs[model_name] config = { "configurable": {"thread_id": thread_id}, "metadata": {"user_id": user_id} } input_state = {"messages": [{"role": "user", "content": message}]} context = GraphContext(user_id=user_id) # 使用 astream_events 获取流式事件 async for event in graph.astream_events(input_state, config=config, context=context, version="v2"): kind = event["event"] # 聊天模型流式输出 if kind == "on_chat_model_stream": content = event["data"]["chunk"].content if content: yield {"type": "token", "content": content} # 工具调用开始 elif kind == "on_tool_start": tool_name = event["name"] yield {"type": "tool_start", "tool": tool_name} # 工具调用结束 elif kind == "on_tool_end": tool_name = event["name"] yield {"type": "tool_end", "tool": tool_name} # 链结束,获取最终结果 elif kind == "on_chain_end" and event["name"] == "LangGraph": output = event["data"]["output"] reply = output["messages"][-1].content if output.get("messages") else "" token_usage = output.get("last_token_usage", {}) elapsed_time = output.get("last_elapsed_time", 0.0) yield { "type": "done", "reply": reply, "token_usage": token_usage, "elapsed_time": elapsed_time }