Files
ailine/app/graph_builder.py
root 8dd94c6c19
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 27s
添加长期记忆
2026-04-14 17:34:12 +08:00

321 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
"""
import operator
import asyncio
import time
from typing import Literal, Annotated, Any
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from langgraph.store.postgres.aio import AsyncPostgresStore
from langgraph.runtime import Runtime
from dataclasses import dataclass
import uuid
# 本地模块
from app.logger import debug, info, warning, error
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
memory_context:str
last_token_usage: dict # 本次调用的 token 使用详情
last_elapsed_time: float # 本次调用耗时(秒)
@dataclass
class GraphContext:
user_id: str
# 可扩展更多上下文信息
class GraphBuilder:
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
self._llm_with_tools = llm.bind_tools(tools)
self._prompt = self._create_prompt()
self._chain = self._prompt | self._llm_with_tools
@staticmethod
def _create_prompt() -> ChatPromptTemplate:
"""创建系统提示模板(静态方法,无需访问实例)"""
return ChatPromptTemplate.from_messages([
SystemMessage(content=(
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
"2. 优先利用已知用户信息进行个性化回复。\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
)),
MessagesPlaceholder(variable_name="messages")
])
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
LLM 调用节点(异步方法)
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
"""
memory_context = state.get("memory_context", "暂无用户信息")
# 构建完整的输入消息列表(用于调试打印)
system_prompt = self._prompt.messages[0] # SystemMessage
if isinstance(system_prompt, SystemMessage):
system_content = system_prompt.content.format(memory_context=memory_context)
else:
system_content = str(system_prompt.content)
input_messages = [SystemMessage(content=system_content)] + state["messages"]
# 打印发送给大模型的最终输入
debug("\n" + "="*80)
debug("📤 [LLM输入] 发送给大模型的完整消息:")
debug(f" 总消息数: {len(input_messages)}")
for i, msg in enumerate(input_messages):
content_preview = str(msg.content) # 不截断,完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug("="*80 + "\n")
loop = asyncio.get_event_loop()
start_time = time.time()
try:
response = await loop.run_in_executor(
None,
lambda: self._chain.invoke({
"messages": state["messages"],
"memory_context": memory_context
})
)
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
debug("="*80 + "\n")
return {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time
}
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
return {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time
}
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
工具执行节点(异步方法)
对于每个工具调用,在线程池中执行同步工具函数
"""
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return {"messages": []}
results = []
loop = asyncio.get_event_loop()
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = self.tools_by_name.get(tool_name)
if tool_func is None:
results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id))
continue
try:
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
)
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
except Exception as e:
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
return {"messages": results}
@staticmethod
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
"""决定下一步:工具调用、保存记忆还是结束"""
last_message = state["messages"][-1]
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return 'tool_node'
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
if isinstance(last_message, AIMessage):
return 'save_memory'
# 3. 其他情况(如只有用户消息)直接结束
return 'END'
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""搜索并返回长期记忆"""
user_id = runtime.context.user_id
namespace = ("memories", user_id)
query = str(state["messages"][-1].content)
debug(f"\n{'='*60}")
debug(f"🔎 [记忆检索] 开始检索")
debug(f" ├─ 用户ID: {user_id}")
debug(f" ├─ 命名空间: {namespace}")
debug(f" ├─ 查询内容: '{query}'")
debug(f" └─ 消息总数: {len(state['messages'])}")
try:
memories = await runtime.store.asearch(namespace, query=query)
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
if memories:
memory_text = "\n".join([m.value["data"] for m in memories])
debug(f"📚 [记忆内容]")
for i, memory in enumerate(memories, 1):
debug(f" [{i}] {memory.value['data']}")
debug(f"{'='*60}\n")
return {"memory_context": memory_text}
else:
debug(f"⚠️ [记忆检索] 未找到相关记忆")
debug(f"{'='*60}\n")
return {"memory_context": ""}
except Exception as e:
error(f"❌ [记忆检索] 检索失败: {e}")
import traceback
traceback.print_exc()
debug(f"{'='*60}\n")
return {"memory_context": ""}
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""尝试从对话中提取并保存长期记忆"""
# 获取最后一条用户消息(通常是要记住的内容的来源)
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
if not user_messages:
return {}
last_user_msg = user_messages[-1].content.lower()
# 简单触发逻辑:包含"记住"或"保存"等关键词
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
memory_content = f"用户说过:{last_user_msg}"
user_id = runtime.context.user_id
namespace = ("memories", user_id)
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
info(f"✅ 长期记忆已保存:{memory_content}")
return {}
def build(self) -> StateGraph:
"""
构建未编译的状态图(返回 StateGraph 实例)
图中节点直接使用实例方法 call_llm, call_tools
"""
builder = StateGraph(MessagesState,context_schema=GraphContext)
builder.add_node("retrieve_memory", self.retrieve_memory)
builder.add_node("llm_call", self.call_llm)
builder.add_node("tool_node", self.call_tools)
builder.add_node("save_memory", self.save_memory)
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
self.should_continue,
{
"tool_node": "tool_node",
"save_memory": "save_memory",
'END': END
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("save_memory", END)
return builder