""" LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数 """ import operator import asyncio import time import os from typing import Literal, Annotated, Any from langchain_core.language_models import BaseLLM from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableLambda from langgraph.graph import StateGraph, START, END from typing_extensions import TypedDict from langgraph.store.postgres.aio import AsyncPostgresStore from langgraph.runtime import Runtime from dataclasses import dataclass import uuid from langchain_core.prompt_values import ChatPromptValue # 本地模块 from app.logger import debug, info, warning, error # 是否启用 Graph 流转追踪(通过环境变量控制) ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true" class MessagesState(TypedDict): """对话状态类型定义""" messages: Annotated[list[AnyMessage], operator.add] llm_calls: int memory_context:str last_token_usage: dict # 本次调用的 token 使用详情 last_elapsed_time: float # 本次调用耗时(秒) @dataclass class GraphContext: user_id: str # 可扩展更多上下文信息 def _log_state_change(node_name: str, state: MessagesState, prefix: str = "进入"): """ 通用辅助函数:打印节点状态变化 Args: node_name: 节点名称 state: 当前状态 prefix: 前缀("进入" 或 "离开") """ if not ENABLE_GRAPH_TRACE: return messages = state.get("messages", []) msg_count = len(messages) last_msg = messages[-1] if messages else None last_info = "" if last_msg: content_preview = str(last_msg.content)[:100].replace("\n", " ") last_info = f"{last_msg.type.upper()}: {content_preview}" info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}") def _print_llm_input(prompt_value: ChatPromptValue) -> ChatPromptValue: """ RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息 Args: prompt_value: ChatPromptValue 对象,包含格式化后的消息列表 Returns: 原样返回 prompt_value,不影响链式调用 """ if not ENABLE_GRAPH_TRACE: return prompt_value messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性 debug("\n" + "=" * 80) debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:") debug(f" 总消息数: {len(messages)}") debug("-" * 80) for i, msg in enumerate(messages): content_preview = str(msg.content) # 完整输出 debug(f" [{i}] {msg.type.upper():10s}: {content_preview}") debug( "\n"+"=" * 80 + "\n") return prompt_value class GraphBuilder: """LangGraph 状态图构建器 - 所有节点均为类方法""" def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]): """ 初始化构建器 Args: llm: 大语言模型实例 tools: 工具列表 tools_by_name: 名称到工具函数的映射 """ self.llm = llm self.tools = tools self.tools_by_name = tools_by_name self._llm_with_tools = llm.bind_tools(tools) self._prompt = self._create_prompt() self._chain = ( self._prompt | RunnableLambda(_print_llm_input) | self._llm_with_tools ) @staticmethod def _create_prompt() -> ChatPromptTemplate: """创建系统提示模板(静态方法,无需访问实例)""" system_template = ( "你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n" "【用户背景信息】\n" "以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n" "{memory_context}\n" "若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n" "【可用工具与使用规则】\n" "- 获取温度/天气:`get_current_temperature`\n" "- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n" "- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n" "- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n" "- 抓取网页内容:`fetch_webpage_content`\n" "工具调用时请直接返回所需参数,无需额外说明。\n\n" "【回答要求(必须遵守)】\n" "1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n" "2. 优先利用已知用户信息进行个性化回复。\n" "3. 若无信息可依,礼貌询问或提供通用帮助。" ) return ChatPromptTemplate.from_messages([ ("system", system_template), MessagesPlaceholder(variable_name="messages") ]) async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: """ LLM 调用节点(异步方法) 注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环 """ _log_state_change("llm_call", state, "进入") memory_context = state.get("memory_context", "暂无用户信息") loop = asyncio.get_event_loop() start_time = time.time() try: response = await loop.run_in_executor( None, lambda: self._chain.invoke({ "messages": state["messages"], "memory_context": memory_context }) ) elapsed_time = time.time() - start_time # 提取 token 用量(兼容不同 LLM 提供商的元数据格式) token_usage = {} input_tokens = 0 output_tokens = 0 # 尝试从 response_metadata 中提取 if hasattr(response, 'response_metadata') and response.response_metadata: meta = response.response_metadata if 'token_usage' in meta: token_usage = meta['token_usage'] elif 'usage' in meta: token_usage = meta['usage'] # 尝试从 additional_kwargs 中提取 if not token_usage and hasattr(response, 'additional_kwargs'): add_kwargs = response.additional_kwargs if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']: token_usage = add_kwargs['llm_output']['token_usage'] # 提取具体的 token 数值 if token_usage: input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0)) output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0)) # 打印响应统计信息 info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒") info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}") if token_usage: debug(f"📋 [LLM统计] 详细用量: {token_usage}") # 打印 LLM 的完整输出 debug("\n" + "="*80) debug("📥 [LLM输出] 大模型返回的完整响应:") debug(f" 消息类型: {response.type.upper()}") debug(f" 内容长度: {len(str(response.content))} 字符") debug("-"*80) debug(f"{response.content}") debug("="*80 + "\n") result = { "messages": [response], "llm_calls": state.get('llm_calls', 0) + 1, "last_token_usage": token_usage, "last_elapsed_time": elapsed_time } _log_state_change("llm_call", {**state, **result}, "离开") return result except Exception as e: elapsed_time = time.time() - start_time error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)") error(f" 错误类型: {type(e).__name__}") error(f" 错误信息: {str(e)}") import traceback traceback.print_exc() debug("="*80 + "\n") # 返回一个友好的错误消息 error_response = AIMessage( content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。" ) error_result = { "messages": [error_response], "llm_calls": state.get('llm_calls', 0), "last_token_usage": {}, "last_elapsed_time": elapsed_time } _log_state_change("llm_call", state, "离开(异常)") return error_result async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: """ 工具执行节点(异步方法) 对于每个工具调用,在线程池中执行同步工具函数 """ _log_state_change("tool_node", state, "进入") last_message = state['messages'][-1] if not isinstance(last_message, AIMessage) or not last_message.tool_calls: _log_state_change("tool_node", state, "离开(无工具调用)") return {"messages": []} results = [] loop = asyncio.get_event_loop() info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具") for tool_call in last_message.tool_calls: tool_name = tool_call["name"] tool_args = tool_call["args"] tool_id = tool_call["id"] tool_func = self.tools_by_name.get(tool_name) debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}") if tool_func is None: err_msg = f"Tool {tool_name} not found" debug(f" └─ ❌ {err_msg}") results.append(ToolMessage(content=err_msg, tool_call_id=tool_id)) continue try: # 修复闭包问题:将变量作为默认参数传入 lambda # 如果工具支持异步 (ainvoke),优先使用异步调用 if hasattr(tool_func, 'ainvoke'): observation = await tool_func.ainvoke(tool_args) else: observation = await loop.run_in_executor( None, lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值 ) # 字符打印 result_preview = str(observation).replace("\n", " ") debug(f" └─ ✅ 结果: {result_preview}") results.append(ToolMessage(content=str(observation), tool_call_id=tool_id)) except Exception as e: debug(f" └─ ❌ 异常: {e}") results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id)) info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage") result = {"messages": results} _log_state_change("tool_node", {**state, **result}, "离开") return result @staticmethod def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']: """决定下一步:工具调用、保存记忆还是结束""" last_message = state["messages"][-1] # 1. 如果需要调用工具,优先进入工具节点 if isinstance(last_message, AIMessage) and last_message.tool_calls: if ENABLE_GRAPH_TRACE: info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'") return 'tool_node' # 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑) # 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。 if isinstance(last_message, AIMessage): if ENABLE_GRAPH_TRACE: info(f"🔀 [路由决策] 收到 AI 最终回复(无工具调用) → 转向 'save_memory'") return 'save_memory' # 3. 其他情况(如只有用户消息)直接结束 if ENABLE_GRAPH_TRACE: info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程") return 'END' async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: """搜索并返回长期记忆""" _log_state_change("retrieve_memory", state, "进入") user_id = runtime.context.user_id namespace = ("memories", user_id) query = str(state["messages"][-1].content) debug(f"\n{'='*60}") debug(f"🔎 [记忆检索] 开始检索") debug(f" ├─ 用户ID: {user_id}") debug(f" ├─ 命名空间: {namespace}") debug(f" ├─ 查询内容: '{query}'") debug(f" └─ 消息总数: {len(state['messages'])}") try: memories = await runtime.store.asearch(namespace, query=query) debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆") if memories: memory_text = "\n".join([m.value["data"] for m in memories]) debug(f"📚 [记忆内容]") for i, memory in enumerate(memories, 1): debug(f" [{i}] {memory.value['data']}") debug(f"{'='*60}\n") result = {"memory_context": memory_text} _log_state_change("retrieve_memory", {**state, **result}, "离开") return result else: debug(f"⚠️ [记忆检索] 未找到相关记忆") debug(f"{'='*60}\n") result = {"memory_context": ""} _log_state_change("retrieve_memory", {**state, **result}, "离开") return result except Exception as e: error(f"❌ [记忆检索] 检索失败: {e}") import traceback traceback.print_exc() debug(f"{'='*60}\n") result = {"memory_context": ""} _log_state_change("retrieve_memory", {**state, **result}, "离开(异常)") return result async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: """尝试从对话中提取并保存长期记忆""" _log_state_change("save_memory", state, "进入") # 获取最后一条用户消息(通常是要记住的内容的来源) user_messages = [msg for msg in state["messages"] if msg.type == "human"] if not user_messages: _log_state_change("save_memory", state, "离开(无用户消息)") return {} last_user_msg = user_messages[-1].content.lower() # 简单触发逻辑:包含"记住"或"保存"等关键词 if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]): # 提取记忆内容(这里仅作示例,实际可用 LLM 提取) memory_content = f"用户说过:{last_user_msg}" user_id = runtime.context.user_id namespace = ("memories", user_id) await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content}) info(f"✅ 长期记忆已保存:{memory_content}") _log_state_change("save_memory", state, "离开") return {} def build(self) -> StateGraph: """ 构建未编译的状态图(返回 StateGraph 实例) 图中节点直接使用实例方法 call_llm, call_tools """ builder = StateGraph(MessagesState,context_schema=GraphContext) builder.add_node("retrieve_memory", self.retrieve_memory) builder.add_node("llm_call", self.call_llm) builder.add_node("tool_node", self.call_tools) builder.add_node("save_memory", self.save_memory) builder.add_edge(START, "retrieve_memory") builder.add_edge("retrieve_memory", "llm_call") builder.add_conditional_edges( "llm_call", self.should_continue, { "tool_node": "tool_node", "save_memory": "save_memory", 'END': END } ) builder.add_edge("tool_node", "llm_call") builder.add_edge("save_memory", END) return builder