修改引用逻辑,修改长期记忆bug
This commit is contained in:
8
app/graph/__init__.py
Normal file
8
app/graph/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Graph 子模块
|
||||
"""
|
||||
|
||||
from app.graph.graph_builder import GraphBuilder
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
|
||||
__all__ = ["GraphBuilder", "MessagesState", "GraphContext"]
|
||||
@@ -5,18 +5,17 @@ LangGraph 状态图构建模块 - 精简版,仅负责组装图
|
||||
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
# 本地模块
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.nodes import (
|
||||
should_continue,
|
||||
create_llm_call_node,
|
||||
create_tool_call_node,
|
||||
create_retrieve_memory_node,
|
||||
create_summarize_node,
|
||||
should_continue
|
||||
finalize_node,
|
||||
)
|
||||
from app.nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from app.memory import Mem0Client
|
||||
from app.nodes.finalize import finalize_node
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
@@ -45,6 +44,9 @@ class GraphBuilder:
|
||||
Returns:
|
||||
StateGraph 实例
|
||||
"""
|
||||
# 注入全局客户端
|
||||
set_mem0_client(self.mem0_client)
|
||||
|
||||
builder = StateGraph(MessagesState, context_schema=GraphContext)
|
||||
|
||||
# ⭐ 通过工厂函数创建节点(依赖注入)
|
||||
@@ -55,6 +57,7 @@ class GraphBuilder:
|
||||
|
||||
# 添加节点
|
||||
builder.add_node("retrieve_memory", retrieve_memory_node)
|
||||
builder.add_node("memory_trigger", memory_trigger_node)
|
||||
builder.add_node("llm_call", llm_call_node)
|
||||
builder.add_node("tool_node", tool_call_node)
|
||||
builder.add_node("summarize", summarize_node)
|
||||
@@ -62,7 +65,8 @@ class GraphBuilder:
|
||||
|
||||
# 添加边
|
||||
builder.add_edge(START, "retrieve_memory")
|
||||
builder.add_edge("retrieve_memory", "llm_call")
|
||||
builder.add_edge("retrieve_memory", "memory_trigger")
|
||||
builder.add_edge("memory_trigger", "llm_call")
|
||||
builder.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_continue,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
"""
|
||||
|
||||
# 标准库
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 第三方库
|
||||
@@ -13,7 +12,6 @@ import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _file_allow_check(filename: str) -> Path:
|
||||
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
|
||||
allowed_dir = Path("./user_docs").resolve()
|
||||
@@ -28,13 +26,11 @@ def _file_allow_check(filename: str) -> Path:
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_temperature(location: str) -> str:
|
||||
"""获取指定地点的当前温度。"""
|
||||
return f'当前{location}的温度为25℃'
|
||||
|
||||
|
||||
@tool
|
||||
def read_local_file(filename: str) -> str:
|
||||
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
|
||||
@@ -46,7 +42,6 @@ def read_local_file(filename: str) -> str:
|
||||
except Exception as e:
|
||||
return f"读取文件时出错:{str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def read_pdf_summary(filename: str) -> str:
|
||||
"""读取PDF文件并返回内容文本摘要。"""
|
||||
@@ -61,7 +56,6 @@ def read_pdf_summary(filename: str) -> str:
|
||||
except Exception as e:
|
||||
return f"读取PDF出错:{e}"
|
||||
|
||||
|
||||
@tool
|
||||
def read_excel_as_markdown(filename: str) -> str:
|
||||
"""读取Excel文件,并将其主要数据转换为Markdown表格格式。"""
|
||||
@@ -73,7 +67,6 @@ def read_excel_as_markdown(filename: str) -> str:
|
||||
except Exception as e:
|
||||
return f"读取Excel出错:{e}"
|
||||
|
||||
|
||||
@tool
|
||||
def fetch_webpage_content(url: str) -> str:
|
||||
"""抓取给定URL的网页正文内容,并返回清晰的纯文本。"""
|
||||
@@ -91,7 +84,6 @@ def fetch_webpage_content(url: str) -> str:
|
||||
except Exception as e:
|
||||
return f"抓取网页时出错:{str(e)}"
|
||||
|
||||
|
||||
# 工具列表和映射(全局常量)
|
||||
AVAILABLE_TOOLS = [
|
||||
get_current_temperature,
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug
|
||||
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆检索节点
|
||||
|
||||
@@ -4,12 +4,11 @@ LangGraph 状态定义模块
|
||||
"""
|
||||
|
||||
import operator
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from dataclasses import dataclass
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
|
||||
class MessagesState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
@@ -19,7 +18,6 @@ class MessagesState(TypedDict):
|
||||
last_elapsed_time: float # 本次调用耗时(秒)
|
||||
turns_since_last_summary: int # 距离上次生成摘要的轮数
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphContext:
|
||||
"""图执行上下文"""
|
||||
|
||||
Reference in New Issue
Block a user