321 lines
14 KiB
Python
321 lines
14 KiB
Python
"""
|
||
LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
|
||
"""
|
||
|
||
import operator
|
||
import asyncio
|
||
import time
|
||
from typing import Literal, Annotated, Any
|
||
from langchain_core.language_models import BaseLLM
|
||
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
|
||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||
from langgraph.graph import StateGraph, START, END
|
||
from typing_extensions import TypedDict
|
||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||
from langgraph.runtime import Runtime
|
||
from dataclasses import dataclass
|
||
import uuid
|
||
|
||
# 本地模块
|
||
from app.logger import debug, info, warning, error
|
||
|
||
|
||
class MessagesState(TypedDict):
|
||
"""对话状态类型定义"""
|
||
messages: Annotated[list[AnyMessage], operator.add]
|
||
llm_calls: int
|
||
memory_context:str
|
||
last_token_usage: dict # 本次调用的 token 使用详情
|
||
last_elapsed_time: float # 本次调用耗时(秒)
|
||
|
||
@dataclass
|
||
class GraphContext:
|
||
user_id: str
|
||
# 可扩展更多上下文信息
|
||
|
||
class GraphBuilder:
|
||
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
|
||
|
||
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]):
|
||
"""
|
||
初始化构建器
|
||
|
||
Args:
|
||
llm: 大语言模型实例
|
||
tools: 工具列表
|
||
tools_by_name: 名称到工具函数的映射
|
||
"""
|
||
self.llm = llm
|
||
self.tools = tools
|
||
self.tools_by_name = tools_by_name
|
||
self._llm_with_tools = llm.bind_tools(tools)
|
||
self._prompt = self._create_prompt()
|
||
self._chain = self._prompt | self._llm_with_tools
|
||
|
||
@staticmethod
|
||
def _create_prompt() -> ChatPromptTemplate:
|
||
"""创建系统提示模板(静态方法,无需访问实例)"""
|
||
return ChatPromptTemplate.from_messages([
|
||
SystemMessage(content=(
|
||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||
"【用户背景信息】\n"
|
||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||
"{memory_context}\n"
|
||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||
"【可用工具与使用规则】\n"
|
||
"- 获取温度/天气:`get_current_temperature`\n"
|
||
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n"
|
||
"- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n"
|
||
"- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n"
|
||
"- 抓取网页内容:`fetch_webpage_content`\n"
|
||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||
"【回答要求(必须遵守)】\n"
|
||
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
|
||
"2. 优先利用已知用户信息进行个性化回复。\n"
|
||
"3. 若无信息可依,礼貌询问或提供通用帮助。"
|
||
)),
|
||
MessagesPlaceholder(variable_name="messages")
|
||
])
|
||
|
||
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||
"""
|
||
LLM 调用节点(异步方法)
|
||
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
|
||
"""
|
||
memory_context = state.get("memory_context", "暂无用户信息")
|
||
|
||
# 构建完整的输入消息列表(用于调试打印)
|
||
system_prompt = self._prompt.messages[0] # SystemMessage
|
||
if isinstance(system_prompt, SystemMessage):
|
||
system_content = system_prompt.content.format(memory_context=memory_context)
|
||
else:
|
||
system_content = str(system_prompt.content)
|
||
|
||
input_messages = [SystemMessage(content=system_content)] + state["messages"]
|
||
|
||
# 打印发送给大模型的最终输入
|
||
debug("\n" + "="*80)
|
||
debug("📤 [LLM输入] 发送给大模型的完整消息:")
|
||
debug(f" 总消息数: {len(input_messages)}")
|
||
for i, msg in enumerate(input_messages):
|
||
content_preview = str(msg.content) # 不截断,完整输出
|
||
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
||
debug("="*80 + "\n")
|
||
|
||
loop = asyncio.get_event_loop()
|
||
start_time = time.time()
|
||
|
||
try:
|
||
response = await loop.run_in_executor(
|
||
None,
|
||
lambda: self._chain.invoke({
|
||
"messages": state["messages"],
|
||
"memory_context": memory_context
|
||
})
|
||
)
|
||
|
||
elapsed_time = time.time() - start_time
|
||
|
||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||
token_usage = {}
|
||
input_tokens = 0
|
||
output_tokens = 0
|
||
|
||
# 尝试从 response_metadata 中提取
|
||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||
meta = response.response_metadata
|
||
if 'token_usage' in meta:
|
||
token_usage = meta['token_usage']
|
||
elif 'usage' in meta:
|
||
token_usage = meta['usage']
|
||
|
||
# 尝试从 additional_kwargs 中提取
|
||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||
add_kwargs = response.additional_kwargs
|
||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||
token_usage = add_kwargs['llm_output']['token_usage']
|
||
|
||
# 提取具体的 token 数值
|
||
if token_usage:
|
||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||
|
||
# 打印响应统计信息
|
||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||
if token_usage:
|
||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||
|
||
# 打印 LLM 的完整输出
|
||
debug("\n" + "="*80)
|
||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||
debug(f" 消息类型: {response.type.upper()}")
|
||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||
debug("-"*80)
|
||
debug(f"{response.content}")
|
||
debug("="*80 + "\n")
|
||
|
||
return {
|
||
"messages": [response],
|
||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||
"last_token_usage": token_usage,
|
||
"last_elapsed_time": elapsed_time
|
||
}
|
||
|
||
except Exception as e:
|
||
elapsed_time = time.time() - start_time
|
||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||
error(f" 错误类型: {type(e).__name__}")
|
||
error(f" 错误信息: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
debug("="*80 + "\n")
|
||
|
||
# 返回一个友好的错误消息
|
||
error_response = AIMessage(
|
||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||
)
|
||
return {
|
||
"messages": [error_response],
|
||
"llm_calls": state.get('llm_calls', 0),
|
||
"last_token_usage": {},
|
||
"last_elapsed_time": elapsed_time
|
||
}
|
||
|
||
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||
"""
|
||
工具执行节点(异步方法)
|
||
对于每个工具调用,在线程池中执行同步工具函数
|
||
"""
|
||
last_message = state['messages'][-1]
|
||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||
return {"messages": []}
|
||
|
||
results = []
|
||
loop = asyncio.get_event_loop()
|
||
|
||
for tool_call in last_message.tool_calls:
|
||
tool_name = tool_call["name"]
|
||
tool_args = tool_call["args"]
|
||
tool_id = tool_call["id"]
|
||
tool_func = self.tools_by_name.get(tool_name)
|
||
|
||
if tool_func is None:
|
||
results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id))
|
||
continue
|
||
|
||
try:
|
||
# 修复闭包问题:将变量作为默认参数传入 lambda
|
||
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
||
if hasattr(tool_func, 'ainvoke'):
|
||
observation = await tool_func.ainvoke(tool_args)
|
||
else:
|
||
observation = await loop.run_in_executor(
|
||
None,
|
||
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
|
||
)
|
||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||
except Exception as e:
|
||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||
|
||
return {"messages": results}
|
||
|
||
@staticmethod
|
||
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
|
||
"""决定下一步:工具调用、保存记忆还是结束"""
|
||
last_message = state["messages"][-1]
|
||
|
||
# 1. 如果需要调用工具,优先进入工具节点
|
||
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
||
return 'tool_node'
|
||
|
||
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
|
||
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
|
||
if isinstance(last_message, AIMessage):
|
||
return 'save_memory'
|
||
|
||
# 3. 其他情况(如只有用户消息)直接结束
|
||
return 'END'
|
||
|
||
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||
"""搜索并返回长期记忆"""
|
||
user_id = runtime.context.user_id
|
||
namespace = ("memories", user_id)
|
||
query = str(state["messages"][-1].content)
|
||
|
||
debug(f"\n{'='*60}")
|
||
debug(f"🔎 [记忆检索] 开始检索")
|
||
debug(f" ├─ 用户ID: {user_id}")
|
||
debug(f" ├─ 命名空间: {namespace}")
|
||
debug(f" ├─ 查询内容: '{query}'")
|
||
debug(f" └─ 消息总数: {len(state['messages'])}")
|
||
|
||
try:
|
||
memories = await runtime.store.asearch(namespace, query=query)
|
||
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
|
||
|
||
if memories:
|
||
memory_text = "\n".join([m.value["data"] for m in memories])
|
||
debug(f"📚 [记忆内容]")
|
||
for i, memory in enumerate(memories, 1):
|
||
debug(f" [{i}] {memory.value['data']}")
|
||
debug(f"{'='*60}\n")
|
||
return {"memory_context": memory_text}
|
||
else:
|
||
debug(f"⚠️ [记忆检索] 未找到相关记忆")
|
||
debug(f"{'='*60}\n")
|
||
return {"memory_context": ""}
|
||
|
||
except Exception as e:
|
||
error(f"❌ [记忆检索] 检索失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
debug(f"{'='*60}\n")
|
||
return {"memory_context": ""}
|
||
|
||
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||
"""尝试从对话中提取并保存长期记忆"""
|
||
# 获取最后一条用户消息(通常是要记住的内容的来源)
|
||
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
|
||
if not user_messages:
|
||
return {}
|
||
|
||
last_user_msg = user_messages[-1].content.lower()
|
||
|
||
# 简单触发逻辑:包含"记住"或"保存"等关键词
|
||
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
|
||
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
|
||
memory_content = f"用户说过:{last_user_msg}"
|
||
user_id = runtime.context.user_id
|
||
namespace = ("memories", user_id)
|
||
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
|
||
info(f"✅ 长期记忆已保存:{memory_content}")
|
||
|
||
return {}
|
||
|
||
def build(self) -> StateGraph:
|
||
"""
|
||
构建未编译的状态图(返回 StateGraph 实例)
|
||
图中节点直接使用实例方法 call_llm, call_tools
|
||
"""
|
||
builder = StateGraph(MessagesState,context_schema=GraphContext)
|
||
builder.add_node("retrieve_memory", self.retrieve_memory)
|
||
builder.add_node("llm_call", self.call_llm)
|
||
builder.add_node("tool_node", self.call_tools)
|
||
builder.add_node("save_memory", self.save_memory)
|
||
|
||
builder.add_edge(START, "retrieve_memory")
|
||
builder.add_edge("retrieve_memory", "llm_call")
|
||
builder.add_conditional_edges(
|
||
"llm_call",
|
||
self.should_continue,
|
||
{
|
||
"tool_node": "tool_node",
|
||
"save_memory": "save_memory",
|
||
'END': END
|
||
}
|
||
)
|
||
builder.add_edge("tool_node", "llm_call")
|
||
builder.add_edge("save_memory", END)
|
||
|
||
return builder |