修改引用逻辑,修改长期记忆bug
This commit is contained in:
@@ -4,15 +4,13 @@
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import info, error
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
|
||||
@@ -3,21 +3,17 @@ LLM 调用节点模块
|
||||
负责调用大语言模型并处理响应
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change, print_llm_input
|
||||
from app.graph.state import MessagesState
|
||||
from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
|
||||
|
||||
def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
|
||||
38
app/nodes/memory_trigger.py
Normal file
38
app/nodes/memory_trigger.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Any, Dict
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from app.graph.state import MessagesState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.logger import info
|
||||
|
||||
# 全局变量,在 GraphBuilder 中注入
|
||||
_mem0_client: Mem0Client = None
|
||||
|
||||
def set_mem0_client(client: Mem0Client):
|
||||
global _mem0_client
|
||||
_mem0_client = client
|
||||
|
||||
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
|
||||
if _mem0_client is None:
|
||||
return {}
|
||||
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return {}
|
||||
|
||||
last_msg = messages[-1]
|
||||
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
|
||||
|
||||
# 触发词(可自行扩展)
|
||||
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
|
||||
if any(word in content for word in trigger_words):
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
# 确保 Mem0 已初始化
|
||||
if not _mem0_client._initialized:
|
||||
await _mem0_client.initialize()
|
||||
# 将用户消息作为事实来源提交给 Mem0
|
||||
mem0_messages = [{"role": "user", "content": content}]
|
||||
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
|
||||
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
|
||||
|
||||
return {} # 不修改状态
|
||||
@@ -4,15 +4,13 @@
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error, warning
|
||||
|
||||
|
||||
def create_summarize_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆存储节点
|
||||
|
||||
@@ -6,15 +6,13 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info
|
||||
|
||||
|
||||
def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
"""
|
||||
工厂函数:创建工具执行节点
|
||||
|
||||
Reference in New Issue
Block a user