采用向量数据库实现长期记忆
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-04-15 23:52:13 +08:00
parent de68916c5a
commit a92a220ff3
24 changed files with 1237 additions and 713 deletions

View File

@@ -6,16 +6,15 @@ AI Agent 服务类 - 支持多模型动态切换
import os
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.logger import debug, info, warning, error
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres.aio import AsyncPostgresStore
load_dotenv()
@@ -23,15 +22,13 @@ load_dotenv()
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def __init__(self, checkpointer: AsyncPostgresSaver, store: AsyncPostgresStore):
def __init__(self, checkpointer: AsyncPostgresSaver):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
store: 已经初始化的 AsyncPostgresStore 实例
"""
self.checkpointer = checkpointer
self.store = store
self.graphs = {} # 存储不同模型对应的 graph 实例
def _create_zhipu_llm(self):
@@ -68,19 +65,19 @@ class AIAgentService:
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://115.190.121.151:18000/v1"
"http://localhost:8081/v1"
)
return ChatOpenAI(
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer 和 store"""
"""预编译所有模型的 graph使用传入的 checkpointer"""
model_configs = {
"zhipu": self._create_zhipu_llm,
"deepseek": self._create_deepseek_llm,
@@ -92,7 +89,7 @@ class AIAgentService:
info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer, store=self.store)
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
info(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
@@ -121,14 +118,27 @@ class AIAgentService:
"elapsed_time": float # 调用耗时(秒)
}
"""
# 尝试使用指定模型,如果不可用则循环尝试其他模型
if model not in self.graphs:
fallback_model = next(iter(self.graphs.keys()))
warning(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
model = fallback_model
warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型")
found = False
for available_model in self.graphs.keys():
try:
# 这里可以添加额外的模型可用性检查逻辑
model = available_model
found = True
info(f"已切换到可用模型: '{model}'")
break
except Exception as e:
warning(f"模型 '{available_model}' 也不可用: {str(e)}")
continue
if not found:
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
graph = self.graphs[model]
config = {"configurable": {"thread_id": thread_id}}
input_state = {"messages": [HumanMessage(content=message)]}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)

View File

@@ -12,7 +12,6 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depe
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres.aio import AsyncPostgresStore
from app.agent import AIAgentService
from app.logger import debug, info, warning, error
@@ -23,23 +22,19 @@ load_dotenv()
# 优先级:环境变量 DB_URI > Docker 内部服务名 > 本地开发地址
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable"
"postgresql://postgres:mysecretpassword@ai-postgres:5432/langgraph_db?sslmode=disable"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
# 1. 创建数据库连接池并初始化表
async with (
AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer,
AsyncPostgresStore.from_conn_string(DB_URI) as store
):
# 1. 创建数据库连接池并初始化表(仅 checkpointer
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
await store.setup()
# 2. 构建 AI Agent 服务
agent_service = AIAgentService(checkpointer,store)
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
# 3. 将服务实例存入 app.state
@@ -155,4 +150,6 @@ async def websocket_endpoint(
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)
# 使用环境变量或默认端口 8003避免与 vLLM 的 8001 端口冲突)
port = int(os.getenv("BACKEND_PORT", "8003"))
uvicorn.run(app, host="0.0.0.0", port=port)

23
app/config.py Normal file
View File

@@ -0,0 +1,23 @@
"""
环境变量集中管理模块
所有配置项统一定义,避免散落在各个文件中
"""
import os
# ========== Graph 执行追踪配置 ==========
# 是否启用 Graph 流转追踪(通过环境变量控制)
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
# ========== 记忆提取配置 ==========
# 记忆提取间隔:每 N 轮对话生成一次摘要
MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
# ========== Mem0 记忆层配置 ==========
# Qdrant 向量数据库地址
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
# vLLM Embedding 服务地址 (用于 Mem0 的向量化)
VLLM_EMBEDDING_URL = os.getenv("VLLM_EMBEDDING_URL", "http://localhost:8082/v1")

View File

@@ -1,95 +1,27 @@
"""
LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
LangGraph 状态图构建模块 - 精简版,仅负责组装图
所有节点逻辑已拆分到独立模块
"""
import operator
import asyncio
import time
import os
from typing import Literal, Annotated, Any
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from langgraph.store.postgres.aio import AsyncPostgresStore
from langgraph.runtime import Runtime
from dataclasses import dataclass
import uuid
from langchain_core.prompt_values import ChatPromptValue
# 本地模块
from app.logger import debug, info, warning, error
from app.state import MessagesState, GraphContext
from app.nodes import (
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
should_continue
)
from app.memory import Mem0Client
# 是否启用 Graph 流转追踪(通过环境变量控制)
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
memory_context:str
last_token_usage: dict # 本次调用的 token 使用详情
last_elapsed_time: float # 本次调用耗时(秒)
@dataclass
class GraphContext:
user_id: str
# 可扩展更多上下文信息
def _log_state_change(node_name: str, state: MessagesState, prefix: str = "进入"):
"""
通用辅助函数:打印节点状态变化
Args:
node_name: 节点名称
state: 当前状态
prefix: 前缀("进入""离开"
"""
if not ENABLE_GRAPH_TRACE:
return
messages = state.get("messages", [])
msg_count = len(messages)
last_msg = messages[-1] if messages else None
last_info = ""
if last_msg:
content_preview = str(last_msg.content)[:100].replace("\n", " ")
last_info = f"{last_msg.type.upper()}: {content_preview}"
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")
def _print_llm_input(prompt_value: ChatPromptValue) -> ChatPromptValue:
"""
RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息
Args:
prompt_value: ChatPromptValue 对象,包含格式化后的消息列表
Returns:
原样返回 prompt_value不影响链式调用
"""
if not ENABLE_GRAPH_TRACE:
return prompt_value
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
debug("\n" + "=" * 80)
debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:")
debug(f" 总消息数: {len(messages)}")
debug("-" * 80)
for i, msg in enumerate(messages):
content_preview = str(msg.content) # 完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug( "\n"+"=" * 80 + "\n")
return prompt_value
class GraphBuilder:
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
"""LangGraph 状态图构建器 - 仅负责组装图"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]):
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
"""
初始化构建器
@@ -101,304 +33,44 @@ class GraphBuilder:
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
self._llm_with_tools = llm.bind_tools(tools)
self._prompt = self._create_prompt()
self._chain = (
self._prompt
| RunnableLambda(_print_llm_input)
| self._llm_with_tools
)
@staticmethod
def _create_prompt() -> ChatPromptTemplate:
"""创建系统提示模板(静态方法,无需访问实例)"""
system_template = (
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
"2. 优先利用已知用户信息进行个性化回复。\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
)
return ChatPromptTemplate.from_messages([
("system", system_template),
MessagesPlaceholder(variable_name="messages")
])
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
LLM 调用节点(异步方法)
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
"""
_log_state_change("llm_call", state, "进入")
memory_context = state.get("memory_context", "暂无用户信息")
loop = asyncio.get_event_loop()
start_time = time.time()
try:
response = await loop.run_in_executor(
None,
lambda: self._chain.invoke({
"messages": state["messages"],
"memory_context": memory_context
})
)
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
debug("="*80 + "\n")
result = {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time
}
_log_state_change("llm_call", {**state, **result}, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
error_result = {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time
}
_log_state_change("llm_call", state, "离开(异常)")
return error_result
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
工具执行节点(异步方法)
对于每个工具调用,在线程池中执行同步工具函数
"""
_log_state_change("tool_node", state, "进入")
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
_log_state_change("tool_node", state, "离开(无工具调用)")
return {"messages": []}
results = []
loop = asyncio.get_event_loop()
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = self.tools_by_name.get(tool_name)
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
if tool_func is None:
err_msg = f"Tool {tool_name} not found"
debug(f" └─ ❌ {err_msg}")
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
continue
try:
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
)
# 字符打印
result_preview = str(observation).replace("\n", " ")
debug(f" └─ ✅ 结果: {result_preview}")
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
except Exception as e:
debug(f" └─ ❌ 异常: {e}")
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
result = {"messages": results}
_log_state_change("tool_node", {**state, **result}, "离开")
return result
@staticmethod
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
"""决定下一步:工具调用、保存记忆还是结束"""
last_message = state["messages"][-1]
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
return 'tool_node'
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
if isinstance(last_message, AIMessage):
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复(无工具调用) → 转向 'save_memory'")
return 'save_memory'
# 3. 其他情况(如只有用户消息)直接结束
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
return 'END'
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""搜索并返回长期记忆"""
_log_state_change("retrieve_memory", state, "进入")
user_id = runtime.context.user_id
namespace = ("memories", user_id)
query = str(state["messages"][-1].content)
debug(f"\n{'='*60}")
debug(f"🔎 [记忆检索] 开始检索")
debug(f" ├─ 用户ID: {user_id}")
debug(f" ├─ 命名空间: {namespace}")
debug(f" ├─ 查询内容: '{query}'")
debug(f" └─ 消息总数: {len(state['messages'])}")
try:
memories = await runtime.store.asearch(namespace, query=query)
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
if memories:
memory_text = "\n".join([m.value["data"] for m in memories])
debug(f"📚 [记忆内容]")
for i, memory in enumerate(memories, 1):
debug(f" [{i}] {memory.value['data']}")
debug(f"{'='*60}\n")
result = {"memory_context": memory_text}
_log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
else:
debug(f"⚠️ [记忆检索] 未找到相关记忆")
debug(f"{'='*60}\n")
result = {"memory_context": ""}
_log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
except Exception as e:
error(f"❌ [记忆检索] 检索失败: {e}")
import traceback
traceback.print_exc()
debug(f"{'='*60}\n")
result = {"memory_context": ""}
_log_state_change("retrieve_memory", {**state, **result}, "离开(异常)")
return result
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""尝试从对话中提取并保存长期记忆"""
_log_state_change("save_memory", state, "进入")
# 获取最后一条用户消息(通常是要记住的内容的来源)
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
if not user_messages:
_log_state_change("save_memory", state, "离开(无用户消息)")
return {}
last_user_msg = user_messages[-1].content.lower()
# 简单触发逻辑:包含"记住"或"保存"等关键词
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
memory_content = f"用户说过:{last_user_msg}"
user_id = runtime.context.user_id
namespace = ("memories", user_id)
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
info(f"✅ 长期记忆已保存:{memory_content}")
_log_state_change("save_memory", state, "离开")
return {}
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
self.mem0_client = Mem0Client(llm)
def build(self) -> StateGraph:
"""
构建未编译的状态图(返回 StateGraph 实例)
图中节点直接使用实例方法 call_llm, call_tools
构建未编译的状态图
Returns:
StateGraph 实例
"""
builder = StateGraph(MessagesState,context_schema=GraphContext)
builder.add_node("retrieve_memory", self.retrieve_memory)
builder.add_node("llm_call", self.call_llm)
builder.add_node("tool_node", self.call_tools)
builder.add_node("save_memory", self.save_memory)
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
llm_call_node = create_llm_call_node(self.llm, self.tools)
tool_call_node = create_tool_call_node(self.tools_by_name)
summarize_node = create_summarize_node(self.mem0_client)
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
self.should_continue,
should_continue,
{
"tool_node": "tool_node",
"save_memory": "save_memory",
"summarize": "summarize",
'END': END
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("save_memory", END)
builder.add_edge("summarize", END)
return builder
return builder

7
app/memory/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
Mem0 记忆层模块
"""
from app.memory.mem0_client import Mem0Client
__all__ = ["Mem0Client"]

144
app/memory/mem0_client.py Normal file
View File

@@ -0,0 +1,144 @@
"""
Mem0 记忆层客户端封装模块
负责 Mem0 的初始化、检索和存储
"""
import os
from typing import Optional, List, Dict, Any
from mem0 import AsyncMemory
# 本地模块
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, VLLM_EMBEDDING_URL
from app.logger import info, warning, error
class Mem0Client:
"""Mem0 异步客户端封装类"""
def __init__(self, llm_instance):
"""
初始化 Mem0 客户端
Args:
llm_instance: LangChain LLM 实例(用于事实提取)
"""
self.llm = llm_instance
self.mem0: Optional[AsyncMemory] = None
self._initialized = False
async def initialize(self):
"""异步初始化 Mem0 客户端"""
if self._initialized:
return
try:
# 检查 Qdrant 是否可达 (可选)
import requests
try:
resp = requests.get(f"{QDRANT_URL}/collections", timeout=2)
if resp.status_code == 200:
info(f"✅ Qdrant 服务正常: {QDRANT_URL}")
except Exception:
warning(f"⚠️ 无法连接到 Qdrant: {QDRANT_URL}Mem0 将尝试自动连接")
config = {
# 向量存储:复用 Qdrant 实例
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": QDRANT_COLLECTION_NAME,
"host": QDRANT_URL.split("://")[1].split(":")[0] if "://" in QDRANT_URL else "localhost",
"port": int(QDRANT_URL.split(":")[-1]) if ":" in QDRANT_URL.split("://")[-1] else 6333,
"embedding_model_dims": 768, # embeddinggemma-300m 输出 768 维
}
},
# 事实提取 LLM直接复用传入的 LangChain 实例
"llm": {
"provider": "langchain",
"config": {
"model": self.llm # 直接传入 LangChain 模型实例
}
},
# Embedding指向 vLLM 服务
"embedder": {
"provider": "openai",
"embedding_dims": 768, # 关键:将维度参数提升到顶层
"config": {
"model": "google/embeddinggemma-300m",
"api_key": "EMPTY",
"api_base": VLLM_EMBEDDING_URL,
# 注意:不要在此处传递 dimensions 参数,避免与 vLLM v0.7.2 不兼容
}
},
"version": "v1.1"
}
self.mem0 = AsyncMemory.from_config(config)
self._initialized = True
info(f"✅ Mem0 初始化成功 (Embedding: vLLM@8002, Vector: Qdrant, LLM: 复用现有实例)")
except Exception as e:
error(f"❌ Mem0 初始化失败: {e}")
import traceback
traceback.print_exc()
self.mem0 = None
async def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[str]:
"""
检索相关记忆
Args:
query: 查询文本
user_id: 用户 ID
limit: 返回结果数量限制
Returns:
List[str]: 记忆事实列表
"""
if not self.mem0:
warning("⚠️ Mem0 未初始化,跳过记忆检索")
return []
try:
memories = await self.mem0.search(query, user_id=user_id, limit=limit)
if memories and "results" in memories:
facts = [m["memory"] for m in memories["results"] if m.get("memory")]
if facts:
info(f"🔍 [记忆检索] Mem0 返回 {len(facts)} 条记忆")
return facts
info("🔍 [记忆检索] 未找到相关记忆")
return []
except Exception as e:
warning(f"⚠️ Mem0 检索失败: {e}")
return []
async def add_memories(self, messages: List[Dict[str, str]], user_id: str) -> bool:
"""
添加记忆(自动提取事实并存储)
Args:
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}]
user_id: 用户 ID
Returns:
bool: 是否成功
"""
if not self.mem0:
warning("⚠️ Mem0 未初始化,跳过记忆添加")
return False
try:
result = await self.mem0.add(
messages,
user_id=user_id,
metadata={"type": "conversation"}
)
info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取")
return True
except Exception as e:
error(f"❌ Mem0 记忆添加失败: {e}")
return False

