长期记忆
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 33s

This commit is contained in:
2026-04-14 20:39:58 +08:00
parent 8dd94c6c19
commit de68916c5a
2 changed files with 130 additions and 54 deletions

View File

@@ -91,13 +91,6 @@ class AIAgentService:
try: try:
info(f"🔄 正在初始化模型 '{model_name}'...") info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator() llm = llm_creator()
# 测试 LLM 连接(可选,用于调试)
if model_name == "local":
debug(f" 测试 vLLM 连接: {os.getenv('VLLM_BASE_URL', '未设置')}")
elif model_name == "deepseek":
debug(f" 测试 DeepSeek API 连接: https://api.deepseek.com")
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build() builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer, store=self.store) graph = builder.compile(checkpointer=self.checkpointer, store=self.store)
self.graphs[model_name] = graph self.graphs[model_name] = graph

View File

@@ -5,21 +5,27 @@ LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
import operator import operator
import asyncio import asyncio
import time import time
import os
from typing import Literal, Annotated, Any from typing import Literal, Annotated, Any
from langchain_core.language_models import BaseLLM from langchain_core.language_models import BaseLLM
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda
from langgraph.graph import StateGraph, START, END from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langgraph.store.postgres.aio import AsyncPostgresStore from langgraph.store.postgres.aio import AsyncPostgresStore
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from dataclasses import dataclass from dataclasses import dataclass
import uuid import uuid
from langchain_core.prompt_values import ChatPromptValue
# 本地模块 # 本地模块
from app.logger import debug, info, warning, error from app.logger import debug, info, warning, error
# 是否启用 Graph 流转追踪(通过环境变量控制)
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
class MessagesState(TypedDict): class MessagesState(TypedDict):
"""对话状态类型定义""" """对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add] messages: Annotated[list[AnyMessage], operator.add]
@@ -33,6 +39,53 @@ class GraphContext:
user_id: str user_id: str
# 可扩展更多上下文信息 # 可扩展更多上下文信息
def _log_state_change(node_name: str, state: MessagesState, prefix: str = "进入"):
"""
通用辅助函数:打印节点状态变化
Args:
node_name: 节点名称
state: 当前状态
prefix: 前缀("进入""离开"
"""
if not ENABLE_GRAPH_TRACE:
return
messages = state.get("messages", [])
msg_count = len(messages)
last_msg = messages[-1] if messages else None
last_info = ""
if last_msg:
content_preview = str(last_msg.content)[:100].replace("\n", " ")
last_info = f"{last_msg.type.upper()}: {content_preview}"
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")
def _print_llm_input(prompt_value: ChatPromptValue) -> ChatPromptValue:
"""
RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息
Args:
prompt_value: ChatPromptValue 对象,包含格式化后的消息列表
Returns:
原样返回 prompt_value不影响链式调用
"""
if not ENABLE_GRAPH_TRACE:
return prompt_value
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
debug("\n" + "=" * 80)
debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:")
debug(f" 总消息数: {len(messages)}")
debug("-" * 80)
for i, msg in enumerate(messages):
content_preview = str(msg.content) # 完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug( "\n"+"=" * 80 + "\n")
return prompt_value
class GraphBuilder: class GraphBuilder:
"""LangGraph 状态图构建器 - 所有节点均为类方法""" """LangGraph 状态图构建器 - 所有节点均为类方法"""
@@ -50,30 +103,35 @@ class GraphBuilder:
self.tools_by_name = tools_by_name self.tools_by_name = tools_by_name
self._llm_with_tools = llm.bind_tools(tools) self._llm_with_tools = llm.bind_tools(tools)
self._prompt = self._create_prompt() self._prompt = self._create_prompt()
self._chain = self._prompt | self._llm_with_tools self._chain = (
self._prompt
| RunnableLambda(_print_llm_input)
| self._llm_with_tools
)
@staticmethod @staticmethod
def _create_prompt() -> ChatPromptTemplate: def _create_prompt() -> ChatPromptTemplate:
"""创建系统提示模板(静态方法,无需访问实例)""" """创建系统提示模板(静态方法,无需访问实例)"""
system_template = (
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
"2. 优先利用已知用户信息进行个性化回复。\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
)
return ChatPromptTemplate.from_messages([ return ChatPromptTemplate.from_messages([
SystemMessage(content=( ("system", system_template),
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
"2. 优先利用已知用户信息进行个性化回复。\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
)),
MessagesPlaceholder(variable_name="messages") MessagesPlaceholder(variable_name="messages")
]) ])
@@ -82,26 +140,9 @@ class GraphBuilder:
LLM 调用节点(异步方法) LLM 调用节点(异步方法)
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环 注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
""" """
_log_state_change("llm_call", state, "进入")
memory_context = state.get("memory_context", "暂无用户信息") memory_context = state.get("memory_context", "暂无用户信息")
# 构建完整的输入消息列表(用于调试打印)
system_prompt = self._prompt.messages[0] # SystemMessage
if isinstance(system_prompt, SystemMessage):
system_content = system_prompt.content.format(memory_context=memory_context)
else:
system_content = str(system_prompt.content)
input_messages = [SystemMessage(content=system_content)] + state["messages"]
# 打印发送给大模型的最终输入
debug("\n" + "="*80)
debug("📤 [LLM输入] 发送给大模型的完整消息:")
debug(f" 总消息数: {len(input_messages)}")
for i, msg in enumerate(input_messages):
content_preview = str(msg.content) # 不截断,完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug("="*80 + "\n")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
start_time = time.time() start_time = time.time()
@@ -155,13 +196,16 @@ class GraphBuilder:
debug(f"{response.content}") debug(f"{response.content}")
debug("="*80 + "\n") debug("="*80 + "\n")
return { result = {
"messages": [response], "messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1, "llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage, "last_token_usage": token_usage,
"last_elapsed_time": elapsed_time "last_elapsed_time": elapsed_time
} }
_log_state_change("llm_call", {**state, **result}, "离开")
return result
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)") error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
@@ -175,33 +219,45 @@ class GraphBuilder:
error_response = AIMessage( error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。" content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
) )
return { error_result = {
"messages": [error_response], "messages": [error_response],
"llm_calls": state.get('llm_calls', 0), "llm_calls": state.get('llm_calls', 0),
"last_token_usage": {}, "last_token_usage": {},
"last_elapsed_time": elapsed_time "last_elapsed_time": elapsed_time
} }
_log_state_change("llm_call", state, "离开(异常)")
return error_result
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
""" """
工具执行节点(异步方法) 工具执行节点(异步方法)
对于每个工具调用,在线程池中执行同步工具函数 对于每个工具调用,在线程池中执行同步工具函数
""" """
_log_state_change("tool_node", state, "进入")
last_message = state['messages'][-1] last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls: if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
_log_state_change("tool_node", state, "离开(无工具调用)")
return {"messages": []} return {"messages": []}
results = [] results = []
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
for tool_call in last_message.tool_calls: for tool_call in last_message.tool_calls:
tool_name = tool_call["name"] tool_name = tool_call["name"]
tool_args = tool_call["args"] tool_args = tool_call["args"]
tool_id = tool_call["id"] tool_id = tool_call["id"]
tool_func = self.tools_by_name.get(tool_name) tool_func = self.tools_by_name.get(tool_name)
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
if tool_func is None: if tool_func is None:
results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id)) err_msg = f"Tool {tool_name} not found"
debug(f" └─ ❌ {err_msg}")
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
continue continue
try: try:
@@ -214,11 +270,20 @@ class GraphBuilder:
None, None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值 lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
) )
# 字符打印
result_preview = str(observation).replace("\n", " ")
debug(f" └─ ✅ 结果: {result_preview}")
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id)) results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
except Exception as e: except Exception as e:
debug(f" └─ ❌ 异常: {e}")
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id)) results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
return {"messages": results} info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
result = {"messages": results}
_log_state_change("tool_node", {**state, **result}, "离开")
return result
@staticmethod @staticmethod
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']: def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
@@ -227,18 +292,26 @@ class GraphBuilder:
# 1. 如果需要调用工具,优先进入工具节点 # 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls: if isinstance(last_message, AIMessage) and last_message.tool_calls:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
return 'tool_node' return 'tool_node'
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑) # 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。 # 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
if isinstance(last_message, AIMessage): if isinstance(last_message, AIMessage):
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复(无工具调用) → 转向 'save_memory'")
return 'save_memory' return 'save_memory'
# 3. 其他情况(如只有用户消息)直接结束 # 3. 其他情况(如只有用户消息)直接结束
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
return 'END' return 'END'
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""搜索并返回长期记忆""" """搜索并返回长期记忆"""
_log_state_change("retrieve_memory", state, "进入")
user_id = runtime.context.user_id user_id = runtime.context.user_id
namespace = ("memories", user_id) namespace = ("memories", user_id)
query = str(state["messages"][-1].content) query = str(state["messages"][-1].content)
@@ -260,24 +333,33 @@ class GraphBuilder:
for i, memory in enumerate(memories, 1): for i, memory in enumerate(memories, 1):
debug(f" [{i}] {memory.value['data']}") debug(f" [{i}] {memory.value['data']}")
debug(f"{'='*60}\n") debug(f"{'='*60}\n")
return {"memory_context": memory_text} result = {"memory_context": memory_text}
_log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
else: else:
debug(f"⚠️ [记忆检索] 未找到相关记忆") debug(f"⚠️ [记忆检索] 未找到相关记忆")
debug(f"{'='*60}\n") debug(f"{'='*60}\n")
return {"memory_context": ""} result = {"memory_context": ""}
_log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
except Exception as e: except Exception as e:
error(f"❌ [记忆检索] 检索失败: {e}") error(f"❌ [记忆检索] 检索失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
debug(f"{'='*60}\n") debug(f"{'='*60}\n")
return {"memory_context": ""} result = {"memory_context": ""}
_log_state_change("retrieve_memory", {**state, **result}, "离开(异常)")
return result
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict: async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""尝试从对话中提取并保存长期记忆""" """尝试从对话中提取并保存长期记忆"""
_log_state_change("save_memory", state, "进入")
# 获取最后一条用户消息(通常是要记住的内容的来源) # 获取最后一条用户消息(通常是要记住的内容的来源)
user_messages = [msg for msg in state["messages"] if msg.type == "human"] user_messages = [msg for msg in state["messages"] if msg.type == "human"]
if not user_messages: if not user_messages:
_log_state_change("save_memory", state, "离开(无用户消息)")
return {} return {}
last_user_msg = user_messages[-1].content.lower() last_user_msg = user_messages[-1].content.lower()
@@ -291,6 +373,7 @@ class GraphBuilder:
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content}) await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
info(f"✅ 长期记忆已保存:{memory_content}") info(f"✅ 长期记忆已保存:{memory_content}")
_log_state_change("save_memory", state, "离开")
return {} return {}
def build(self) -> StateGraph: def build(self) -> StateGraph: