This commit is contained in:
178
app/history.py
Normal file
178
app/history.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
历史对话查询模块
|
||||
利用 LangGraph 的 checkpointer 获取对话历史和摘要
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
from app.logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
|
||||
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的所有线程摘要信息
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
线程列表,每个包含 thread_id, last_updated, summary, message_count
|
||||
"""
|
||||
try:
|
||||
# 查询 checkpoints 表获取用户的线程列表
|
||||
async with self.checkpointer.conn.cursor() as cur:
|
||||
# 查询每个线程的最新 checkpoint 和创建时间
|
||||
query = """
|
||||
SELECT
|
||||
thread_id,
|
||||
MAX(created_at) as last_updated
|
||||
FROM checkpoints
|
||||
WHERE metadata->>'user_id' = %s
|
||||
GROUP BY thread_id
|
||||
ORDER BY last_updated DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
await cur.execute(query, (user_id, limit))
|
||||
rows = await cur.fetchall()
|
||||
|
||||
threads = []
|
||||
for row in rows:
|
||||
thread_id = row['thread_id']
|
||||
|
||||
# 获取该线程的状态
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state and state.values:
|
||||
messages = state.values.get("messages", [])
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
threads.append({
|
||||
"thread_id": thread_id,
|
||||
"last_updated": row['last_updated'].isoformat() if row['last_updated'] else "",
|
||||
"summary": summary,
|
||||
"message_count": message_count
|
||||
})
|
||||
|
||||
return threads
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
|
||||
return []
|
||||
|
||||
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取指定线程的完整消息历史
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
|
||||
Returns:
|
||||
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
|
||||
"""
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None or not state.values:
|
||||
return []
|
||||
|
||||
messages = state.values.get("messages", [])
|
||||
|
||||
# 转换 LangChain 消息对象为字典
|
||||
result = []
|
||||
for msg in messages:
|
||||
# 跳过 system 消息
|
||||
if hasattr(msg, 'type') and msg.type == "system":
|
||||
continue
|
||||
|
||||
if hasattr(msg, 'type'):
|
||||
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
|
||||
result.append({
|
||||
"role": role,
|
||||
"content": msg.content
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "unknown"))
|
||||
if role in ["human", "user"]:
|
||||
role = "user"
|
||||
elif role in ["ai", "assistant"]:
|
||||
role = "assistant"
|
||||
result.append({
|
||||
"role": role,
|
||||
"content": msg.get("content", "")
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程消息历史失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取线程摘要(用于历史列表展示)
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
|
||||
Returns:
|
||||
包含摘要信息的字典
|
||||
"""
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None or not state.values:
|
||||
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
|
||||
|
||||
messages = state.values.get("messages", [])
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
# 获取最后更新时间
|
||||
last_updated = ""
|
||||
if state.metadata and "created_at" in state.metadata:
|
||||
last_updated = state.metadata["created_at"].isoformat()
|
||||
|
||||
return {
|
||||
"thread_id": thread_id,
|
||||
"summary": summary,
|
||||
"message_count": message_count,
|
||||
"last_updated": last_updated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程摘要失败: {e}")
|
||||
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
|
||||
|
||||
def _extract_summary(self, messages: List) -> str:
|
||||
"""
|
||||
从消息列表中提取摘要
|
||||
|
||||
策略:
|
||||
1. 如果有 summarize 节点生成的 summary,优先使用
|
||||
2. 否则使用第一条用户消息的前 50 字
|
||||
"""
|
||||
# 查找是否有 summary 字段
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
|
||||
return msg.additional_kwargs['summary']
|
||||
elif isinstance(msg, dict) and msg.get('summary'):
|
||||
return msg['summary']
|
||||
|
||||
# 使用第一条用户消息作为摘要
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and msg.type == "human":
|
||||
content = msg.content
|
||||
return content[:50] + "..." if len(content) > 50 else content
|
||||
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
|
||||
content = msg.get("content", "")
|
||||
return content[:50] + "..." if len(content) > 50 else content
|
||||
|
||||
return "空对话"
|
||||
Reference in New Issue
Block a user