This commit is contained in:
@@ -5,21 +5,27 @@ LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
|
||||
import operator
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
from typing import Literal, Annotated, Any
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||||
from langgraph.runtime import Runtime
|
||||
from dataclasses import dataclass
|
||||
import uuid
|
||||
|
||||
from langchain_core.prompt_values import ChatPromptValue
|
||||
# 本地模块
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
|
||||
# 是否启用 Graph 流转追踪(通过环境变量控制)
|
||||
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
|
||||
|
||||
|
||||
class MessagesState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
@@ -33,6 +39,53 @@ class GraphContext:
|
||||
user_id: str
|
||||
# 可扩展更多上下文信息
|
||||
|
||||
def _log_state_change(node_name: str, state: MessagesState, prefix: str = "进入"):
|
||||
"""
|
||||
通用辅助函数:打印节点状态变化
|
||||
|
||||
Args:
|
||||
node_name: 节点名称
|
||||
state: 当前状态
|
||||
prefix: 前缀("进入" 或 "离开")
|
||||
"""
|
||||
if not ENABLE_GRAPH_TRACE:
|
||||
return
|
||||
|
||||
messages = state.get("messages", [])
|
||||
msg_count = len(messages)
|
||||
last_msg = messages[-1] if messages else None
|
||||
last_info = ""
|
||||
if last_msg:
|
||||
content_preview = str(last_msg.content)[:100].replace("\n", " ")
|
||||
last_info = f"{last_msg.type.upper()}: {content_preview}"
|
||||
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")
|
||||
|
||||
def _print_llm_input(prompt_value: ChatPromptValue) -> ChatPromptValue:
|
||||
"""
|
||||
RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息
|
||||
|
||||
Args:
|
||||
prompt_value: ChatPromptValue 对象,包含格式化后的消息列表
|
||||
|
||||
Returns:
|
||||
原样返回 prompt_value,不影响链式调用
|
||||
"""
|
||||
if not ENABLE_GRAPH_TRACE:
|
||||
return prompt_value
|
||||
|
||||
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
|
||||
|
||||
debug("\n" + "=" * 80)
|
||||
debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:")
|
||||
debug(f" 总消息数: {len(messages)}")
|
||||
debug("-" * 80)
|
||||
for i, msg in enumerate(messages):
|
||||
content_preview = str(msg.content) # 完整输出
|
||||
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
||||
debug( "\n"+"=" * 80 + "\n")
|
||||
|
||||
return prompt_value
|
||||
|
||||
class GraphBuilder:
|
||||
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
|
||||
|
||||
@@ -50,30 +103,35 @@ class GraphBuilder:
|
||||
self.tools_by_name = tools_by_name
|
||||
self._llm_with_tools = llm.bind_tools(tools)
|
||||
self._prompt = self._create_prompt()
|
||||
self._chain = self._prompt | self._llm_with_tools
|
||||
self._chain = (
|
||||
self._prompt
|
||||
| RunnableLambda(_print_llm_input)
|
||||
| self._llm_with_tools
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_prompt() -> ChatPromptTemplate:
|
||||
"""创建系统提示模板(静态方法,无需访问实例)"""
|
||||
system_template = (
|
||||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
"- 获取温度/天气:`get_current_temperature`\n"
|
||||
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n"
|
||||
"- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n"
|
||||
"- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n"
|
||||
"- 抓取网页内容:`fetch_webpage_content`\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
|
||||
"2. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"3. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)
|
||||
return ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=(
|
||||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
"- 获取温度/天气:`get_current_temperature`\n"
|
||||
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n"
|
||||
"- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n"
|
||||
"- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n"
|
||||
"- 抓取网页内容:`fetch_webpage_content`\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
|
||||
"2. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"3. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)),
|
||||
("system", system_template),
|
||||
MessagesPlaceholder(variable_name="messages")
|
||||
])
|
||||
|
||||
@@ -82,26 +140,9 @@ class GraphBuilder:
|
||||
LLM 调用节点(异步方法)
|
||||
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
|
||||
"""
|
||||
_log_state_change("llm_call", state, "进入")
|
||||
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
|
||||
# 构建完整的输入消息列表(用于调试打印)
|
||||
system_prompt = self._prompt.messages[0] # SystemMessage
|
||||
if isinstance(system_prompt, SystemMessage):
|
||||
system_content = system_prompt.content.format(memory_context=memory_context)
|
||||
else:
|
||||
system_content = str(system_prompt.content)
|
||||
|
||||
input_messages = [SystemMessage(content=system_content)] + state["messages"]
|
||||
|
||||
# 打印发送给大模型的最终输入
|
||||
debug("\n" + "="*80)
|
||||
debug("📤 [LLM输入] 发送给大模型的完整消息:")
|
||||
debug(f" 总消息数: {len(input_messages)}")
|
||||
for i, msg in enumerate(input_messages):
|
||||
content_preview = str(msg.content) # 不截断,完整输出
|
||||
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = time.time()
|
||||
|
||||
@@ -155,13 +196,16 @@ class GraphBuilder:
|
||||
debug(f"{response.content}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
return {
|
||||
result = {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||||
"last_token_usage": token_usage,
|
||||
"last_elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
_log_state_change("llm_call", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
@@ -175,33 +219,45 @@ class GraphBuilder:
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
)
|
||||
return {
|
||||
error_result = {
|
||||
"messages": [error_response],
|
||||
"llm_calls": state.get('llm_calls', 0),
|
||||
"last_token_usage": {},
|
||||
"last_elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
_log_state_change("llm_call", state, "离开(异常)")
|
||||
return error_result
|
||||
|
||||
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
对于每个工具调用,在线程池中执行同步工具函数
|
||||
"""
|
||||
_log_state_change("tool_node", state, "进入")
|
||||
|
||||
last_message = state['messages'][-1]
|
||||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||||
_log_state_change("tool_node", state, "离开(无工具调用)")
|
||||
return {"messages": []}
|
||||
|
||||
results = []
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
tool_id = tool_call["id"]
|
||||
tool_func = self.tools_by_name.get(tool_name)
|
||||
|
||||
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
|
||||
|
||||
if tool_func is None:
|
||||
results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id))
|
||||
err_msg = f"Tool {tool_name} not found"
|
||||
debug(f" └─ ❌ {err_msg}")
|
||||
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -214,11 +270,20 @@ class GraphBuilder:
|
||||
None,
|
||||
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
|
||||
)
|
||||
|
||||
# 字符打印
|
||||
result_preview = str(observation).replace("\n", " ")
|
||||
debug(f" └─ ✅ 结果: {result_preview}")
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
except Exception as e:
|
||||
debug(f" └─ ❌ 异常: {e}")
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
|
||||
return {"messages": results}
|
||||
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
|
||||
|
||||
result = {"messages": results}
|
||||
_log_state_change("tool_node", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
|
||||
@@ -227,18 +292,26 @@ class GraphBuilder:
|
||||
|
||||
# 1. 如果需要调用工具,优先进入工具节点
|
||||
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
|
||||
return 'tool_node'
|
||||
|
||||
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
|
||||
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
|
||||
if isinstance(last_message, AIMessage):
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复(无工具调用) → 转向 'save_memory'")
|
||||
return 'save_memory'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
|
||||
return 'END'
|
||||
|
||||
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""搜索并返回长期记忆"""
|
||||
_log_state_change("retrieve_memory", state, "进入")
|
||||
|
||||
user_id = runtime.context.user_id
|
||||
namespace = ("memories", user_id)
|
||||
query = str(state["messages"][-1].content)
|
||||
@@ -260,24 +333,33 @@ class GraphBuilder:
|
||||
for i, memory in enumerate(memories, 1):
|
||||
debug(f" [{i}] {memory.value['data']}")
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": memory_text}
|
||||
result = {"memory_context": memory_text}
|
||||
_log_state_change("retrieve_memory", {**state, **result}, "离开")
|
||||
return result
|
||||
else:
|
||||
debug(f"⚠️ [记忆检索] 未找到相关记忆")
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": ""}
|
||||
result = {"memory_context": ""}
|
||||
_log_state_change("retrieve_memory", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error(f"❌ [记忆检索] 检索失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": ""}
|
||||
result = {"memory_context": ""}
|
||||
_log_state_change("retrieve_memory", {**state, **result}, "离开(异常)")
|
||||
return result
|
||||
|
||||
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""尝试从对话中提取并保存长期记忆"""
|
||||
_log_state_change("save_memory", state, "进入")
|
||||
|
||||
# 获取最后一条用户消息(通常是要记住的内容的来源)
|
||||
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
|
||||
if not user_messages:
|
||||
_log_state_change("save_memory", state, "离开(无用户消息)")
|
||||
return {}
|
||||
|
||||
last_user_msg = user_messages[-1].content.lower()
|
||||
@@ -291,6 +373,7 @@ class GraphBuilder:
|
||||
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
|
||||
info(f"✅ 长期记忆已保存:{memory_content}")
|
||||
|
||||
_log_state_change("save_memory", state, "离开")
|
||||
return {}
|
||||
|
||||
def build(self) -> StateGraph:
|
||||
|
||||
Reference in New Issue
Block a user