Files
ailine/app/backend.py
root e81663ab0b
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 1m16s
fix: 添加健康检查端点和 Secrets 验证步骤
2026-04-14 02:02:42 +08:00

126 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
采用依赖注入模式,优雅管理资源生命周期
"""
import os
import uuid
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
# PostgreSQL 连接字符串(优先从环境变量读取,适配 Docker 和本地开发)
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:mysecretpassword@postgres:5432/langgraph_db?sslmode=disable"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
# 1. 创建数据库连接池并初始化表
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
# 2. 构建 AI Agent 服务
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
# 3. 将服务实例存入 app.state
app.state.agent_service = agent_service
# 应用运行中...
yield
# 4. 关闭时自动清理数据库连接async with 负责)
print("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== 健康检查端点 ==========
@app.get("/health")
async def health_check():
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
return {"status": "ok", "service": "ai-agent-backend"}
# ========== Pydantic 模型 ==========
class ChatRequest(BaseModel):
message: str
thread_id: str | None = None
model: str = "zhipu"
class ChatResponse(BaseModel):
reply: str
thread_id: str
model_used: str
# ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
agent_service: AIAgentService = Depends(get_agent_service)
):
"""同步对话接口,支持模型选择"""
if not request.message:
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
reply = await agent_service.process_message(
request.message, thread_id, request.model
)
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
agent_service: AIAgentService = Depends(get_agent_service)
):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
message = data.get("message")
thread_id = data.get("thread_id", str(uuid.uuid4()))
model = data.get("model", "zhipu")
if not message:
await websocket.send_json({"error": "missing message"})
continue
reply = await agent_service.process_message(message, thread_id, model)
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
except WebSocketDisconnect:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)