""" AI Agent 服务类 - 支持多模型动态切换 接收外部传入的 checkpointer,不负责管理连接生命周期 """ import os import json from dotenv import load_dotenv try: from langchain_community.chat_models import ChatZhipuAI HAS_ZHIPUAI = True except ImportError: HAS_ZHIPUAI = False ChatZhipuAI = None try: from langchain_openai import ChatOpenAI, OpenAIEmbeddings HAS_OPENAI = True except ImportError: HAS_OPENAI = False ChatOpenAI = None OpenAIEmbeddings = None from pydantic import SecretStr try: from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver HAS_POSTGRES_CHECKPOINT = True except ImportError: HAS_POSTGRES_CHECKPOINT = False AsyncPostgresSaver = None # 本地模块 from app.graph_builder import GraphBuilder, GraphContext from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME try: from app.rag import RAGPipeline from app.rag.tools import RAGTool HAS_RAG = True except ImportError as e: HAS_RAG = False RAGPipeline = None RAGTool = None from app.logger import debug, info, warning, error load_dotenv() class AIAgentService: """异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer""" def __init__(self, checkpointer: AsyncPostgresSaver): """ 初始化服务 Args: checkpointer: 已经初始化的 AsyncPostgresSaver 实例 """ self.checkpointer = checkpointer self.graphs = {} # 存储不同模型对应的 graph 实例 self.rag = None # RAG 检索实例 self.rag_tool = None # RAG 工具实例 def _create_zhipu_llm(self): """创建智谱在线 LLM""" if not HAS_ZHIPUAI: raise ImportError("智谱AI支持不可用,请安装langchain-community包") api_key = os.getenv("ZHIPUAI_API_KEY") if not api_key: raise ValueError("ZHIPUAI_API_KEY not set in environment") return ChatZhipuAI( model="glm-4.7-flash", api_key=api_key, temperature=0.1, max_tokens=4096, timeout=120.0, # 增加请求超时时间(秒),原为60秒 max_retries=3, # 增加重试次数,原为2次 streaming=True, # 确保开启流式输出 ) def _create_deepseek_llm(self): """创建 DeepSeek LLM(使用 OpenAI 兼容 API)""" if not HAS_OPENAI: raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") api_key = os.getenv("DEEPSEEK_API_KEY") if not api_key: raise ValueError("DEEPSEEK_API_KEY not set in environment") return ChatOpenAI( base_url="https://api.deepseek.com", api_key=SecretStr(api_key), model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式 temperature=0.1, max_tokens=4096, timeout=60.0, # 请求超时时间(秒) max_retries=2, # 失败后自动重试次数 streaming=True, # 确保开启流式输出 ) def _create_local_llm(self): """创建本地 vLLM 服务 LLM""" if not HAS_OPENAI: raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") # vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发 vllm_base_url = os.getenv( "VLLM_BASE_URL", "http://127.0.0.1:8081/v1" ) return ChatOpenAI( base_url=vllm_base_url, api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), model="gemma-4-E2B-it", timeout=60.0, # 请求超时时间(秒) max_retries=2, # 失败后自动重试次数 streaming=True, # 确保开启流式输出 ) def _create_embeddings(self): """创建嵌入模型""" if not HAS_OPENAI: raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") embedding_url = os.getenv( "LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1" ) return OpenAIEmbeddings( openai_api_base=embedding_url, openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"), model="text-embedding-ada-002", # 模型名称不重要,兼容即可 ) async def initialize(self): """预编译所有模型的 graph(使用传入的 checkpointer)""" # 先初始化 RAG 检索系统 if HAS_RAG and RAGPipeline is not None and RAGTool is not None: try: info("🔄 正在初始化 RAG 检索系统...") embeddings = self._create_embeddings() self.rag = RAGPipeline(embeddings=embeddings) self.rag_tool = RAGTool(self.rag).get_tool() info("✅ RAG 检索系统初始化成功") except Exception as e: warning(f"⚠️ RAG 检索系统初始化失败: {e}") self.rag = None self.rag_tool = None else: info("⏭️ RAG 检索系统不可用,跳过初始化") self.rag = None self.rag_tool = None model_configs = { "local": self._create_local_llm, # 本地模型作为第一个 "deepseek": self._create_deepseek_llm, # DeepSeek 作为中间 "zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个 } for model_name, llm_creator in model_configs.items(): try: info(f"🔄 正在初始化模型 '{model_name}'...") llm = llm_creator() # 构建工具列表:基础工具 + RAG工具(如果可用) tools = AVAILABLE_TOOLS.copy() tools_by_name = TOOLS_BY_NAME.copy() if self.rag_tool is not None: tools.append(self.rag_tool) tools_by_name[self.rag_tool.name] = self.rag_tool builder = GraphBuilder(llm, tools, tools_by_name).build() graph = builder.compile(checkpointer=self.checkpointer) self.graphs[model_name] = graph info(f"✅ 模型 '{model_name}' 初始化成功") except Exception as e: import traceback error_detail = traceback.format_exc() warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}") debug(f" 详细错误:\n{error_detail}") if not self.graphs: raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n" "1. ZHIPUAI_API_KEY 未配置或无效\n" "2. DEEPSEEK_API_KEY 未配置或无效\n" "3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n" "4. 网络连接问题") return self async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict: """ 处理用户消息,返回包含回复、token统计和耗时的字典 Returns: dict: { "reply": str, # AI 回复内容 "token_usage": dict, # Token 使用详情 "elapsed_time": float # 调用耗时(秒) } """ # 尝试使用指定模型,如果不可用则循环尝试其他模型 if model not in self.graphs: warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型") found = False for available_model in self.graphs.keys(): try: # 这里可以添加额外的模型可用性检查逻辑 model = available_model found = True info(f"已切换到可用模型: '{model}'") break except Exception as e: warning(f"模型 '{available_model}' 也不可用: {str(e)}") continue if not found: raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}") graph = self.graphs[model] config = { "configurable": {"thread_id": thread_id}, "metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用 } input_state = {"messages": [{"role": "user", "content": message}]} context = GraphContext(user_id=user_id) result = await graph.ainvoke(input_state, config=config, context=context) reply = result["messages"][-1].content token_usage = result.get("last_token_usage", {}) elapsed_time = result.get("last_elapsed_time", 0.0) return { "reply": reply, "token_usage": token_usage, "elapsed_time": elapsed_time } def _serialize_value(self, value): """递归将 LangChain 对象转换为可 JSON 序列化的格式""" if hasattr(value, 'content'): # LangChain 消息对象 msg_type = getattr(value, 'type', 'message') return { "role": msg_type, "content": getattr(value, 'content', ''), "additional_kwargs": getattr(value, 'additional_kwargs', {}), "tool_calls": getattr(value, 'tool_calls', []) } elif isinstance(value, dict): return {k: self._serialize_value(v) for k, v in value.items()} elif isinstance(value, (list, tuple)): return [self._serialize_value(item) for item in value] else: try: json.dumps(value) return value except (TypeError, ValueError): return str(value) async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"): """ 流式处理消息,返回异步生成器 Args: message: 用户消息 thread_id: 线程 ID model_name: 模型名称 user_id: 用户 ID Yields: 字典,包含事件类型和数据 """ graph = self.graphs.get(model_name) if not graph: raise ValueError(f"模型 '{model_name}' 未找到或未初始化") config = { "configurable": {"thread_id": thread_id}, "metadata": {"user_id": user_id} } input_state = {"messages": [{"role": "user", "content": message}]} context = GraphContext(user_id=user_id) async for chunk in graph.astream( input_state, config=config, context=context, stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom version="v2", # 使用统一的v2格式 subgraphs=True # 如果你使用了子图,请开启此项 ): chunk_type = chunk["type"] processed_event = {} # 1. 处理 LLM Token 流 (实现打字机效果) if chunk_type == "messages": message_chunk, metadata = chunk["data"] # 提取元数据 node_name = metadata.get("langgraph_node", "unknown") # 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串 token_content = getattr(message_chunk, 'content', str(message_chunk)) # 提取 DeepSeek reasoner 的思考过程 token reasoning_token = "" if hasattr(message_chunk, 'additional_kwargs'): reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") # [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它 if reasoning_token: import logging logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}") processed_event = { "type": "llm_token", "node": node_name, "token": token_content, "reasoning_token": reasoning_token, "metadata": metadata # 可选的元数据 } # 2. 处理状态更新 (节点执行完成) elif chunk_type == "updates": updates_data = chunk["data"] # 序列化 updates 中的所有数据 serialized_data = self._serialize_value(updates_data) processed_event = { "type": "state_update", "data": serialized_data } # 为了兼容前端旧字段,也保留 messages 字段(可选) if "messages" in serialized_data: processed_event["messages"] = serialized_data["messages"] # 3. 处理自定义数据 (如果需要) elif chunk_type == "custom": # 自定义事件同样需要序列化 serialized_data = self._serialize_value(chunk["data"]) processed_event = { "type": "custom", "data": serialized_data } # 4. 其他类型(debug, tasks等)按需处理 else: # 对于不需要的类型,直接跳过 continue # 确保事件有数据再发送 if processed_event: yield processed_event