Files
ailine/app/history.py
root 404efde282
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
添加长期存储,流式检查
2026-04-17 01:26:05 +08:00

187 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
历史对话查询模块
利用 LangGraph 的 checkpointer 获取对话历史和摘要
"""
from typing import List, Dict, Any, Optional
import logging
from app.logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""
def __init__(self, checkpointer):
self.checkpointer = checkpointer
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""
获取指定用户的所有线程摘要信息
Args:
user_id: 用户 ID
limit: 返回数量限制
Returns:
线程列表,每个包含 thread_id, last_updated, summary, message_count
"""
try:
# 查询 checkpoints 表获取用户的线程列表
async with self.checkpointer.conn.cursor() as cur:
# 在较新的 LangGraph 版本中AsyncPostgresSaver 创建的 checkpoints 表
# 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。
# 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。
# 另外,用户的 metadata 存储在 metadata JSONB 列中。
query = """
SELECT
thread_id,
MAX(checkpoint_id) as last_updated
FROM checkpoints
WHERE metadata->>'user_id' = %s
GROUP BY thread_id
ORDER BY last_updated DESC
LIMIT %s
"""
await cur.execute(query, (user_id, limit))
rows = await cur.fetchall()
threads = []
for row in rows:
thread_id = row['thread_id']
# 获取该线程的状态
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict):
messages = state.checkpoint.get("channel_values", {}).get("messages", [])
if messages:
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
threads.append({
"thread_id": thread_id,
# checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识
"last_updated": row['last_updated'] if row['last_updated'] else "",
"summary": summary,
"message_count": message_count
})
return threads
except Exception as e:
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
return []
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
"""
获取指定线程的完整消息历史
Args:
thread_id: 线程 ID
Returns:
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
"""
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None:
return []
messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else []
if not messages:
return []
# 转换 LangChain 消息对象为字典
result = []
for msg in messages:
# 跳过 system 消息
if hasattr(msg, 'type') and msg.type == "system":
continue
if hasattr(msg, 'type'):
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
result.append({
"role": role,
"content": msg.content
})
elif isinstance(msg, dict):
role = msg.get("role", msg.get("type", "unknown"))
if role in ["human", "user"]:
role = "user"
elif role in ["ai", "assistant"]:
role = "assistant"
result.append({
"role": role,
"content": msg.get("content", "")
})
return result
except Exception as e:
error(f"获取线程消息历史失败: {e}")
return []
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
"""
获取线程摘要(用于历史列表展示)
Args:
thread_id: 线程 ID
Returns:
包含摘要信息的字典
"""
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None or not state.values:
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
messages = state.values.get("messages", [])
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
# 获取最后更新时间
last_updated = ""
if state.metadata and "created_at" in state.metadata:
last_updated = state.metadata["created_at"].isoformat()
return {
"thread_id": thread_id,
"summary": summary,
"message_count": message_count,
"last_updated": last_updated
}
except Exception as e:
error(f"获取线程摘要失败: {e}")
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
def _extract_summary(self, messages: List) -> str:
"""
从消息列表中提取摘要
策略:
1. 如果有 summarize 节点生成的 summary优先使用
2. 否则使用第一条用户消息的前 50 字
"""
# 查找是否有 summary 字段
for msg in messages:
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
return msg.additional_kwargs['summary']
elif isinstance(msg, dict) and msg.get('summary'):
return msg['summary']
# 使用第一条用户消息作为摘要
for msg in messages:
if hasattr(msg, 'type') and msg.type == "human":
content = msg.content
return content[:50] + "..." if len(content) > 50 else content
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
content = msg.get("content", "")
return content[:50] + "..." if len(content) > 50 else content
return "空对话"