This commit is contained in:
17
app/nodes/__init__.py
Normal file
17
app/nodes/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
节点模块 - 导出所有 LangGraph 节点函数
|
||||
"""
|
||||
|
||||
from app.nodes.router import should_continue
|
||||
from app.nodes.llm_call import create_llm_call_node
|
||||
from app.nodes.tool_call import create_tool_call_node
|
||||
from app.nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from app.nodes.summarize import create_summarize_node
|
||||
|
||||
__all__ = [
|
||||
"should_continue",
|
||||
"create_llm_call_node",
|
||||
"create_tool_call_node",
|
||||
"create_retrieve_memory_node",
|
||||
"create_summarize_node",
|
||||
]
|
||||
139
app/nodes/llm_call.py
Normal file
139
app/nodes/llm_call.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
LLM 调用节点模块
|
||||
负责调用大语言模型并处理响应
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change, print_llm_input
|
||||
from app.logger import debug, info, error
|
||||
|
||||
|
||||
def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
|
||||
Args:
|
||||
llm: LangChain LLM 实例
|
||||
tools: 工具列表
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
prompt = create_system_prompt()
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
chain = prompt | RunnableLambda(print_llm_input) | llm_with_tools
|
||||
|
||||
async def call_llm(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
|
||||
Returns:
|
||||
更新后的状态字典
|
||||
"""
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: chain.invoke({
|
||||
"messages": state["messages"],
|
||||
"memory_context": memory_context
|
||||
})
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
if 'token_usage' in meta:
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
result = {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||||
"last_token_usage": token_usage,
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器
|
||||
}
|
||||
|
||||
log_state_change("llm_call", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug("="*80 + "\n")
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
)
|
||||
error_result = {
|
||||
"messages": [error_response],
|
||||
"llm_calls": state.get('llm_calls', 0),
|
||||
"last_token_usage": {},
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器
|
||||
}
|
||||
|
||||
log_state_change("llm_call", state, "离开(异常)")
|
||||
return error_result
|
||||
|
||||
return call_llm
|
||||
75
app/nodes/retrieve_memory.py
Normal file
75
app/nodes/retrieve_memory.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
记忆检索节点模块
|
||||
负责从 Mem0 检索相关长期记忆
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug
|
||||
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆检索节点
|
||||
|
||||
Args:
|
||||
mem0_client: Mem0 客户端实例
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def retrieve_memory(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆检索节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
|
||||
Returns:
|
||||
包含 memory_context 的状态更新
|
||||
"""
|
||||
log_state_change("retrieve_memory", state, "进入")
|
||||
|
||||
user_id = runtime.context.user_id
|
||||
|
||||
# 兼容 dict 和对象两种消息格式
|
||||
last_msg = state["messages"][-1]
|
||||
if isinstance(last_msg, dict):
|
||||
query = str(last_msg.get("content", ""))
|
||||
else:
|
||||
query = str(last_msg.content)
|
||||
memory_text_parts = []
|
||||
|
||||
# 确保 Mem0 已初始化(懒加载)
|
||||
if not mem0_client._initialized:
|
||||
await mem0_client.initialize()
|
||||
|
||||
if mem0_client.mem0:
|
||||
try:
|
||||
# 异步调用 Mem0 语义检索
|
||||
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
|
||||
|
||||
if facts:
|
||||
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
|
||||
else:
|
||||
debug("🔍 [记忆检索] 未找到相关记忆")
|
||||
except Exception as e:
|
||||
from app.logger import warning
|
||||
warning(f"⚠️ Mem0 检索失败: {e}")
|
||||
else:
|
||||
from app.logger import warning
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆检索")
|
||||
|
||||
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
||||
result = {"memory_context": memory_context}
|
||||
log_state_change("retrieve_memory", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
return retrieve_memory
|
||||
48
app/nodes/router.py
Normal file
48
app/nodes/router.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
路由决策节点
|
||||
根据当前状态决定下一步走向
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
|
||||
from app.state import MessagesState
|
||||
from app.logger import info
|
||||
|
||||
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'END']:
|
||||
"""
|
||||
决定下一步:工具调用、生成摘要还是结束
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称或 END
|
||||
"""
|
||||
last_message = state["messages"][-1]
|
||||
|
||||
# 1. 如果需要调用工具,优先进入工具节点
|
||||
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
|
||||
return 'tool_node'
|
||||
|
||||
# 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值
|
||||
if isinstance(last_message, AIMessage):
|
||||
turns = state.get("turns_since_last_summary", 0)
|
||||
if turns >= MEMORY_SUMMARIZE_INTERVAL:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'")
|
||||
return 'summarize'
|
||||
else:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
|
||||
return 'END'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
|
||||
return 'END'
|
||||
86
app/nodes/summarize.py
Normal file
86
app/nodes/summarize.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
记忆存储节点模块
|
||||
负责将对话历史提交给 Mem0 进行事实提取和存储
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error, warning
|
||||
|
||||
|
||||
def create_summarize_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆存储节点
|
||||
|
||||
Args:
|
||||
mem0_client: Mem0 客户端实例
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def summarize_conversation(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆存储节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
|
||||
Returns:
|
||||
重置计数器的状态更新
|
||||
"""
|
||||
log_state_change("summarize", state, "进入")
|
||||
|
||||
messages = state["messages"]
|
||||
if len(messages) < 4:
|
||||
debug("📝 [记忆添加] 对话过短,跳过")
|
||||
return {"turns_since_last_summary": 0}
|
||||
|
||||
user_id = runtime.context.user_id
|
||||
|
||||
# 确保 Mem0 已初始化(懒加载)
|
||||
if not mem0_client._initialized:
|
||||
await mem0_client.initialize()
|
||||
|
||||
# 将整个对话历史转换为 Mem0 需要的消息格式
|
||||
mem0_messages = []
|
||||
for msg in messages:
|
||||
# 兼容 dict 和对象两种格式
|
||||
if isinstance(msg, dict):
|
||||
msg_type = msg.get("type", "")
|
||||
msg_content = msg.get("content", "")
|
||||
else:
|
||||
msg_type = getattr(msg, 'type', '')
|
||||
msg_content = getattr(msg, 'content', '')
|
||||
|
||||
if msg_type == "human":
|
||||
mem0_messages.append({"role": "user", "content": msg_content})
|
||||
elif msg_type == "ai":
|
||||
mem0_messages.append({"role": "assistant", "content": msg_content})
|
||||
elif msg_type == "tool":
|
||||
mem0_messages.append({"role": "system", "content": f"[工具返回] {msg_content}"})
|
||||
|
||||
if mem0_client.mem0:
|
||||
try:
|
||||
# 异步调用 Mem0 自动提取并存储事实
|
||||
success = await mem0_client.add_memories(
|
||||
mem0_messages,
|
||||
user_id=user_id
|
||||
)
|
||||
if success:
|
||||
info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取")
|
||||
except Exception as e:
|
||||
error(f"❌ Mem0 记忆添加失败: {e}")
|
||||
else:
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆添加")
|
||||
|
||||
log_state_change("summarize", state, "离开")
|
||||
return {"turns_since_last_summary": 0}
|
||||
|
||||
return summarize_conversation
|
||||
90
app/nodes/tool_call.py
Normal file
90
app/nodes/tool_call.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
工具执行节点模块
|
||||
负责执行 AI 调用的工具函数
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info
|
||||
|
||||
|
||||
def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
"""
|
||||
工厂函数:创建工具执行节点
|
||||
|
||||
Args:
|
||||
tools_by_name: 名称到工具函数的映射字典
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def call_tools(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
|
||||
Returns:
|
||||
包含 ToolMessage 的状态更新
|
||||
"""
|
||||
log_state_change("tool_node", state, "进入")
|
||||
|
||||
last_message = state['messages'][-1]
|
||||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||||
log_state_change("tool_node", state, "离开(无工具调用)")
|
||||
return {"messages": []}
|
||||
|
||||
results = []
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
tool_id = tool_call["id"]
|
||||
tool_func = tools_by_name.get(tool_name)
|
||||
|
||||
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
|
||||
|
||||
if tool_func is None:
|
||||
err_msg = f"Tool {tool_name} not found"
|
||||
debug(f" └─ ❌ {err_msg}")
|
||||
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
|
||||
continue
|
||||
|
||||
try:
|
||||
# 修复闭包问题:将变量作为默认参数传入 lambda
|
||||
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
||||
if hasattr(tool_func, 'ainvoke'):
|
||||
observation = await tool_func.ainvoke(tool_args)
|
||||
else:
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
|
||||
)
|
||||
|
||||
# 字符打印
|
||||
result_preview = str(observation).replace("\n", " ")
|
||||
debug(f" └─ ✅ 结果: {result_preview}")
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
except Exception as e:
|
||||
debug(f" └─ ❌ 异常: {e}")
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
|
||||
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
|
||||
|
||||
result = {"messages": results}
|
||||
log_state_change("tool_node", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
return call_tools
|
||||
Reference in New Issue
Block a user