添加长期记忆
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 27s

This commit is contained in:
2026-04-14 17:34:12 +08:00
parent 1bea2491c5
commit 8dd94c6c19
12 changed files with 953 additions and 197 deletions

View File

@@ -7,17 +7,23 @@ import os
import uuid
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres.aio import AsyncPostgresStore
from app.agent import AIAgentService
from app.logger import debug, info, warning, error
# PostgreSQL 连接字符串(优先从环境变量读取,适配 Docker 和本地开发)
# 加载 .env 文件
load_dotenv()
# PostgreSQL 连接字符串配置
# 优先级:环境变量 DB_URI > Docker 内部服务名 > 本地开发地址
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:mysecretpassword@postgres:5432/langgraph_db?sslmode=disable"
"postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable"
)
@@ -25,11 +31,15 @@ DB_URI = os.getenv(
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
# 1. 创建数据库连接池并初始化表
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
async with (
AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer,
AsyncPostgresStore.from_conn_string(DB_URI) as store
):
await checkpointer.setup()
await store.setup()
# 2. 构建 AI Agent 服务
agent_service = AIAgentService(checkpointer)
agent_service = AIAgentService(checkpointer,store)
await agent_service.initialize()
# 3. 将服务实例存入 app.state
@@ -39,7 +49,7 @@ async def lifespan(app: FastAPI):
yield
# 4. 关闭时自动清理数据库连接async with 负责)
print("🛑 应用关闭,数据库连接池已释放")
info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
@@ -66,12 +76,17 @@ class ChatRequest(BaseModel):
message: str
thread_id: str | None = None
model: str = "zhipu"
user_id: str = "default_user"
class ChatResponse(BaseModel):
reply: str
thread_id: str
model_used: str
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
elapsed_time: float = 0.0
# ========== 依赖注入函数 ==========
@@ -91,11 +106,27 @@ async def chat_endpoint(
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
reply = await agent_service.process_message(
request.message, thread_id, request.model
result = await agent_service.process_message(
request.message, thread_id, request.model, request.user_id
)
# 提取 token 统计信息
token_usage = result.get("token_usage", {})
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
elapsed_time = result.get("elapsed_time", 0.0)
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model)
return ChatResponse(
reply=result["reply"],
thread_id=thread_id,
model_used=actual_model,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
elapsed_time=elapsed_time
)
# ========== WebSocket 端点(可选) ==========
@@ -111,10 +142,11 @@ async def websocket_endpoint(
message = data.get("message")
thread_id = data.get("thread_id", str(uuid.uuid4()))
model = data.get("model", "zhipu")
user_id = data.get("user_id", "default_user")
if not message:
await websocket.send_json({"error": "missing message"})
continue
reply = await agent_service.process_message(message, thread_id, model)
reply = await agent_service.process_message(message, thread_id, model, user_id)
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
except WebSocketDisconnect: