Files
ailine/agent.py
2026-04-13 19:49:18 +08:00

85 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import os
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
# 本地模块
from graph_builder import GraphBuilder
from tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
load_dotenv()
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def __init__(self, checkpointer):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=api_key,
temperature=0.1,
max_tokens=4096,
)
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
return ChatOpenAI(
base_url="http://localhost:8000/v1",
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
model="gemma-4-E2B-it",
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
model_configs = {
"zhipu": self._create_zhipu_llm,
"local": self._create_local_llm,
}
for model_name, llm_creator in model_configs.items():
try:
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
print(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型,请检查配置")
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
"""处理用户消息,返回最终答案"""
if model not in self.graphs:
fallback_model = next(iter(self.graphs.keys()))
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
model = fallback_model
graph = self.graphs[model]
config = {"configurable": {"thread_id": thread_id}}
input_state = {"messages": [HumanMessage(content=message)]}
result = await graph.ainvoke(input_state, config=config)
return result["messages"][-1].content