添加长期记忆
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 27s

This commit is contained in:
2026-04-14 17:34:12 +08:00
parent 1bea2491c5
commit 8dd94c6c19
12 changed files with 953 additions and 197 deletions

View File

@@ -4,19 +4,34 @@ LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
import operator
import asyncio
import time
from typing import Literal, Annotated, Any
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from langgraph.store.postgres.aio import AsyncPostgresStore
from langgraph.runtime import Runtime
from dataclasses import dataclass
import uuid
# 本地模块
from app.logger import debug, info, warning, error
class MessageState(TypedDict):
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
memory_context:str
last_token_usage: dict # 本次调用的 token 使用详情
last_elapsed_time: float # 本次调用耗时(秒)
@dataclass
class GraphContext:
user_id: str
# 可扩展更多上下文信息
class GraphBuilder:
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
@@ -42,33 +57,132 @@ class GraphBuilder:
"""创建系统提示模板(静态方法,无需访问实例)"""
return ChatPromptTemplate.from_messages([
SystemMessage(content=(
"你是一个个人生活助手和数据分析助手。请说中文。"
"用户询问天气或温度时,使用get_current_temperature工具获取信息。"
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
"2. 优先利用已知用户信息进行个性化回复。\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
)),
MessagesPlaceholder(variable_name="message")
MessagesPlaceholder(variable_name="messages")
])
async def call_llm(self, state: MessageState) -> dict:
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
LLM 调用节点(异步方法)
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
"""
memory_context = state.get("memory_context", "暂无用户信息")
# 构建完整的输入消息列表(用于调试打印)
system_prompt = self._prompt.messages[0] # SystemMessage
if isinstance(system_prompt, SystemMessage):
system_content = system_prompt.content.format(memory_context=memory_context)
else:
system_content = str(system_prompt.content)
input_messages = [SystemMessage(content=system_content)] + state["messages"]
# 打印发送给大模型的最终输入
debug("\n" + "="*80)
debug("📤 [LLM输入] 发送给大模型的完整消息:")
debug(f" 总消息数: {len(input_messages)}")
for i, msg in enumerate(input_messages):
content_preview = str(msg.content) # 不截断,完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug("="*80 + "\n")
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self._chain.invoke({"message": state["messages"]})
)
return {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1
}
start_time = time.time()
try:
response = await loop.run_in_executor(
None,
lambda: self._chain.invoke({
"messages": state["messages"],
"memory_context": memory_context
})
)
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
debug("="*80 + "\n")
return {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time
}
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
return {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time
}
async def call_tools(self, state: MessageState) -> dict:
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
工具执行节点(异步方法)
对于每个工具调用,在线程池中执行同步工具函数
@@ -91,11 +205,15 @@ class GraphBuilder:
continue
try:
# 同步工具函数在线程池中执行
observation = await loop.run_in_executor(
None,
lambda: tool_func.invoke(tool_args)
)
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
)
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
except Exception as e:
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
@@ -103,25 +221,101 @@ class GraphBuilder:
return {"messages": results}
@staticmethod
def should_continue(state: MessageState) -> Literal['tool_node', END]:
"""
条件边判断(静态方法)
决定下一步是进入工具节点还是结束
"""
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
"""决定下一步:工具调用、保存记忆还是结束"""
last_message = state["messages"][-1]
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return 'tool_node'
return END
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
if isinstance(last_message, AIMessage):
return 'save_memory'
# 3. 其他情况(如只有用户消息)直接结束
return 'END'
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""搜索并返回长期记忆"""
user_id = runtime.context.user_id
namespace = ("memories", user_id)
query = str(state["messages"][-1].content)
debug(f"\n{'='*60}")
debug(f"🔎 [记忆检索] 开始检索")
debug(f" ├─ 用户ID: {user_id}")
debug(f" ├─ 命名空间: {namespace}")
debug(f" ├─ 查询内容: '{query}'")
debug(f" └─ 消息总数: {len(state['messages'])}")
try:
memories = await runtime.store.asearch(namespace, query=query)
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
if memories:
memory_text = "\n".join([m.value["data"] for m in memories])
debug(f"📚 [记忆内容]")
for i, memory in enumerate(memories, 1):
debug(f" [{i}] {memory.value['data']}")
debug(f"{'='*60}\n")
return {"memory_context": memory_text}
else:
debug(f"⚠️ [记忆检索] 未找到相关记忆")
debug(f"{'='*60}\n")
return {"memory_context": ""}
except Exception as e:
error(f"❌ [记忆检索] 检索失败: {e}")
import traceback
traceback.print_exc()
debug(f"{'='*60}\n")
return {"memory_context": ""}
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""尝试从对话中提取并保存长期记忆"""
# 获取最后一条用户消息(通常是要记住的内容的来源)
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
if not user_messages:
return {}
last_user_msg = user_messages[-1].content.lower()
# 简单触发逻辑:包含"记住"或"保存"等关键词
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
memory_content = f"用户说过:{last_user_msg}"
user_id = runtime.context.user_id
namespace = ("memories", user_id)
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
info(f"✅ 长期记忆已保存:{memory_content}")
return {}
def build(self) -> StateGraph:
"""
构建未编译的状态图(返回 StateGraph 实例)
图中节点直接使用实例方法 call_llm, call_tools
"""
builder = StateGraph(MessageState)
builder = StateGraph(MessagesState,context_schema=GraphContext)
builder.add_node("retrieve_memory", self.retrieve_memory)
builder.add_node("llm_call", self.call_llm)
builder.add_node("tool_node", self.call_tools)
builder.add_edge(START, "llm_call")
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
builder.add_node("save_memory", self.save_memory)
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
self.should_continue,
{
"tool_node": "tool_node",
"save_memory": "save_memory",
'END': END
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("save_memory", END)
return builder