17
app/nodes/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
"""
节点模块 - 导出所有 LangGraph 节点函数
"""
from app.nodes.router import should_continue
from app.nodes.llm_call import create_llm_call_node
from app.nodes.tool_call import create_tool_call_node
from app.nodes.retrieve_memory import create_retrieve_memory_node
from app.nodes.summarize import create_summarize_node
__all__ = [
"should_continue",
"create_llm_call_node",
"create_tool_call_node",
"create_retrieve_memory_node",
"create_summarize_node",
]

139
app/nodes/llm_call.py Normal file
View File

@@ -0,0 +1,139 @@
"""
LLM 调用节点模块
负责调用大语言模型并处理响应
"""
import asyncio
import time
from typing import Any, Dict
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda
from langgraph.runtime import Runtime
# 本地模块
from app.state import MessagesState, GraphContext
from app.prompts import create_system_prompt
from app.utils.logging import log_state_change, print_llm_input
from app.logger import debug, info, error
def create_llm_call_node(llm: BaseLLM, tools: list):
"""
工厂函数:创建 LLM 调用节点
Args:
llm: LangChain LLM 实例
tools: 工具列表
Returns:
异步节点函数
"""
# 构建调用链
prompt = create_system_prompt()
llm_with_tools = llm.bind_tools(tools)
chain = prompt | RunnableLambda(print_llm_input) | llm_with_tools
async def call_llm(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
"""
LLM 调用节点(异步方法)
Args:
state: 当前对话状态
runtime: LangGraph 运行时上下文
Returns:
更新后的状态字典
"""
log_state_change("llm_call", state, "进入")
memory_context = state.get("memory_context", "暂无用户信息")
loop = asyncio.get_event_loop()
start_time = time.time()
try:
response = await loop.run_in_executor(
None,
lambda: chain.invoke({
"messages": state["messages"],
"memory_context": memory_context
})
)
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
debug("="*80 + "\n")
result = {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器
}
log_state_change("llm_call", {**state, **result}, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
error_result = {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器
}
log_state_change("llm_call", state, "离开(异常)")
return error_result
return call_llm

View File

@@ -0,0 +1,75 @@
"""
记忆检索节点模块
负责从 Mem0 检索相关长期记忆
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.state import MessagesState, GraphContext
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆检索节点
Args:
mem0_client: Mem0 客户端实例
Returns:
异步节点函数
"""
async def retrieve_memory(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
"""
记忆检索节点 - 使用 Mem0
Args:
state: 当前对话状态
runtime: LangGraph 运行时上下文
Returns:
包含 memory_context 的状态更新
"""
log_state_change("retrieve_memory", state, "进入")
user_id = runtime.context.user_id
# 兼容 dict 和对象两种消息格式
last_msg = state["messages"][-1]
if isinstance(last_msg, dict):
query = str(last_msg.get("content", ""))
else:
query = str(last_msg.content)
memory_text_parts = []
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
await mem0_client.initialize()
if mem0_client.mem0:
try:
# 异步调用 Mem0 语义检索
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
if facts:
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
else:
debug("🔍 [记忆检索] 未找到相关记忆")
except Exception as e:
from app.logger import warning
warning(f"⚠️ Mem0 检索失败: {e}")
else:
from app.logger import warning
warning("⚠️ Mem0 未初始化,跳过记忆检索")
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
result = {"memory_context": memory_context}
log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
return retrieve_memory

48
app/nodes/router.py Normal file
View File

@@ -0,0 +1,48 @@
"""
路由决策节点
根据当前状态决定下一步走向
"""
from typing import Literal
from langchain_core.messages import AIMessage
# 本地模块
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
from app.state import MessagesState
from app.logger import info
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'END']:
"""
决定下一步:工具调用、生成摘要还是结束
Args:
state: 当前对话状态
Returns:
下一个节点名称或 END
"""
last_message = state["messages"][-1]
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
return 'tool_node'
# 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值
if isinstance(last_message, AIMessage):
turns = state.get("turns_since_last_summary", 0)
if turns >= MEMORY_SUMMARIZE_INTERVAL:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'")
return 'summarize'
else:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
return 'END'
# 3. 其他情况(如只有用户消息)直接结束
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
return 'END'

86
app/nodes/summarize.py Normal file
View File

@@ -0,0 +1,86 @@
"""
记忆存储节点模块
负责将对话历史提交给 Mem0 进行事实提取和存储
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.state import MessagesState, GraphContext
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆存储节点
Args:
mem0_client: Mem0 客户端实例
Returns:
异步节点函数
"""
async def summarize_conversation(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
"""
记忆存储节点 - 使用 Mem0
Args:
state: 当前对话状态
runtime: LangGraph 运行时上下文
Returns:
重置计数器的状态更新
"""
log_state_change("summarize", state, "进入")
messages = state["messages"]
if len(messages) < 4:
debug("📝 [记忆添加] 对话过短,跳过")
return {"turns_since_last_summary": 0}
user_id = runtime.context.user_id
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
await mem0_client.initialize()
# 将整个对话历史转换为 Mem0 需要的消息格式
mem0_messages = []
for msg in messages:
# 兼容 dict 和对象两种格式
if isinstance(msg, dict):
msg_type = msg.get("type", "")
msg_content = msg.get("content", "")
else:
msg_type = getattr(msg, 'type', '')
msg_content = getattr(msg, 'content', '')
if msg_type == "human":
mem0_messages.append({"role": "user", "content": msg_content})
elif msg_type == "ai":
mem0_messages.append({"role": "assistant", "content": msg_content})
elif msg_type == "tool":
mem0_messages.append({"role": "system", "content": f"[工具返回] {msg_content}"})
if mem0_client.mem0:
try:
# 异步调用 Mem0 自动提取并存储事实
success = await mem0_client.add_memories(
mem0_messages,
user_id=user_id
)
if success:
info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取")
except Exception as e:
error(f"❌ Mem0 记忆添加失败: {e}")
else:
warning("⚠️ Mem0 未初始化,跳过记忆添加")
log_state_change("summarize", state, "离开")
return {"turns_since_last_summary": 0}
return summarize_conversation

90
app/nodes/tool_call.py Normal file
View File

@@ -0,0 +1,90 @@
"""
工具执行节点模块
负责执行 AI 调用的工具函数
"""
import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
# 本地模块
from app.state import MessagesState, GraphContext
from app.utils.logging import log_state_change
from app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""
工厂函数:创建工具执行节点
Args:
tools_by_name: 名称到工具函数的映射字典
Returns:
异步节点函数
"""
async def call_tools(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
"""
工具执行节点(异步方法)
Args:
state: 当前对话状态
runtime: LangGraph 运行时上下文
Returns:
包含 ToolMessage 的状态更新
"""
log_state_change("tool_node", state, "进入")
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
log_state_change("tool_node", state, "离开(无工具调用)")
return {"messages": []}
results = []
loop = asyncio.get_event_loop()
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = tools_by_name.get(tool_name)
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
if tool_func is None:
err_msg = f"Tool {tool_name} not found"
debug(f" └─ ❌ {err_msg}")
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
continue
try:
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
)
# 字符打印
result_preview = str(observation).replace("\n", " ")
debug(f" └─ ✅ 结果: {result_preview}")
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
except Exception as e:
debug(f" └─ ❌ 异常: {e}")
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
result = {"messages": results}
log_state_change("tool_node", {**state, **result}, "离开")
return result
return call_tools

38
app/prompts.py Normal file
View File

@@ -0,0 +1,38 @@
"""
提示模板管理模块
所有系统提示和对话模板统一定义
"""
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
def create_system_prompt() -> ChatPromptTemplate:
"""
创建系统提示模板
Returns:
ChatPromptTemplate: 包含系统指令和消息占位符的提示模板
"""
system_template = (
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
"2. 优先利用已知用户信息进行个性化回复。\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
)
return ChatPromptTemplate.from_messages([
("system", system_template),
MessagesPlaceholder(variable_name="messages")
])

27
app/state.py Normal file
View File

@@ -0,0 +1,27 @@
"""
LangGraph 状态定义模块
包含 MessagesState 和 GraphContext
"""
import operator
from typing import Annotated, Any
from typing_extensions import TypedDict
from dataclasses import dataclass
from langchain_core.messages import AnyMessage
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
memory_context: str
last_token_usage: dict # 本次调用的 token 使用详情
last_elapsed_time: float # 本次调用耗时(秒)
turns_since_last_summary: int # 距离上次生成摘要的轮数
@dataclass
class GraphContext:
"""图执行上下文"""
user_id: str
# 可扩展更多上下文信息

7
app/utils/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
工具模块
"""
from app.utils.logging import log_state_change, print_llm_input
__all__ = ["log_state_change", "print_llm_input"]

61
app/utils/logging.py Normal file
View File

@@ -0,0 +1,61 @@
"""
LangGraph 节点日志工具模块
提供状态流转追踪和 LLM 输入输出打印功能
"""
from app.config import ENABLE_GRAPH_TRACE
from app.logger import debug, info
def log_state_change(node_name: str, state: dict, prefix: str = "进入"):
"""
记录状态变化日志
Args:
node_name: 节点名称
state: 当前状态
prefix: 日志前缀("进入""离开"
"""
from app.logger import info
messages = state.get("messages", [])
msg_count = len(messages)
last_msg = messages[-1] if messages else None
last_info = ""
if last_msg:
# 兼容 dict 和对象两种格式
if isinstance(last_msg, dict):
content_preview = str(last_msg.get("content", ""))[:100].replace("\n", " ")
msg_type = last_msg.get("type", "unknown")
else:
content_preview = str(last_msg.content)[:100].replace("\n", " ")
msg_type = getattr(last_msg, 'type', 'unknown')
last_info = f"{msg_type.upper()}: {content_preview}"
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")
def print_llm_input(prompt_value):
"""
RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息
Args:
prompt_value: ChatPromptValue 对象,包含格式化后的消息列表
Returns:
原样返回 prompt_value不影响链式调用
"""
if not ENABLE_GRAPH_TRACE:
return prompt_value
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
debug("\n" + "=" * 80)
debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:")
debug(f" 总消息数: {len(messages)}")
debug("-" * 80)
for i, msg in enumerate(messages):
content_preview = str(msg.content) # 完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug("\n" + "=" * 80 + "\n")
return prompt_value