""" LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数 """ import operator import asyncio import time from typing import Literal, Annotated, Any from langchain_core.language_models import BaseLLM from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.graph import StateGraph, START, END from typing_extensions import TypedDict from langgraph.store.postgres.aio import AsyncPostgresStore from langgraph.runtime import Runtime from dataclasses import dataclass import uuid # 本地模块 from app.logger import debug, info, warning, error class MessagesState(TypedDict): """对话状态类型定义""" messages: Annotated[list[AnyMessage], operator.add] llm_calls: int memory_context:str last_token_usage: dict # 本次调用的 token 使用详情 last_elapsed_time: float # 本次调用耗时(秒) @dataclass class GraphContext: user_id: str # 可扩展更多上下文信息 class GraphBuilder: """LangGraph 状态图构建器 - 所有节点均为类方法""" def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]): """ 初始化构建器 Args: llm: 大语言模型实例 tools: 工具列表 tools_by_name: 名称到工具函数的映射 """ self.llm = llm self.tools = tools self.tools_by_name = tools_by_name self._llm_with_tools = llm.bind_tools(tools) self._prompt = self._create_prompt() self._chain = self._prompt | self._llm_with_tools @staticmethod def _create_prompt() -> ChatPromptTemplate: """创建系统提示模板(静态方法,无需访问实例)""" return ChatPromptTemplate.from_messages([ SystemMessage(content=( "你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n" "【用户背景信息】\n" "以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n" "{memory_context}\n" "若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n" "【可用工具与使用规则】\n" "- 获取温度/天气:`get_current_temperature`\n" "- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n" "- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n" "- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n" "- 抓取网页内容:`fetch_webpage_content`\n" "工具调用时请直接返回所需参数,无需额外说明。\n\n" "【回答要求(必须遵守)】\n" "1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n" "2. 优先利用已知用户信息进行个性化回复。\n" "3. 若无信息可依,礼貌询问或提供通用帮助。" )), MessagesPlaceholder(variable_name="messages") ]) async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: """ LLM 调用节点(异步方法) 注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环 """ memory_context = state.get("memory_context", "暂无用户信息") # 构建完整的输入消息列表(用于调试打印) system_prompt = self._prompt.messages[0] # SystemMessage if isinstance(system_prompt, SystemMessage): system_content = system_prompt.content.format(memory_context=memory_context) else: system_content = str(system_prompt.content) input_messages = [SystemMessage(content=system_content)] + state["messages"] # 打印发送给大模型的最终输入 debug("\n" + "="*80) debug("📤 [LLM输入] 发送给大模型的完整消息:") debug(f" 总消息数: {len(input_messages)}") for i, msg in enumerate(input_messages): content_preview = str(msg.content) # 不截断,完整输出 debug(f" [{i}] {msg.type.upper():10s}: {content_preview}") debug("="*80 + "\n") loop = asyncio.get_event_loop() start_time = time.time() try: response = await loop.run_in_executor( None, lambda: self._chain.invoke({ "messages": state["messages"], "memory_context": memory_context }) ) elapsed_time = time.time() - start_time # 提取 token 用量(兼容不同 LLM 提供商的元数据格式) token_usage = {} input_tokens = 0 output_tokens = 0 # 尝试从 response_metadata 中提取 if hasattr(response, 'response_metadata') and response.response_metadata: meta = response.response_metadata if 'token_usage' in meta: token_usage = meta['token_usage'] elif 'usage' in meta: token_usage = meta['usage'] # 尝试从 additional_kwargs 中提取 if not token_usage and hasattr(response, 'additional_kwargs'): add_kwargs = response.additional_kwargs if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']: token_usage = add_kwargs['llm_output']['token_usage'] # 提取具体的 token 数值 if token_usage: input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0)) output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0)) # 打印响应统计信息 info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒") info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}") if token_usage: debug(f"📋 [LLM统计] 详细用量: {token_usage}") # 打印 LLM 的完整输出 debug("\n" + "="*80) debug("📥 [LLM输出] 大模型返回的完整响应:") debug(f" 消息类型: {response.type.upper()}") debug(f" 内容长度: {len(str(response.content))} 字符") debug("-"*80) debug(f"{response.content}") debug("="*80 + "\n") return { "messages": [response], "llm_calls": state.get('llm_calls', 0) + 1, "last_token_usage": token_usage, "last_elapsed_time": elapsed_time } except Exception as e: elapsed_time = time.time() - start_time error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)") error(f" 错误类型: {type(e).__name__}") error(f" 错误信息: {str(e)}") import traceback traceback.print_exc() debug("="*80 + "\n") # 返回一个友好的错误消息 error_response = AIMessage( content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。" ) return { "messages": [error_response], "llm_calls": state.get('llm_calls', 0), "last_token_usage": {}, "last_elapsed_time": elapsed_time } async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: """ 工具执行节点(异步方法) 对于每个工具调用,在线程池中执行同步工具函数 """ last_message = state['messages'][-1] if not isinstance(last_message, AIMessage) or not last_message.tool_calls: return {"messages": []} results = [] loop = asyncio.get_event_loop() for tool_call in last_message.tool_calls: tool_name = tool_call["name"] tool_args = tool_call["args"] tool_id = tool_call["id"] tool_func = self.tools_by_name.get(tool_name) if tool_func is None: results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id)) continue try: # 修复闭包问题:将变量作为默认参数传入 lambda # 如果工具支持异步 (ainvoke),优先使用异步调用 if hasattr(tool_func, 'ainvoke'): observation = await tool_func.ainvoke(tool_args) else: observation = await loop.run_in_executor( None, lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值 ) results.append(ToolMessage(content=str(observation), tool_call_id=tool_id)) except Exception as e: results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id)) return {"messages": results} @staticmethod def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']: """决定下一步:工具调用、保存记忆还是结束""" last_message = state["messages"][-1] # 1. 如果需要调用工具,优先进入工具节点 if isinstance(last_message, AIMessage) and last_message.tool_calls: return 'tool_node' # 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑) # 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。 if isinstance(last_message, AIMessage): return 'save_memory' # 3. 其他情况(如只有用户消息)直接结束 return 'END' async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: """搜索并返回长期记忆""" user_id = runtime.context.user_id namespace = ("memories", user_id) query = str(state["messages"][-1].content) debug(f"\n{'='*60}") debug(f"🔎 [记忆检索] 开始检索") debug(f" ├─ 用户ID: {user_id}") debug(f" ├─ 命名空间: {namespace}") debug(f" ├─ 查询内容: '{query}'") debug(f" └─ 消息总数: {len(state['messages'])}") try: memories = await runtime.store.asearch(namespace, query=query) debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆") if memories: memory_text = "\n".join([m.value["data"] for m in memories]) debug(f"📚 [记忆内容]") for i, memory in enumerate(memories, 1): debug(f" [{i}] {memory.value['data']}") debug(f"{'='*60}\n") return {"memory_context": memory_text} else: debug(f"⚠️ [记忆检索] 未找到相关记忆") debug(f"{'='*60}\n") return {"memory_context": ""} except Exception as e: error(f"❌ [记忆检索] 检索失败: {e}") import traceback traceback.print_exc() debug(f"{'='*60}\n") return {"memory_context": ""} async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: """尝试从对话中提取并保存长期记忆""" # 获取最后一条用户消息(通常是要记住的内容的来源) user_messages = [msg for msg in state["messages"] if msg.type == "human"] if not user_messages: return {} last_user_msg = user_messages[-1].content.lower() # 简单触发逻辑:包含"记住"或"保存"等关键词 if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]): # 提取记忆内容(这里仅作示例,实际可用 LLM 提取) memory_content = f"用户说过:{last_user_msg}" user_id = runtime.context.user_id namespace = ("memories", user_id) await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content}) info(f"✅ 长期记忆已保存:{memory_content}") return {} def build(self) -> StateGraph: """ 构建未编译的状态图(返回 StateGraph 实例) 图中节点直接使用实例方法 call_llm, call_tools """ builder = StateGraph(MessagesState,context_schema=GraphContext) builder.add_node("retrieve_memory", self.retrieve_memory) builder.add_node("llm_call", self.call_llm) builder.add_node("tool_node", self.call_tools) builder.add_node("save_memory", self.save_memory) builder.add_edge(START, "retrieve_memory") builder.add_edge("retrieve_memory", "llm_call") builder.add_conditional_edges( "llm_call", self.should_continue, { "tool_node": "tool_node", "save_memory": "save_memory", 'END': END } ) builder.add_edge("tool_node", "llm_call") builder.add_edge("save_memory", END) return builder