""" AI Agent 服务类 - 支持多模型动态切换 接收外部传入的 checkpointer,不负责管理连接生命周期 """ import os from dotenv import load_dotenv from langchain_community.chat_models import ChatZhipuAI from langchain_core.messages import HumanMessage from langchain_openai import ChatOpenAI from pydantic import SecretStr # 本地模块 from graph_builder import GraphBuilder from tools import AVAILABLE_TOOLS, TOOLS_BY_NAME load_dotenv() class AIAgentService: """异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer""" def __init__(self, checkpointer): """ 初始化服务 Args: checkpointer: 已经初始化的 AsyncPostgresSaver 实例 """ self.checkpointer = checkpointer self.graphs = {} # 存储不同模型对应的 graph 实例 def _create_zhipu_llm(self): """创建智谱在线 LLM""" api_key = os.getenv("ZHIPUAI_API_KEY") if not api_key: raise ValueError("ZHIPUAI_API_KEY not set in environment") return ChatZhipuAI( model="glm-4.7-flash", api_key=api_key, temperature=0.1, max_tokens=4096, ) def _create_local_llm(self): """创建本地 vLLM 服务 LLM""" return ChatOpenAI( base_url="http://localhost:8000/v1", api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")), model="gemma-4-E2B-it", ) async def initialize(self): """预编译所有模型的 graph(使用传入的 checkpointer)""" model_configs = { "zhipu": self._create_zhipu_llm, "local": self._create_local_llm, } for model_name, llm_creator in model_configs.items(): try: llm = llm_creator() builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build() graph = builder.compile(checkpointer=self.checkpointer) self.graphs[model_name] = graph print(f"✅ 模型 '{model_name}' 初始化成功") except Exception as e: print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}") if not self.graphs: raise RuntimeError("没有可用的模型,请检查配置") return self async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str: """处理用户消息,返回最终答案""" if model not in self.graphs: fallback_model = next(iter(self.graphs.keys())) print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'") model = fallback_model graph = self.graphs[model] config = {"configurable": {"thread_id": thread_id}} input_state = {"messages": [HumanMessage(content=message)]} result = await graph.ainvoke(input_state, config=config) return result["messages"][-1].content