前端修改
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 18s

This commit is contained in:
2026-04-16 03:21:38 +08:00
parent a5b8820d13
commit 626bae54ff
22 changed files with 2968 additions and 138 deletions

View File

@@ -137,7 +137,10 @@ class AIAgentService:
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
graph = self.graphs[model]
config = {"configurable": {"thread_id": thread_id}}
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
@@ -152,3 +155,63 @@ class AIAgentService:
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""
流式处理消息,返回异步生成器
Args:
message: 用户消息
thread_id: 线程 ID
model_name: 模型名称
user_id: 用户 ID
Yields:
字典,包含事件类型和数据
"""
graph = self.graphs.get(model_name)
if not graph:
warning(f"警告: 模型 '{model_name}' 不可用,使用默认模型")
model_name = next(iter(self.graphs.keys()))
graph = self.graphs[model_name]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
# 使用 astream_events 获取流式事件
async for event in graph.astream_events(input_state, config=config, context=context, version="v2"):
kind = event["event"]
# 聊天模型流式输出
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield {"type": "token", "content": content}
# 工具调用开始
elif kind == "on_tool_start":
tool_name = event["name"]
yield {"type": "tool_start", "tool": tool_name}
# 工具调用结束
elif kind == "on_tool_end":
tool_name = event["name"]
yield {"type": "tool_end", "tool": tool_name}
# 链结束,获取最终结果
elif kind == "on_chain_end" and event["name"] == "LangGraph":
output = event["data"]["output"]
reply = output["messages"][-1].content if output.get("messages") else ""
token_usage = output.get("last_token_usage", {})
elapsed_time = output.get("last_elapsed_time", 0.0)
yield {
"type": "done",
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}

View File

@@ -5,14 +5,17 @@ FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
import os
import uuid
import json
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.history import ThreadHistoryService
from app.logger import debug, info, warning, error
# 加载 .env 文件
@@ -37,13 +40,17 @@ async def lifespan(app: FastAPI):
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
# 3. 将服务实例存入 app.state
# 3. 创建历史查询服务
history_service = ThreadHistoryService(checkpointer)
# 4. 将服务实例存入 app.state
app.state.agent_service = agent_service
app.state.history_service = history_service
# 应用运行中...
yield
# 4. 关闭时自动清理数据库连接async with 负责)
# 5. 关闭时自动清理数据库连接async with 负责)
info("🛑 应用关闭,数据库连接池已释放")
@@ -90,6 +97,11 @@ def get_agent_service(request: Request) -> AIAgentService:
return request.app.state.agent_service
def get_history_service(request: Request) -> ThreadHistoryService:
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
return request.app.state.history_service
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
@@ -124,6 +136,75 @@ async def chat_endpoint(
)
# ========== 历史查询接口 ==========
@app.get("/threads")
async def list_threads(
user_id: str = Query("default_user", description="用户 ID"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取当前用户的对话历史列表"""
threads = await history_service.get_user_threads(user_id, limit)
return {"threads": threads}
@app.get("/thread/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
user_id: str = Query("default_user", description="用户 ID"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取指定线程的完整消息历史"""
messages = await history_service.get_thread_messages(thread_id)
return {"messages": messages}
@app.get("/thread/{thread_id}/summary")
async def get_thread_summary(
thread_id: str,
user_id: str = Query("default_user", description="用户 ID"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取指定线程的摘要信息"""
summary = await history_service.get_thread_summary(thread_id)
return summary
# ========== 流式对话接口 ==========
@app.post("/chat/stream")
async def chat_stream_endpoint(
request: ChatRequest,
agent_service: AIAgentService = Depends(get_agent_service)
):
"""流式对话接口SSE"""
if not request.message:
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
async def event_generator():
try:
async for chunk in agent_service.process_message_stream(
request.message, thread_id, request.model, request.user_id
):
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
error(f"流式响应异常: {e}")
yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
}
)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(

178
app/history.py Normal file
View File

@@ -0,0 +1,178 @@
"""
历史对话查询模块
利用 LangGraph 的 checkpointer 获取对话历史和摘要
"""
from typing import List, Dict, Any, Optional
import logging
from app.logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""
def __init__(self, checkpointer):
self.checkpointer = checkpointer
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""
获取指定用户的所有线程摘要信息
Args:
user_id: 用户 ID
limit: 返回数量限制
Returns:
线程列表,每个包含 thread_id, last_updated, summary, message_count
"""
try:
# 查询 checkpoints 表获取用户的线程列表
async with self.checkpointer.conn.cursor() as cur:
# 查询每个线程的最新 checkpoint 和创建时间
query = """
SELECT
thread_id,
MAX(created_at) as last_updated
FROM checkpoints
WHERE metadata->>'user_id' = %s
GROUP BY thread_id
ORDER BY last_updated DESC
LIMIT %s
"""
await cur.execute(query, (user_id, limit))
rows = await cur.fetchall()
threads = []
for row in rows:
thread_id = row['thread_id']
# 获取该线程的状态
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state and state.values:
messages = state.values.get("messages", [])
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
threads.append({
"thread_id": thread_id,
"last_updated": row['last_updated'].isoformat() if row['last_updated'] else "",
"summary": summary,
"message_count": message_count
})
return threads
except Exception as e:
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
return []
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
"""
获取指定线程的完整消息历史
Args:
thread_id: 线程 ID
Returns:
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
"""
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None or not state.values:
return []
messages = state.values.get("messages", [])
# 转换 LangChain 消息对象为字典
result = []
for msg in messages:
# 跳过 system 消息
if hasattr(msg, 'type') and msg.type == "system":
continue
if hasattr(msg, 'type'):
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
result.append({
"role": role,
"content": msg.content
})
elif isinstance(msg, dict):
role = msg.get("role", msg.get("type", "unknown"))
if role in ["human", "user"]:
role = "user"
elif role in ["ai", "assistant"]:
role = "assistant"
result.append({
"role": role,
"content": msg.get("content", "")
})
return result
except Exception as e:
error(f"获取线程消息历史失败: {e}")
return []
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
"""
获取线程摘要(用于历史列表展示)
Args:
thread_id: 线程 ID
Returns:
包含摘要信息的字典
"""
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None or not state.values:
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
messages = state.values.get("messages", [])
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
# 获取最后更新时间
last_updated = ""
if state.metadata and "created_at" in state.metadata:
last_updated = state.metadata["created_at"].isoformat()
return {
"thread_id": thread_id,
"summary": summary,
"message_count": message_count,
"last_updated": last_updated
}
except Exception as e:
error(f"获取线程摘要失败: {e}")
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
def _extract_summary(self, messages: List) -> str:
"""
从消息列表中提取摘要
策略:
1. 如果有 summarize 节点生成的 summary优先使用
2. 否则使用第一条用户消息的前 50 字
"""
# 查找是否有 summary 字段
for msg in messages:
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
return msg.additional_kwargs['summary']
elif isinstance(msg, dict) and msg.get('summary'):
return msg['summary']
# 使用第一条用户消息作为摘要
for msg in messages:
if hasattr(msg, 'type') and msg.type == "human":
content = msg.content
return content[:50] + "..." if len(content) > 50 else content
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
content = msg.get("content", "")
return content[:50] + "..." if len(content) > 50 else content
return "空对话"