192 lines
6.0 KiB
Python
192 lines
6.0 KiB
Python
|
|
"""
|
|||
|
|
API 客户端模块
|
|||
|
|
封装所有与后端的通信,支持流式响应
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
from typing import List, Dict, Any, Generator
|
|||
|
|
import requests
|
|||
|
|
|
|||
|
|
# 使用绝对导入
|
|||
|
|
from frontend.config import config
|
|||
|
|
from frontend.logger import error, warning
|
|||
|
|
|
|||
|
|
|
|||
|
|
class APIClient:
|
|||
|
|
"""后端 API 客户端 - 统一封装所有 HTTP 请求"""
|
|||
|
|
|
|||
|
|
def __init__(self, base_url: str = None):
|
|||
|
|
"""
|
|||
|
|
初始化 API 客户端
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
base_url: 后端 API 地址(默认从配置读取)
|
|||
|
|
"""
|
|||
|
|
self.base_url = (base_url or config.api_base).rstrip("/")
|
|||
|
|
|
|||
|
|
# ==================== 历史管理接口 ====================
|
|||
|
|
|
|||
|
|
def get_user_threads(self, user_id: str, limit: int = None) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
获取用户的历史对话列表
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_id: 用户 ID
|
|||
|
|
limit: 返回数量限制(默认使用配置值)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
线程列表,每个元素包含 thread_id, summary, message_count, last_updated
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
resp = requests.get(
|
|||
|
|
f"{self.base_url}/threads",
|
|||
|
|
params={
|
|||
|
|
"user_id": user_id,
|
|||
|
|
"limit": limit or config.history_limit
|
|||
|
|
},
|
|||
|
|
timeout=10
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if resp.status_code == 200:
|
|||
|
|
return resp.json().get("threads", [])
|
|||
|
|
else:
|
|||
|
|
warning(f"获取历史列表失败: HTTP {resp.status_code}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error(f"获取历史列表异常: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def get_thread_messages(self, thread_id: str, user_id: str) -> List[Dict[str, str]]:
|
|||
|
|
"""
|
|||
|
|
获取指定线程的完整消息历史
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
thread_id: 线程 ID
|
|||
|
|
user_id: 用户 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息列表,每个元素包含 role 和 content
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
resp = requests.get(
|
|||
|
|
f"{self.base_url}/thread/{thread_id}/messages",
|
|||
|
|
params={"user_id": user_id},
|
|||
|
|
timeout=10
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if resp.status_code == 200:
|
|||
|
|
return resp.json().get("messages", [])
|
|||
|
|
else:
|
|||
|
|
warning(f"获取消息历史失败: HTTP {resp.status_code}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error(f"获取消息历史异常: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def get_thread_summary(self, thread_id: str, user_id: str) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
获取指定线程的摘要信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
thread_id: 线程 ID
|
|||
|
|
user_id: 用户 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
摘要信息字典
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
resp = requests.get(
|
|||
|
|
f"{self.base_url}/thread/{thread_id}/summary",
|
|||
|
|
params={"user_id": user_id},
|
|||
|
|
timeout=10
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if resp.status_code == 200:
|
|||
|
|
return resp.json()
|
|||
|
|
else:
|
|||
|
|
warning(f"获取线程摘要失败: HTTP {resp.status_code}")
|
|||
|
|
return {"summary": "加载失败", "message_count": 0}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error(f"获取线程摘要异常: {e}")
|
|||
|
|
return {"summary": "加载失败", "message_count": 0}
|
|||
|
|
|
|||
|
|
# ==================== 聊天接口 ====================
|
|||
|
|
|
|||
|
|
def chat_stream(
|
|||
|
|
self,
|
|||
|
|
message: str,
|
|||
|
|
thread_id: str,
|
|||
|
|
model: str,
|
|||
|
|
user_id: str
|
|||
|
|
) -> Generator[Dict[str, Any], None, None]:
|
|||
|
|
"""
|
|||
|
|
流式对话接口(SSE)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
message: 用户消息
|
|||
|
|
thread_id: 线程 ID
|
|||
|
|
model: 模型名称
|
|||
|
|
user_id: 用户 ID
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
SSE 事件字典,类型包括:
|
|||
|
|
- token: 逐字输出 {type: "token", content: "..."}
|
|||
|
|
- tool_start: 工具调用开始 {type: "tool_start", tool: "..."}
|
|||
|
|
- tool_end: 工具调用完成 {type: "tool_end", tool: "..."}
|
|||
|
|
- done: 对话完成 {type: "done", token_usage: {...}, elapsed_time: ...}
|
|||
|
|
- error: 错误信息 {type: "error", message: "..."}
|
|||
|
|
"""
|
|||
|
|
payload = {
|
|||
|
|
"message": message,
|
|||
|
|
"thread_id": thread_id,
|
|||
|
|
"model": model,
|
|||
|
|
"user_id": user_id
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with requests.post(
|
|||
|
|
f"{self.base_url}/chat/stream",
|
|||
|
|
json=payload,
|
|||
|
|
stream=True,
|
|||
|
|
timeout=config.stream_timeout
|
|||
|
|
) as response:
|
|||
|
|
if response.status_code != 200:
|
|||
|
|
yield {
|
|||
|
|
"type": "error",
|
|||
|
|
"message": f"请求失败: HTTP {response.status_code}"
|
|||
|
|
}
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
for line in response.iter_lines():
|
|||
|
|
if line:
|
|||
|
|
line = line.decode('utf-8')
|
|||
|
|
if line.startswith("data: "):
|
|||
|
|
data_str = line[6:]
|
|||
|
|
if data_str == "[DONE]":
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
data = json.loads(data_str)
|
|||
|
|
yield data
|
|||
|
|
except json.JSONDecodeError as e:
|
|||
|
|
warning(f"JSON 解析失败: {e}")
|
|||
|
|
|
|||
|
|
except requests.exceptions.Timeout:
|
|||
|
|
yield {
|
|||
|
|
"type": "error",
|
|||
|
|
"message": "请求超时,请检查网络连接"
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
error(f"流式对话异常: {e}")
|
|||
|
|
yield {
|
|||
|
|
"type": "error",
|
|||
|
|
"message": f"请求失败: {str(e)}"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局 API 客户端实例(单例模式)
|
|||
|
|
api_client = APIClient()
|