Files
ailine/agent.py

85 lines
2.9 KiB
Python
Raw Normal View History

2026-04-13 19:49:18 +08:00
"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
2026-04-12 01:46:51 +08:00
import os
2026-04-13 19:49:18 +08:00
from dotenv import load_dotenv
2026-04-12 01:46:51 +08:00
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
2026-04-13 19:49:18 +08:00
# 本地模块
from graph_builder import GraphBuilder
from tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
2026-04-12 01:46:51 +08:00
2026-04-13 19:49:18 +08:00
load_dotenv()
2026-04-12 01:46:51 +08:00
2026-04-13 19:49:18 +08:00
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def __init__(self, checkpointer):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
return ChatZhipuAI(
2026-04-12 01:46:51 +08:00
model="glm-4.7-flash",
2026-04-13 19:49:18 +08:00
api_key=api_key,
2026-04-12 01:46:51 +08:00
temperature=0.1,
max_tokens=4096,
)
2026-04-13 19:49:18 +08:00
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
return ChatOpenAI(
base_url="http://localhost:8000/v1",
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
model="gemma-4-E2B-it",
)
2026-04-12 01:46:51 +08:00
2026-04-13 19:49:18 +08:00
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
model_configs = {
"zhipu": self._create_zhipu_llm,
"local": self._create_local_llm,
}
for model_name, llm_creator in model_configs.items():
try:
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
print(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型,请检查配置")
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
"""处理用户消息,返回最终答案"""
if model not in self.graphs:
fallback_model = next(iter(self.graphs.keys()))
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
model = fallback_model
graph = self.graphs[model]
config = {"configurable": {"thread_id": thread_id}}
input_state = {"messages": [HumanMessage(content=message)]}
result = await graph.ainvoke(input_state, config=config)
return result["messages"][-1].content