This commit is contained in:
@@ -7,6 +7,7 @@ from app.nodes.llm_call import create_llm_call_node
|
||||
from app.nodes.tool_call import create_tool_call_node
|
||||
from app.nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from app.nodes.summarize import create_summarize_node
|
||||
from app.nodes.finalize import finalize_node
|
||||
|
||||
__all__ = [
|
||||
"should_continue",
|
||||
@@ -14,4 +15,5 @@ __all__ = [
|
||||
"create_tool_call_node",
|
||||
"create_retrieve_memory_node",
|
||||
"create_summarize_node",
|
||||
"finalize_node",
|
||||
]
|
||||
|
||||
47
app/nodes/finalize.py
Normal file
47
app/nodes/finalize.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
完成事件节点模块
|
||||
负责发送完成事件,包含token使用情况和耗时信息
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import info, error
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
完成事件节点 - 发送完成事件,包含token使用情况和耗时信息
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
空字典(完成节点,无状态更新)
|
||||
"""
|
||||
log_state_change("finalize", state, "进入")
|
||||
|
||||
try:
|
||||
# 获取流式写入器并发送完成事件
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": "done",
|
||||
"token_usage": state.get("last_token_usage", {}),
|
||||
"elapsed_time": state.get("last_elapsed_time", 0.0)
|
||||
}
|
||||
})
|
||||
info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息")
|
||||
except Exception as e:
|
||||
error(f"❌ [完成事件] 发送完成事件时发生异常: {e}")
|
||||
|
||||
log_state_change("finalize", state, "离开")
|
||||
return {}
|
||||
@@ -32,15 +32,19 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
# 构建调用链
|
||||
prompt = create_system_prompt()
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
chain = prompt | RunnableLambda(print_llm_input) | llm_with_tools
|
||||
|
||||
async def call_llm(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
chain = prompt | llm_with_tools
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
|
||||
|
||||
Returns:
|
||||
更新后的状态字典
|
||||
@@ -48,17 +52,28 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: chain.invoke({
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
"messages": state["messages"],
|
||||
"memory_context": memory_context
|
||||
})
|
||||
)
|
||||
},
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
for chunk in chunks[1:]:
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
@@ -85,13 +100,7 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
@@ -99,6 +108,12 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
result = {
|
||||
|
||||
@@ -24,20 +24,23 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def retrieve_memory(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆检索节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
包含 memory_context 的状态更新
|
||||
"""
|
||||
log_state_change("retrieve_memory", state, "进入")
|
||||
|
||||
user_id = runtime.context.user_id
|
||||
# 从 metadata 中获取 user_id
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
|
||||
# 兼容 dict 和对象两种消息格式
|
||||
last_msg = state["messages"][-1]
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.state import MessagesState
|
||||
from app.logger import info
|
||||
|
||||
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'END']:
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']:
|
||||
"""
|
||||
决定下一步:工具调用、生成摘要还是结束
|
||||
|
||||
@@ -20,7 +20,7 @@ def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', '
|
||||
state: 当前对话状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称或 END
|
||||
下一个节点名称
|
||||
"""
|
||||
last_message = state["messages"][-1]
|
||||
|
||||
@@ -40,9 +40,9 @@ def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', '
|
||||
else:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
|
||||
return 'END'
|
||||
return 'finalize'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
|
||||
return 'END'
|
||||
return 'finalize'
|
||||
|
||||
@@ -24,13 +24,15 @@ def create_summarize_node(mem0_client: Mem0Client):
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def summarize_conversation(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆存储节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
重置计数器的状态更新
|
||||
@@ -42,7 +44,8 @@ def create_summarize_node(mem0_client: Mem0Client):
|
||||
debug("📝 [记忆添加] 对话过短,跳过")
|
||||
return {"turns_since_last_summary": 0}
|
||||
|
||||
user_id = runtime.context.user_id
|
||||
# 从 metadata 中获取 user_id
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
|
||||
# 确保 Mem0 已初始化(懒加载)
|
||||
if not mem0_client._initialized:
|
||||
@@ -83,4 +86,4 @@ def create_summarize_node(mem0_client: Mem0Client):
|
||||
log_state_change("summarize", state, "离开")
|
||||
return {"turns_since_last_summary": 0}
|
||||
|
||||
return summarize_conversation
|
||||
return summarize_conversation
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
from typing import Any, Dict
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
@@ -25,13 +26,15 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def call_tools(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def call_tools(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
包含 ToolMessage 的状态更新
|
||||
@@ -62,6 +65,10 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
|
||||
continue
|
||||
|
||||
# 获取流式写入器并发送工具调用开始事件
|
||||
writer = get_stream_writer()
|
||||
writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}})
|
||||
|
||||
try:
|
||||
# 修复闭包问题:将变量作为默认参数传入 lambda
|
||||
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
||||
@@ -77,9 +84,15 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
result_preview = str(observation).replace("\n", " ")
|
||||
debug(f" └─ ✅ 结果: {result_preview}")
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
|
||||
# 发送工具调用完成事件
|
||||
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}})
|
||||
except Exception as e:
|
||||
debug(f" └─ ❌ 异常: {e}")
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
|
||||
# 发送工具调用失败事件
|
||||
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}})
|
||||
|
||||
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
|
||||
|
||||
@@ -87,4 +100,4 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
log_state_change("tool_node", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
return call_tools
|
||||
return call_tools
|
||||
Reference in New Issue
Block a user