文件变更
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
from app.nodes.router import should_continue
|
||||
from app.nodes.llm_call import create_llm_call_node
|
||||
from app.nodes.tool_call import create_tool_call_node
|
||||
from app.nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from app.graph.retrieve_memory import create_retrieve_memory_node
|
||||
from app.nodes.summarize import create_summarize_node
|
||||
from app.nodes.finalize import finalize_node
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import info, error
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain_core.runnables import RunnableLambda
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change, print_llm_input
|
||||
from app.logger import debug, info, error
|
||||
@@ -30,7 +30,7 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
prompt = create_system_prompt()
|
||||
prompt = create_system_prompt(tools)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""
|
||||
记忆检索节点模块
|
||||
负责从 Mem0 检索相关长期记忆
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug
|
||||
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆检索节点
|
||||
|
||||
Args:
|
||||
mem0_client: Mem0 客户端实例
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆检索节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
包含 memory_context 的状态更新
|
||||
"""
|
||||
log_state_change("retrieve_memory", state, "进入")
|
||||
|
||||
# 从 metadata 中获取 user_id
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
|
||||
# 兼容 dict 和对象两种消息格式
|
||||
last_msg = state["messages"][-1]
|
||||
if isinstance(last_msg, dict):
|
||||
query = str(last_msg.get("content", ""))
|
||||
else:
|
||||
query = str(last_msg.content)
|
||||
memory_text_parts = []
|
||||
|
||||
# 确保 Mem0 已初始化(懒加载)
|
||||
if not mem0_client._initialized:
|
||||
await mem0_client.initialize()
|
||||
|
||||
if mem0_client.mem0:
|
||||
try:
|
||||
# 异步调用 Mem0 语义检索
|
||||
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
|
||||
|
||||
if facts:
|
||||
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
|
||||
else:
|
||||
debug("🔍 [记忆检索] 未找到相关记忆")
|
||||
except Exception as e:
|
||||
from app.logger import warning
|
||||
warning(f"⚠️ Mem0 检索失败: {e}")
|
||||
else:
|
||||
from app.logger import warning
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆检索")
|
||||
|
||||
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
||||
result = {"memory_context": memory_context}
|
||||
log_state_change("retrieve_memory", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
return retrieve_memory
|
||||
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
|
||||
from app.state import MessagesState
|
||||
from app.graph.state import MessagesState
|
||||
from app.logger import info
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error, warning
|
||||
|
||||
@@ -10,7 +10,7 @@ from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info
|
||||
|
||||
|
||||
Reference in New Issue
Block a user