""" FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆 采用依赖注入模式,优雅管理资源生命周期 """ import uuid from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from agent import AIAgentService # PostgreSQL 连接字符串 DB_URI = "postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable" @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理:创建并注入全局服务""" # 1. 创建数据库连接池并初始化表 async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: await checkpointer.setup() # 2. 构建 AI Agent 服务 agent_service = AIAgentService(checkpointer) await agent_service.initialize() # 3. 将服务实例存入 app.state app.state.agent_service = agent_service # 应用运行中... yield # 4. 关闭时自动清理数据库连接(async with 负责) print("🛑 应用关闭,数据库连接池已释放") app = FastAPI(lifespan=lifespan) # CORS 中间件(允许前端跨域) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ========== Pydantic 模型 ========== class ChatRequest(BaseModel): message: str thread_id: str | None = None model: str = "zhipu" class ChatResponse(BaseModel): reply: str thread_id: str model_used: str # ========== 依赖注入函数 ========== def get_agent_service(request: Request) -> AIAgentService: """从 app.state 中获取全局 AIAgentService 实例""" return request.app.state.agent_service # ========== HTTP 端点 ========== @app.post("/chat", response_model=ChatResponse) async def chat_endpoint( request: ChatRequest, agent_service: AIAgentService = Depends(get_agent_service) ): """同步对话接口,支持模型选择""" if not request.message: raise HTTPException(status_code=400, detail="message required") thread_id = request.thread_id or str(uuid.uuid4()) reply = await agent_service.process_message( request.message, thread_id, request.model ) actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys())) return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model) # ========== WebSocket 端点(可选) ========== @app.websocket("/ws") async def websocket_endpoint( websocket: WebSocket, agent_service: AIAgentService = Depends(get_agent_service) ): await websocket.accept() try: while True: data = await websocket.receive_json() message = data.get("message") thread_id = data.get("thread_id", str(uuid.uuid4())) model = data.get("model", "zhipu") if not message: await websocket.send_json({"error": "missing message"}) continue reply = await agent_service.process_message(message, thread_id, model) actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys())) await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model}) except WebSocketDisconnect: pass if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001)