This commit is contained in:
65
app/agent.py
65
app/agent.py
@@ -137,7 +137,10 @@ class AIAgentService:
|
||||
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
@@ -152,3 +155,63 @@ class AIAgentService:
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""
|
||||
流式处理消息,返回异步生成器
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
thread_id: 线程 ID
|
||||
model_name: 模型名称
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
字典,包含事件类型和数据
|
||||
"""
|
||||
graph = self.graphs.get(model_name)
|
||||
if not graph:
|
||||
warning(f"警告: 模型 '{model_name}' 不可用,使用默认模型")
|
||||
model_name = next(iter(self.graphs.keys()))
|
||||
graph = self.graphs[model_name]
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
# 使用 astream_events 获取流式事件
|
||||
async for event in graph.astream_events(input_state, config=config, context=context, version="v2"):
|
||||
kind = event["event"]
|
||||
|
||||
# 聊天模型流式输出
|
||||
if kind == "on_chat_model_stream":
|
||||
content = event["data"]["chunk"].content
|
||||
if content:
|
||||
yield {"type": "token", "content": content}
|
||||
|
||||
# 工具调用开始
|
||||
elif kind == "on_tool_start":
|
||||
tool_name = event["name"]
|
||||
yield {"type": "tool_start", "tool": tool_name}
|
||||
|
||||
# 工具调用结束
|
||||
elif kind == "on_tool_end":
|
||||
tool_name = event["name"]
|
||||
yield {"type": "tool_end", "tool": tool_name}
|
||||
|
||||
# 链结束,获取最终结果
|
||||
elif kind == "on_chain_end" and event["name"] == "LangGraph":
|
||||
output = event["data"]["output"]
|
||||
reply = output["messages"][-1].content if output.get("messages") else ""
|
||||
token_usage = output.get("last_token_usage", {})
|
||||
elapsed_time = output.get("last_elapsed_time", 0.0)
|
||||
|
||||
yield {
|
||||
"type": "done",
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@@ -5,14 +5,17 @@ FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from app.agent import AIAgentService
|
||||
from app.history import ThreadHistoryService
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
# 加载 .env 文件
|
||||
@@ -37,13 +40,17 @@ async def lifespan(app: FastAPI):
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
|
||||
# 3. 将服务实例存入 app.state
|
||||
# 3. 创建历史查询服务
|
||||
history_service = ThreadHistoryService(checkpointer)
|
||||
|
||||
# 4. 将服务实例存入 app.state
|
||||
app.state.agent_service = agent_service
|
||||
app.state.history_service = history_service
|
||||
|
||||
# 应用运行中...
|
||||
yield
|
||||
|
||||
# 4. 关闭时自动清理数据库连接(async with 负责)
|
||||
# 5. 关闭时自动清理数据库连接(async with 负责)
|
||||
info("🛑 应用关闭,数据库连接池已释放")
|
||||
|
||||
|
||||
@@ -90,6 +97,11 @@ def get_agent_service(request: Request) -> AIAgentService:
|
||||
return request.app.state.agent_service
|
||||
|
||||
|
||||
def get_history_service(request: Request) -> ThreadHistoryService:
|
||||
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
|
||||
return request.app.state.history_service
|
||||
|
||||
|
||||
# ========== HTTP 端点 ==========
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(
|
||||
@@ -124,6 +136,75 @@ async def chat_endpoint(
|
||||
)
|
||||
|
||||
|
||||
# ========== 历史查询接口 ==========
|
||||
@app.get("/threads")
|
||||
async def list_threads(
|
||||
user_id: str = Query("default_user", description="用户 ID"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
history_service: ThreadHistoryService = Depends(get_history_service)
|
||||
):
|
||||
"""获取当前用户的对话历史列表"""
|
||||
threads = await history_service.get_user_threads(user_id, limit)
|
||||
return {"threads": threads}
|
||||
|
||||
|
||||
@app.get("/thread/{thread_id}/messages")
|
||||
async def get_thread_messages(
|
||||
thread_id: str,
|
||||
user_id: str = Query("default_user", description="用户 ID"),
|
||||
history_service: ThreadHistoryService = Depends(get_history_service)
|
||||
):
|
||||
"""获取指定线程的完整消息历史"""
|
||||
messages = await history_service.get_thread_messages(thread_id)
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@app.get("/thread/{thread_id}/summary")
|
||||
async def get_thread_summary(
|
||||
thread_id: str,
|
||||
user_id: str = Query("default_user", description="用户 ID"),
|
||||
history_service: ThreadHistoryService = Depends(get_history_service)
|
||||
):
|
||||
"""获取指定线程的摘要信息"""
|
||||
summary = await history_service.get_thread_summary(thread_id)
|
||||
return summary
|
||||
|
||||
|
||||
# ========== 流式对话接口 ==========
|
||||
@app.post("/chat/stream")
|
||||
async def chat_stream_endpoint(
|
||||
request: ChatRequest,
|
||||
agent_service: AIAgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""流式对话接口(SSE)"""
|
||||
if not request.message:
|
||||
raise HTTPException(status_code=400, detail="message required")
|
||||
|
||||
thread_id = request.thread_id or str(uuid.uuid4())
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
async for chunk in agent_service.process_message_stream(
|
||||
request.message, thread_id, request.model, request.user_id
|
||||
):
|
||||
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
error(f"流式响应异常: {e}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ========== WebSocket 端点(可选) ==========
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
|
||||
178
app/history.py
Normal file
178
app/history.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
历史对话查询模块
|
||||
利用 LangGraph 的 checkpointer 获取对话历史和摘要
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
from app.logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
|
||||
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的所有线程摘要信息
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
线程列表,每个包含 thread_id, last_updated, summary, message_count
|
||||
"""
|
||||
try:
|
||||
# 查询 checkpoints 表获取用户的线程列表
|
||||
async with self.checkpointer.conn.cursor() as cur:
|
||||
# 查询每个线程的最新 checkpoint 和创建时间
|
||||
query = """
|
||||
SELECT
|
||||
thread_id,
|
||||
MAX(created_at) as last_updated
|
||||
FROM checkpoints
|
||||
WHERE metadata->>'user_id' = %s
|
||||
GROUP BY thread_id
|
||||
ORDER BY last_updated DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
await cur.execute(query, (user_id, limit))
|
||||
rows = await cur.fetchall()
|
||||
|
||||
threads = []
|
||||
for row in rows:
|
||||
thread_id = row['thread_id']
|
||||
|
||||
# 获取该线程的状态
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state and state.values:
|
||||
messages = state.values.get("messages", [])
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
threads.append({
|
||||
"thread_id": thread_id,
|
||||
"last_updated": row['last_updated'].isoformat() if row['last_updated'] else "",
|
||||
"summary": summary,
|
||||
"message_count": message_count
|
||||
})
|
||||
|
||||
return threads
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
|
||||
return []
|
||||
|
||||
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取指定线程的完整消息历史
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
|
||||
Returns:
|
||||
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
|
||||
"""
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None or not state.values:
|
||||
return []
|
||||
|
||||
messages = state.values.get("messages", [])
|
||||
|
||||
# 转换 LangChain 消息对象为字典
|
||||
result = []
|
||||
for msg in messages:
|
||||
# 跳过 system 消息
|
||||
if hasattr(msg, 'type') and msg.type == "system":
|
||||
continue
|
||||
|
||||
if hasattr(msg, 'type'):
|
||||
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
|
||||
result.append({
|
||||
"role": role,
|
||||
"content": msg.content
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "unknown"))
|
||||
if role in ["human", "user"]:
|
||||
role = "user"
|
||||
elif role in ["ai", "assistant"]:
|
||||
role = "assistant"
|
||||
result.append({
|
||||
"role": role,
|
||||
"content": msg.get("content", "")
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程消息历史失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取线程摘要(用于历史列表展示)
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
|
||||
Returns:
|
||||
包含摘要信息的字典
|
||||
"""
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None or not state.values:
|
||||
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
|
||||
|
||||
messages = state.values.get("messages", [])
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
# 获取最后更新时间
|
||||
last_updated = ""
|
||||
if state.metadata and "created_at" in state.metadata:
|
||||
last_updated = state.metadata["created_at"].isoformat()
|
||||
|
||||
return {
|
||||
"thread_id": thread_id,
|
||||
"summary": summary,
|
||||
"message_count": message_count,
|
||||
"last_updated": last_updated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程摘要失败: {e}")
|
||||
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
|
||||
|
||||
def _extract_summary(self, messages: List) -> str:
|
||||
"""
|
||||
从消息列表中提取摘要
|
||||
|
||||
策略:
|
||||
1. 如果有 summarize 节点生成的 summary,优先使用
|
||||
2. 否则使用第一条用户消息的前 50 字
|
||||
"""
|
||||
# 查找是否有 summary 字段
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
|
||||
return msg.additional_kwargs['summary']
|
||||
elif isinstance(msg, dict) and msg.get('summary'):
|
||||
return msg['summary']
|
||||
|
||||
# 使用第一条用户消息作为摘要
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and msg.type == "human":
|
||||
content = msg.content
|
||||
return content[:50] + "..." if len(content) > 50 else content
|
||||
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
|
||||
content = msg.get("content", "")
|
||||
return content[:50] + "..." if len(content) > 50 else content
|
||||
|
||||
return "空对话"
|
||||
Reference in New Issue
Block a user