添加长期记忆
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 27s

This commit is contained in:
2026-04-14 17:34:12 +08:00
parent 1bea2491c5
commit 8dd94c6c19
12 changed files with 953 additions and 197 deletions

View File

@@ -11,8 +11,11 @@ from langchain_openai import ChatOpenAI
from pydantic import SecretStr
# 本地模块
from app.graph_builder import GraphBuilder
from app.graph_builder import GraphBuilder, GraphContext
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.logger import debug, info, warning, error
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres.aio import AsyncPostgresStore
load_dotenv()
@@ -20,13 +23,15 @@ load_dotenv()
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def __init__(self, checkpointer):
def __init__(self, checkpointer: AsyncPostgresSaver, store: AsyncPostgresStore):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
store: 已经初始化的 AsyncPostgresStore 实例
"""
self.checkpointer = checkpointer
self.store = store
self.graphs = {} # 存储不同模型对应的 graph 实例
def _create_zhipu_llm(self):
@@ -39,49 +44,108 @@ class AIAgentService:
api_key=api_key,
temperature=0.1,
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
)
def _create_deepseek_llm(self):
"""创建 DeepSeek LLM使用 OpenAI 兼容 API"""
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set in environment")
return ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(api_key),
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
temperature=0.1,
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
)
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://115.190.121.151:18000/v1"
)
return ChatOpenAI(
# 原来是 http://localhost:8000/v1
# 改为 FRP 穿透后的公网地址
base_url = "http://115.190.121.151:18000/v1",
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
"""预编译所有模型的 graph使用传入的 checkpointer 和 store"""
model_configs = {
"zhipu": self._create_zhipu_llm,
"deepseek": self._create_deepseek_llm,
"local": self._create_local_llm,
}
for model_name, llm_creator in model_configs.items():
try:
info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator()
# 测试 LLM 连接(可选,用于调试)
if model_name == "local":
debug(f" 测试 vLLM 连接: {os.getenv('VLLM_BASE_URL', '未设置')}")
elif model_name == "deepseek":
debug(f" 测试 DeepSeek API 连接: https://api.deepseek.com")
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer)
graph = builder.compile(checkpointer=self.checkpointer, store=self.store)
self.graphs[model_name] = graph
print(f"✅ 模型 '{model_name}' 初始化成功")
info(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
import traceback
error_detail = traceback.format_exc()
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
debug(f" 详细错误:\n{error_detail}")
if not self.graphs:
raise RuntimeError("没有可用的模型,请检查配置")
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
"1. ZHIPUAI_API_KEY 未配置或无效\n"
"2. DEEPSEEK_API_KEY 未配置或无效\n"
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
"4. 网络连接问题")
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
"""处理用户消息,返回最终答案"""
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
"""
处理用户消息返回包含回复、token统计和耗时的字典
Returns:
dict: {
"reply": str, # AI 回复内容
"token_usage": dict, # Token 使用详情
"elapsed_time": float # 调用耗时(秒)
}
"""
if model not in self.graphs:
fallback_model = next(iter(self.graphs.keys()))
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
warning(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
model = fallback_model
graph = self.graphs[model]
config = {"configurable": {"thread_id": thread_id}}
input_state = {"messages": [HumanMessage(content=message)]}
result = await graph.ainvoke(input_state, config=config)
return result["messages"][-1].content
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}

View File

@@ -7,17 +7,23 @@ import os
import uuid
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres.aio import AsyncPostgresStore
from app.agent import AIAgentService
from app.logger import debug, info, warning, error
# PostgreSQL 连接字符串(优先从环境变量读取,适配 Docker 和本地开发)
# 加载 .env 文件
load_dotenv()
# PostgreSQL 连接字符串配置
# 优先级:环境变量 DB_URI > Docker 内部服务名 > 本地开发地址
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:mysecretpassword@postgres:5432/langgraph_db?sslmode=disable"
"postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable"
)
@@ -25,11 +31,15 @@ DB_URI = os.getenv(
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
# 1. 创建数据库连接池并初始化表
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
async with (
AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer,
AsyncPostgresStore.from_conn_string(DB_URI) as store
):
await checkpointer.setup()
await store.setup()
# 2. 构建 AI Agent 服务
agent_service = AIAgentService(checkpointer)
agent_service = AIAgentService(checkpointer,store)
await agent_service.initialize()
# 3. 将服务实例存入 app.state
@@ -39,7 +49,7 @@ async def lifespan(app: FastAPI):
yield
# 4. 关闭时自动清理数据库连接async with 负责)
print("🛑 应用关闭,数据库连接池已释放")
info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
@@ -66,12 +76,17 @@ class ChatRequest(BaseModel):
message: str
thread_id: str | None = None
model: str = "zhipu"
user_id: str = "default_user"
class ChatResponse(BaseModel):
reply: str
thread_id: str
model_used: str
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
elapsed_time: float = 0.0
# ========== 依赖注入函数 ==========
@@ -91,11 +106,27 @@ async def chat_endpoint(
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
reply = await agent_service.process_message(
request.message, thread_id, request.model
result = await agent_service.process_message(
request.message, thread_id, request.model, request.user_id
)
# 提取 token 统计信息
token_usage = result.get("token_usage", {})
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
elapsed_time = result.get("elapsed_time", 0.0)
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model)
return ChatResponse(
reply=result["reply"],
thread_id=thread_id,
model_used=actual_model,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
elapsed_time=elapsed_time
)
# ========== WebSocket 端点(可选) ==========
@@ -111,10 +142,11 @@ async def websocket_endpoint(
message = data.get("message")
thread_id = data.get("thread_id", str(uuid.uuid4()))
model = data.get("model", "zhipu")
user_id = data.get("user_id", "default_user")
if not message:
await websocket.send_json({"error": "missing message"})
continue
reply = await agent_service.process_message(message, thread_id, model)
reply = await agent_service.process_message(message, thread_id, model, user_id)
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
except WebSocketDisconnect:

View File

@@ -4,19 +4,34 @@ LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
import operator
import asyncio
import time
from typing import Literal, Annotated, Any
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from langgraph.store.postgres.aio import AsyncPostgresStore
from langgraph.runtime import Runtime
from dataclasses import dataclass
import uuid
# 本地模块
from app.logger import debug, info, warning, error
class MessageState(TypedDict):
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
memory_context:str
last_token_usage: dict # 本次调用的 token 使用详情
last_elapsed_time: float # 本次调用耗时(秒)
@dataclass
class GraphContext:
user_id: str
# 可扩展更多上下文信息
class GraphBuilder:
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
@@ -42,33 +57,132 @@ class GraphBuilder:
"""创建系统提示模板(静态方法,无需访问实例)"""
return ChatPromptTemplate.from_messages([
SystemMessage(content=(
"你是一个个人生活助手和数据分析助手。请说中文。"
"用户询问天气或温度时,使用get_current_temperature工具获取信息。"
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
"2. 优先利用已知用户信息进行个性化回复。\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
)),
MessagesPlaceholder(variable_name="message")
MessagesPlaceholder(variable_name="messages")
])
async def call_llm(self, state: MessageState) -> dict:
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
LLM 调用节点(异步方法)
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
"""
memory_context = state.get("memory_context", "暂无用户信息")
# 构建完整的输入消息列表(用于调试打印)
system_prompt = self._prompt.messages[0] # SystemMessage
if isinstance(system_prompt, SystemMessage):
system_content = system_prompt.content.format(memory_context=memory_context)
else:
system_content = str(system_prompt.content)
input_messages = [SystemMessage(content=system_content)] + state["messages"]
# 打印发送给大模型的最终输入
debug("\n" + "="*80)
debug("📤 [LLM输入] 发送给大模型的完整消息:")
debug(f" 总消息数: {len(input_messages)}")
for i, msg in enumerate(input_messages):
content_preview = str(msg.content) # 不截断,完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug("="*80 + "\n")
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self._chain.invoke({"message": state["messages"]})
)
return {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1
}
start_time = time.time()
try:
response = await loop.run_in_executor(
None,
lambda: self._chain.invoke({
"messages": state["messages"],
"memory_context": memory_context
})
)
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
debug("="*80 + "\n")
return {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time
}
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
return {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time
}
async def call_tools(self, state: MessageState) -> dict:
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""
工具执行节点(异步方法)
对于每个工具调用,在线程池中执行同步工具函数
@@ -91,11 +205,15 @@ class GraphBuilder:
continue
try:
# 同步工具函数在线程池中执行
observation = await loop.run_in_executor(
None,
lambda: tool_func.invoke(tool_args)
)
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
)
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
except Exception as e:
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
@@ -103,25 +221,101 @@ class GraphBuilder:
return {"messages": results}
@staticmethod
def should_continue(state: MessageState) -> Literal['tool_node', END]:
"""
条件边判断(静态方法)
决定下一步是进入工具节点还是结束
"""
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
"""决定下一步:工具调用、保存记忆还是结束"""
last_message = state["messages"][-1]
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return 'tool_node'
return END
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
if isinstance(last_message, AIMessage):
return 'save_memory'
# 3. 其他情况(如只有用户消息)直接结束
return 'END'
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""搜索并返回长期记忆"""
user_id = runtime.context.user_id
namespace = ("memories", user_id)
query = str(state["messages"][-1].content)
debug(f"\n{'='*60}")
debug(f"🔎 [记忆检索] 开始检索")
debug(f" ├─ 用户ID: {user_id}")
debug(f" ├─ 命名空间: {namespace}")
debug(f" ├─ 查询内容: '{query}'")
debug(f" └─ 消息总数: {len(state['messages'])}")
try:
memories = await runtime.store.asearch(namespace, query=query)
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
if memories:
memory_text = "\n".join([m.value["data"] for m in memories])
debug(f"📚 [记忆内容]")
for i, memory in enumerate(memories, 1):
debug(f" [{i}] {memory.value['data']}")
debug(f"{'='*60}\n")
return {"memory_context": memory_text}
else:
debug(f"⚠️ [记忆检索] 未找到相关记忆")
debug(f"{'='*60}\n")
return {"memory_context": ""}
except Exception as e:
error(f"❌ [记忆检索] 检索失败: {e}")
import traceback
traceback.print_exc()
debug(f"{'='*60}\n")
return {"memory_context": ""}
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
"""尝试从对话中提取并保存长期记忆"""
# 获取最后一条用户消息(通常是要记住的内容的来源)
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
if not user_messages:
return {}
last_user_msg = user_messages[-1].content.lower()
# 简单触发逻辑:包含"记住"或"保存"等关键词
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
memory_content = f"用户说过:{last_user_msg}"
user_id = runtime.context.user_id
namespace = ("memories", user_id)
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
info(f"✅ 长期记忆已保存:{memory_content}")
return {}
def build(self) -> StateGraph:
"""
构建未编译的状态图(返回 StateGraph 实例)
图中节点直接使用实例方法 call_llm, call_tools
"""
builder = StateGraph(MessageState)
builder = StateGraph(MessagesState,context_schema=GraphContext)
builder.add_node("retrieve_memory", self.retrieve_memory)
builder.add_node("llm_call", self.call_llm)
builder.add_node("tool_node", self.call_tools)
builder.add_edge(START, "llm_call")
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
builder.add_node("save_memory", self.save_memory)
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
self.should_continue,
{
"tool_node": "tool_node",
"save_memory": "save_memory",
'END': END
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("save_memory", END)
return builder

55
app/logger.py Normal file
View File

@@ -0,0 +1,55 @@
"""
统一的日志模块 - 基于环境变量控制日志级别
类似 C# 的条件编译效果,开发时打印详细调试信息,生产环境只输出关键信息
"""
import os
import logging
from typing import Any
from dotenv import load_dotenv
# 先加载环境变量
load_dotenv()
# 从环境变量读取日志级别,默认 INFO
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
# 根据环境变量控制是否显示详细调试信息
DEBUG_MODE = os.getenv("DEBUG", "false").lower() == "true"
# 创建统一的日志器
logger = logging.getLogger("ai_agent")
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
# 避免重复添加 handler
if not logger.handlers:
handler = logging.StreamHandler()
# 重要handler 也需要设置级别,否则可能继承根 logger 的级别
handler.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
def debug(msg: Any, *args, **kwargs):
"""调试日志,仅在 DEBUG 环境变量为 true 时打印"""
if DEBUG_MODE:
logger.debug(msg, *args, **kwargs)
def info(msg: Any, *args, **kwargs):
"""信息日志"""
logger.info(msg, *args, **kwargs)
def warning(msg: Any, *args, **kwargs):
"""警告日志"""
logger.warning(msg, *args, **kwargs)
def error(msg: Any, *args, **kwargs):
"""错误日志"""
logger.error(msg, *args, **kwargs)