178 lines
6.8 KiB
Python
178 lines
6.8 KiB
Python
"""
|
||
历史对话查询模块
|
||
利用 LangGraph 的 checkpointer 获取对话历史和摘要
|
||
"""
|
||
|
||
from typing import List, Dict, Any, Optional
|
||
import logging
|
||
from app.logger import error # 保持兼容,或者替换为 logger
|
||
|
||
|
||
class ThreadHistoryService:
|
||
"""线程历史查询服务"""
|
||
|
||
def __init__(self, checkpointer):
|
||
self.checkpointer = checkpointer
|
||
|
||
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取指定用户的所有线程摘要信息
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
limit: 返回数量限制
|
||
|
||
Returns:
|
||
线程列表,每个包含 thread_id, last_updated, summary, message_count
|
||
"""
|
||
try:
|
||
# 查询 checkpoints 表获取用户的线程列表
|
||
async with self.checkpointer.conn.cursor() as cur:
|
||
# 查询每个线程的最新 checkpoint 和创建时间
|
||
query = """
|
||
SELECT
|
||
thread_id,
|
||
MAX(created_at) as last_updated
|
||
FROM checkpoints
|
||
WHERE metadata->>'user_id' = %s
|
||
GROUP BY thread_id
|
||
ORDER BY last_updated DESC
|
||
LIMIT %s
|
||
"""
|
||
await cur.execute(query, (user_id, limit))
|
||
rows = await cur.fetchall()
|
||
|
||
threads = []
|
||
for row in rows:
|
||
thread_id = row['thread_id']
|
||
|
||
# 获取该线程的状态
|
||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||
|
||
if state and state.values:
|
||
messages = state.values.get("messages", [])
|
||
summary = self._extract_summary(messages)
|
||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||
|
||
threads.append({
|
||
"thread_id": thread_id,
|
||
"last_updated": row['last_updated'].isoformat() if row['last_updated'] else "",
|
||
"summary": summary,
|
||
"message_count": message_count
|
||
})
|
||
|
||
return threads
|
||
|
||
except Exception as e:
|
||
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
|
||
return []
|
||
|
||
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
|
||
"""
|
||
获取指定线程的完整消息历史
|
||
|
||
Args:
|
||
thread_id: 线程 ID
|
||
|
||
Returns:
|
||
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
|
||
"""
|
||
try:
|
||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||
|
||
if state is None or not state.values:
|
||
return []
|
||
|
||
messages = state.values.get("messages", [])
|
||
|
||
# 转换 LangChain 消息对象为字典
|
||
result = []
|
||
for msg in messages:
|
||
# 跳过 system 消息
|
||
if hasattr(msg, 'type') and msg.type == "system":
|
||
continue
|
||
|
||
if hasattr(msg, 'type'):
|
||
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
|
||
result.append({
|
||
"role": role,
|
||
"content": msg.content
|
||
})
|
||
elif isinstance(msg, dict):
|
||
role = msg.get("role", msg.get("type", "unknown"))
|
||
if role in ["human", "user"]:
|
||
role = "user"
|
||
elif role in ["ai", "assistant"]:
|
||
role = "assistant"
|
||
result.append({
|
||
"role": role,
|
||
"content": msg.get("content", "")
|
||
})
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
error(f"获取线程消息历史失败: {e}")
|
||
return []
|
||
|
||
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
|
||
"""
|
||
获取线程摘要(用于历史列表展示)
|
||
|
||
Args:
|
||
thread_id: 线程 ID
|
||
|
||
Returns:
|
||
包含摘要信息的字典
|
||
"""
|
||
try:
|
||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||
|
||
if state is None or not state.values:
|
||
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
|
||
|
||
messages = state.values.get("messages", [])
|
||
summary = self._extract_summary(messages)
|
||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||
|
||
# 获取最后更新时间
|
||
last_updated = ""
|
||
if state.metadata and "created_at" in state.metadata:
|
||
last_updated = state.metadata["created_at"].isoformat()
|
||
|
||
return {
|
||
"thread_id": thread_id,
|
||
"summary": summary,
|
||
"message_count": message_count,
|
||
"last_updated": last_updated
|
||
}
|
||
|
||
except Exception as e:
|
||
error(f"获取线程摘要失败: {e}")
|
||
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
|
||
|
||
def _extract_summary(self, messages: List) -> str:
|
||
"""
|
||
从消息列表中提取摘要
|
||
|
||
策略:
|
||
1. 如果有 summarize 节点生成的 summary,优先使用
|
||
2. 否则使用第一条用户消息的前 50 字
|
||
"""
|
||
# 查找是否有 summary 字段
|
||
for msg in messages:
|
||
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
|
||
return msg.additional_kwargs['summary']
|
||
elif isinstance(msg, dict) and msg.get('summary'):
|
||
return msg['summary']
|
||
|
||
# 使用第一条用户消息作为摘要
|
||
for msg in messages:
|
||
if hasattr(msg, 'type') and msg.type == "human":
|
||
content = msg.content
|
||
return content[:50] + "..." if len(content) > 50 else content
|
||
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
|
||
content = msg.get("content", "")
|
||
return content[:50] + "..." if len(content) > 50 else content
|
||
|
||
return "空对话" |