文件变更

This commit is contained in:
2026-04-20 14:05:57 +08:00
parent 3c906e91d9
commit 4e981e9dcf
28 changed files with 474 additions and 490 deletions

View File

@@ -0,0 +1,79 @@
"""
LangGraph 状态图构建模块 - 精简版,仅负责组装图
所有节点逻辑已拆分到独立模块
"""
from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.nodes import (
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
should_continue
)
from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder:
"""LangGraph 状态图构建器 - 仅负责组装图"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
self.mem0_client = Mem0Client(llm)
def build(self) -> StateGraph:
"""
构建未编译的状态图
Returns:
StateGraph 实例
"""
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
llm_call_node = create_llm_call_node(self.llm, self.tools)
tool_call_node = create_tool_call_node(self.tools_by_name)
summarize_node = create_summarize_node(self.mem0_client)
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
builder.add_node("finalize", finalize_node)
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{
"tool_node": "tool_node",
"summarize": "summarize",
"finalize": "finalize"
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("summarize", "finalize")
builder.add_edge("finalize", END)
return builder

103
app/graph/graph_tools.py Normal file
View File

@@ -0,0 +1,103 @@
"""
工具定义模块 - 纯函数工具,无依赖 AIAgent 类
"""
# 标准库
import os
from pathlib import Path
# 第三方库
import pandas as pd
import pypdf
import requests
from bs4 import BeautifulSoup
from langchain_core.tools import tool
def _file_allow_check(filename: str) -> Path:
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
allowed_dir = Path("./user_docs").resolve()
allowed_dir.mkdir(exist_ok=True)
file_path = (allowed_dir / filename).resolve()
if not str(file_path).startswith(str(allowed_dir)):
raise ValueError("错误:非法文件路径。")
if not file_path.exists():
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
return file_path
@tool
def get_current_temperature(location: str) -> str:
"""获取指定地点的当前温度。"""
return f'当前{location}的温度为25℃'
@tool
def read_local_file(filename: str) -> str:
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
try:
file_path = _file_allow_check(filename)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
except Exception as e:
return f"读取文件时出错:{str(e)}"
@tool
def read_pdf_summary(filename: str) -> str:
"""读取PDF文件并返回内容文本摘要。"""
try:
file_path = _file_allow_check(filename)
text = ""
with open(file_path, 'rb') as f:
reader = pypdf.PdfReader(f)
for page in reader.pages[:3]:
text += page.extract_text()
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
except Exception as e:
return f"读取PDF出错{e}"
@tool
def read_excel_as_markdown(filename: str) -> str:
"""读取Excel文件并将其主要数据转换为Markdown表格格式。"""
try:
file_path = _file_allow_check(filename)
df = pd.read_excel(file_path)
markdown_table = df.head(10).to_markdown(index=False)
return f"Excel文件 '{filename}' 的数据预览前10行\n{markdown_table}"
except Exception as e:
return f"读取Excel出错{e}"
@tool
def fetch_webpage_content(url: str) -> str:
"""抓取给定URL的网页正文内容并返回清晰的纯文本。"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
except Exception as e:
return f"抓取网页时出错:{str(e)}"
# 工具列表和映射(全局常量)
AVAILABLE_TOOLS = [
get_current_temperature,
read_local_file,
fetch_webpage_content,
read_pdf_summary,
read_excel_as_markdown
]
TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS}

View File

@@ -0,0 +1,78 @@
"""
记忆检索节点模块
负责从 Mem0 检索相关长期记忆
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆检索节点
Args:
mem0_client: Mem0 客户端实例
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆检索节点 - 使用 Mem0
Args:
state: 当前对话状态
config: 运行时配置
Returns:
包含 memory_context 的状态更新
"""
log_state_change("retrieve_memory", state, "进入")
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 兼容 dict 和对象两种消息格式
last_msg = state["messages"][-1]
if isinstance(last_msg, dict):
query = str(last_msg.get("content", ""))
else:
query = str(last_msg.content)
memory_text_parts = []
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
await mem0_client.initialize()
if mem0_client.mem0:
try:
# 异步调用 Mem0 语义检索
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
if facts:
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
else:
debug("🔍 [记忆检索] 未找到相关记忆")
except Exception as e:
from app.logger import warning
warning(f"⚠️ Mem0 检索失败: {e}")
else:
from app.logger import warning
warning("⚠️ Mem0 未初始化,跳过记忆检索")
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
result = {"memory_context": memory_context}
log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
return retrieve_memory

27
app/graph/state.py Normal file
View File

@@ -0,0 +1,27 @@
"""
LangGraph 状态定义模块
包含 MessagesState 和 GraphContext
"""
import operator
from typing import Annotated, Any
from typing_extensions import TypedDict
from dataclasses import dataclass
from langchain_core.messages import AnyMessage
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
memory_context: str
last_token_usage: dict # 本次调用的 token 使用详情
last_elapsed_time: float # 本次调用耗时(秒)
turns_since_last_summary: int # 距离上次生成摘要的轮数
@dataclass
class GraphContext:
"""图执行上下文"""
user_id: str
# 可扩展更多上下文信息