Files
ailine/app/graph_builder.py
root de68916c5a
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 33s
长期记忆
2026-04-14 20:39:58 +08:00

404 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
"""
import operator
import asyncio
import time
import os
from typing import Literal, Annotated, Any
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from langgraph.store.postgres.aio import AsyncPostgresStore
from langgraph.runtime import Runtime
from dataclasses import dataclass
import uuid
from langchain_core.prompt_values import ChatPromptValue
# 本地模块
from app.logger import debug, info, warning, error
# 是否启用 Graph 流转追踪(通过环境变量控制)
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
memory_context:str
last_token_usage: dict # 本次调用的 token 使用详情
last_elapsed_time: float # 本次调用耗时(秒)
@dataclass
class GraphContext:
user_id: str
# 可扩展更多上下文信息
def _log_state_change(node_name: str, state: MessagesState, prefix: str = "进入"):
"""
通用辅助函数:打印节点状态变化
Args:
node_name: 节点名称
state: 当前状态
prefix: 前缀("进入""离开"
"""
if not ENABLE_GRAPH_TRACE:
return
messages = state.get("messages", [])
msg_count = len(messages)
last_msg = messages[-1] if messages else None
last_info = ""
if last_msg:
content_preview = str(last_msg.content)[:100].replace("\n", " ")
last_info = f"{last_msg.type.upper()}: {content_preview}"
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")
def _print_llm_input(prompt_value: ChatPromptValue) -> ChatPromptValue:
"""
RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息
Args:
prompt_value: ChatPromptValue 对象,包含格式化后的消息列表
Returns:
原样返回 prompt_value不影响链式调用
"""
if not ENABLE_GRAPH_TRACE:
return prompt_value
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
debug("\n" + "=" * 80)
debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:")
debug(f" 总消息数: {len(messages)}")
debug("-" * 80)
for i, msg in enumerate(messages):
content_preview = str(msg.content) # 完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug( "\n"+"=" * 80 + "\n")
return prompt_value
class GraphBuilder:
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
self._llm_with_tools = llm.bind_tools(tools)
self._prompt = self._create_prompt()
self._chain = (
self._prompt
| RunnableLambda(_print_llm_input)
| self._llm_with_tools
)
@staticmethod
def _create_prompt() -> ChatPromptTemplate:
"""创建系统提示模板(静态方法,无需访问实例)"""
system_template = (
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
"2. 优先利用已知用户信息进行个性化回复。\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
)
return ChatPromptTemplate.from_messages([
("system", system_template),
MessagesPlaceholder(variable_name="messages")
])
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
LLM 调用节点(异步方法)
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
"""
_log_state_change("llm_call", state, "进入")
memory_context = state.get("memory_context", "暂无用户信息")
loop = asyncio.get_event_loop()
start_time = time.time()
try:
response = await loop.run_in_executor(
None,
lambda: self._chain.invoke({
"messages": state["messages"],
"memory_context": memory_context
})
)
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
debug("="*80 + "\n")
result = {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time
}
_log_state_change("llm_call", {**state, **result}, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
error_result = {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time
}
_log_state_change("llm_call", state, "离开(异常)")
return error_result
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
工具执行节点(异步方法)
对于每个工具调用,在线程池中执行同步工具函数
"""
_log_state_change("tool_node", state, "进入")
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
_log_state_change("tool_node", state, "离开(无工具调用)")
return {"messages": []}
results = []
loop = asyncio.get_event_loop()
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = self.tools_by_name.get(tool_name)
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
if tool_func is None:
err_msg = f"Tool {tool_name} not found"
debug(f" └─ ❌ {err_msg}")
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
continue
try:
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
)
# 字符打印
result_preview = str(observation).replace("\n", " ")
debug(f" └─ ✅ 结果: {result_preview}")
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
except Exception as e:
debug(f" └─ ❌ 异常: {e}")
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
result = {"messages": results}
_log_state_change("tool_node", {**state, **result}, "离开")
return result
@staticmethod
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
"""决定下一步:工具调用、保存记忆还是结束"""
last_message = state["messages"][-1]
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
return 'tool_node'
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
if isinstance(last_message, AIMessage):
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复(无工具调用) → 转向 'save_memory'")
return 'save_memory'
# 3. 其他情况(如只有用户消息)直接结束
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
return 'END'
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""搜索并返回长期记忆"""
_log_state_change("retrieve_memory", state, "进入")
user_id = runtime.context.user_id
namespace = ("memories", user_id)
query = str(state["messages"][-1].content)
debug(f"\n{'='*60}")
debug(f"🔎 [记忆检索] 开始检索")
debug(f" ├─ 用户ID: {user_id}")
debug(f" ├─ 命名空间: {namespace}")
debug(f" ├─ 查询内容: '{query}'")
debug(f" └─ 消息总数: {len(state['messages'])}")
try:
memories = await runtime.store.asearch(namespace, query=query)
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
if memories:
memory_text = "\n".join([m.value["data"] for m in memories])
debug(f"📚 [记忆内容]")
for i, memory in enumerate(memories, 1):
debug(f" [{i}] {memory.value['data']}")
debug(f"{'='*60}\n")
result = {"memory_context": memory_text}
_log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
else:
debug(f"⚠️ [记忆检索] 未找到相关记忆")
debug(f"{'='*60}\n")
result = {"memory_context": ""}
_log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
except Exception as e:
error(f"❌ [记忆检索] 检索失败: {e}")
import traceback
traceback.print_exc()
debug(f"{'='*60}\n")
result = {"memory_context": ""}
_log_state_change("retrieve_memory", {**state, **result}, "离开(异常)")
return result
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""尝试从对话中提取并保存长期记忆"""
_log_state_change("save_memory", state, "进入")
# 获取最后一条用户消息(通常是要记住的内容的来源)
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
if not user_messages:
_log_state_change("save_memory", state, "离开(无用户消息)")
return {}
last_user_msg = user_messages[-1].content.lower()
# 简单触发逻辑:包含"记住"或"保存"等关键词
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
memory_content = f"用户说过:{last_user_msg}"
user_id = runtime.context.user_id
namespace = ("memories", user_id)
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
info(f"✅ 长期记忆已保存:{memory_content}")
_log_state_change("save_memory", state, "离开")
return {}
def build(self) -> StateGraph:
"""
构建未编译的状态图(返回 StateGraph 实例)
图中节点直接使用实例方法 call_llm, call_tools
"""
builder = StateGraph(MessagesState,context_schema=GraphContext)
builder.add_node("retrieve_memory", self.retrieve_memory)
builder.add_node("llm_call", self.call_llm)
builder.add_node("tool_node", self.call_tools)
builder.add_node("save_memory", self.save_memory)
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
self.should_continue,
{
"tool_node": "tool_node",
"save_memory": "save_memory",
'END': END
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("save_memory", END)
return builder