This commit is contained in:
@@ -4,19 +4,34 @@ LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
|
||||
|
||||
import operator
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Literal, Annotated, Any
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||||
from langgraph.runtime import Runtime
|
||||
from dataclasses import dataclass
|
||||
import uuid
|
||||
|
||||
# 本地模块
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
|
||||
class MessageState(TypedDict):
|
||||
class MessagesState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
llm_calls: int
|
||||
memory_context:str
|
||||
last_token_usage: dict # 本次调用的 token 使用详情
|
||||
last_elapsed_time: float # 本次调用耗时(秒)
|
||||
|
||||
@dataclass
|
||||
class GraphContext:
|
||||
user_id: str
|
||||
# 可扩展更多上下文信息
|
||||
|
||||
class GraphBuilder:
|
||||
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
|
||||
@@ -42,33 +57,132 @@ class GraphBuilder:
|
||||
"""创建系统提示模板(静态方法,无需访问实例)"""
|
||||
return ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=(
|
||||
"你是一个个人生活助手和数据分析助手。请说中文。"
|
||||
"当用户询问天气或温度时,使用get_current_temperature工具获取信息。"
|
||||
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
||||
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
|
||||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
"- 获取温度/天气:`get_current_temperature`\n"
|
||||
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n"
|
||||
"- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n"
|
||||
"- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n"
|
||||
"- 抓取网页内容:`fetch_webpage_content`\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
|
||||
"2. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"3. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)),
|
||||
MessagesPlaceholder(variable_name="message")
|
||||
MessagesPlaceholder(variable_name="messages")
|
||||
])
|
||||
|
||||
async def call_llm(self, state: MessageState) -> dict:
|
||||
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
|
||||
"""
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
|
||||
# 构建完整的输入消息列表(用于调试打印)
|
||||
system_prompt = self._prompt.messages[0] # SystemMessage
|
||||
if isinstance(system_prompt, SystemMessage):
|
||||
system_content = system_prompt.content.format(memory_context=memory_context)
|
||||
else:
|
||||
system_content = str(system_prompt.content)
|
||||
|
||||
input_messages = [SystemMessage(content=system_content)] + state["messages"]
|
||||
|
||||
# 打印发送给大模型的最终输入
|
||||
debug("\n" + "="*80)
|
||||
debug("📤 [LLM输入] 发送给大模型的完整消息:")
|
||||
debug(f" 总消息数: {len(input_messages)}")
|
||||
for i, msg in enumerate(input_messages):
|
||||
content_preview = str(msg.content) # 不截断,完整输出
|
||||
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._chain.invoke({"message": state["messages"]})
|
||||
)
|
||||
return {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1
|
||||
}
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._chain.invoke({
|
||||
"messages": state["messages"],
|
||||
"memory_context": memory_context
|
||||
})
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
if 'token_usage' in meta:
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
return {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||||
"last_token_usage": token_usage,
|
||||
"last_elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug("="*80 + "\n")
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
)
|
||||
return {
|
||||
"messages": [error_response],
|
||||
"llm_calls": state.get('llm_calls', 0),
|
||||
"last_token_usage": {},
|
||||
"last_elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def call_tools(self, state: MessageState) -> dict:
|
||||
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
对于每个工具调用,在线程池中执行同步工具函数
|
||||
@@ -91,11 +205,15 @@ class GraphBuilder:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 同步工具函数在线程池中执行
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: tool_func.invoke(tool_args)
|
||||
)
|
||||
# 修复闭包问题:将变量作为默认参数传入 lambda
|
||||
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
||||
if hasattr(tool_func, 'ainvoke'):
|
||||
observation = await tool_func.ainvoke(tool_args)
|
||||
else:
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
|
||||
)
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
except Exception as e:
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
@@ -103,25 +221,101 @@ class GraphBuilder:
|
||||
return {"messages": results}
|
||||
|
||||
@staticmethod
|
||||
def should_continue(state: MessageState) -> Literal['tool_node', END]:
|
||||
"""
|
||||
条件边判断(静态方法)
|
||||
决定下一步是进入工具节点还是结束
|
||||
"""
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
|
||||
"""决定下一步:工具调用、保存记忆还是结束"""
|
||||
last_message = state["messages"][-1]
|
||||
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
|
||||
|
||||
# 1. 如果需要调用工具,优先进入工具节点
|
||||
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
||||
return 'tool_node'
|
||||
return END
|
||||
|
||||
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
|
||||
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
|
||||
if isinstance(last_message, AIMessage):
|
||||
return 'save_memory'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
return 'END'
|
||||
|
||||
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""搜索并返回长期记忆"""
|
||||
user_id = runtime.context.user_id
|
||||
namespace = ("memories", user_id)
|
||||
query = str(state["messages"][-1].content)
|
||||
|
||||
debug(f"\n{'='*60}")
|
||||
debug(f"🔎 [记忆检索] 开始检索")
|
||||
debug(f" ├─ 用户ID: {user_id}")
|
||||
debug(f" ├─ 命名空间: {namespace}")
|
||||
debug(f" ├─ 查询内容: '{query}'")
|
||||
debug(f" └─ 消息总数: {len(state['messages'])}")
|
||||
|
||||
try:
|
||||
memories = await runtime.store.asearch(namespace, query=query)
|
||||
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
|
||||
|
||||
if memories:
|
||||
memory_text = "\n".join([m.value["data"] for m in memories])
|
||||
debug(f"📚 [记忆内容]")
|
||||
for i, memory in enumerate(memories, 1):
|
||||
debug(f" [{i}] {memory.value['data']}")
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": memory_text}
|
||||
else:
|
||||
debug(f"⚠️ [记忆检索] 未找到相关记忆")
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": ""}
|
||||
|
||||
except Exception as e:
|
||||
error(f"❌ [记忆检索] 检索失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": ""}
|
||||
|
||||
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""尝试从对话中提取并保存长期记忆"""
|
||||
# 获取最后一条用户消息(通常是要记住的内容的来源)
|
||||
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
|
||||
if not user_messages:
|
||||
return {}
|
||||
|
||||
last_user_msg = user_messages[-1].content.lower()
|
||||
|
||||
# 简单触发逻辑:包含"记住"或"保存"等关键词
|
||||
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
|
||||
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
|
||||
memory_content = f"用户说过:{last_user_msg}"
|
||||
user_id = runtime.context.user_id
|
||||
namespace = ("memories", user_id)
|
||||
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
|
||||
info(f"✅ 长期记忆已保存:{memory_content}")
|
||||
|
||||
return {}
|
||||
|
||||
def build(self) -> StateGraph:
|
||||
"""
|
||||
构建未编译的状态图(返回 StateGraph 实例)
|
||||
图中节点直接使用实例方法 call_llm, call_tools
|
||||
"""
|
||||
builder = StateGraph(MessageState)
|
||||
builder = StateGraph(MessagesState,context_schema=GraphContext)
|
||||
builder.add_node("retrieve_memory", self.retrieve_memory)
|
||||
builder.add_node("llm_call", self.call_llm)
|
||||
builder.add_node("tool_node", self.call_tools)
|
||||
builder.add_edge(START, "llm_call")
|
||||
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
|
||||
builder.add_node("save_memory", self.save_memory)
|
||||
|
||||
builder.add_edge(START, "retrieve_memory")
|
||||
builder.add_edge("retrieve_memory", "llm_call")
|
||||
builder.add_conditional_edges(
|
||||
"llm_call",
|
||||
self.should_continue,
|
||||
{
|
||||
"tool_node": "tool_node",
|
||||
"save_memory": "save_memory",
|
||||
'END': END
|
||||
}
|
||||
)
|
||||
builder.add_edge("tool_node", "llm_call")
|
||||
builder.add_edge("save_memory", END)
|
||||
|
||||
return builder
|
||||
Reference in New Issue
Block a user