192 lines
6.0 KiB
Python
192 lines
6.0 KiB
Python
"""
|
||
API 客户端模块
|
||
封装所有与后端的通信,支持流式响应
|
||
"""
|
||
|
||
import json
|
||
from typing import List, Dict, Any, Generator
|
||
import requests
|
||
|
||
# 使用绝对导入
|
||
from frontend.config import config
|
||
from frontend.logger import error, warning
|
||
|
||
|
||
class APIClient:
|
||
"""后端 API 客户端 - 统一封装所有 HTTP 请求"""
|
||
|
||
def __init__(self, base_url: str = None):
|
||
"""
|
||
初始化 API 客户端
|
||
|
||
Args:
|
||
base_url: 后端 API 地址(默认从配置读取)
|
||
"""
|
||
self.base_url = (base_url or config.api_base).rstrip("/")
|
||
|
||
# ==================== 历史管理接口 ====================
|
||
|
||
def get_user_threads(self, user_id: str, limit: int = None) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取用户的历史对话列表
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
limit: 返回数量限制(默认使用配置值)
|
||
|
||
Returns:
|
||
线程列表,每个元素包含 thread_id, summary, message_count, last_updated
|
||
"""
|
||
try:
|
||
resp = requests.get(
|
||
f"{self.base_url}/threads",
|
||
params={
|
||
"user_id": user_id,
|
||
"limit": limit or config.history_limit
|
||
},
|
||
timeout=10
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
return resp.json().get("threads", [])
|
||
else:
|
||
warning(f"获取历史列表失败: HTTP {resp.status_code}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
error(f"获取历史列表异常: {e}")
|
||
return []
|
||
|
||
def get_thread_messages(self, thread_id: str, user_id: str) -> List[Dict[str, str]]:
|
||
"""
|
||
获取指定线程的完整消息历史
|
||
|
||
Args:
|
||
thread_id: 线程 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
消息列表,每个元素包含 role 和 content
|
||
"""
|
||
try:
|
||
resp = requests.get(
|
||
f"{self.base_url}/thread/{thread_id}/messages",
|
||
params={"user_id": user_id},
|
||
timeout=10
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
return resp.json().get("messages", [])
|
||
else:
|
||
warning(f"获取消息历史失败: HTTP {resp.status_code}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
error(f"获取消息历史异常: {e}")
|
||
return []
|
||
|
||
def get_thread_summary(self, thread_id: str, user_id: str) -> Dict[str, Any]:
|
||
"""
|
||
获取指定线程的摘要信息
|
||
|
||
Args:
|
||
thread_id: 线程 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
摘要信息字典
|
||
"""
|
||
try:
|
||
resp = requests.get(
|
||
f"{self.base_url}/thread/{thread_id}/summary",
|
||
params={"user_id": user_id},
|
||
timeout=10
|
||
)
|
||
|
||
if resp.status_code == 200:
|
||
return resp.json()
|
||
else:
|
||
warning(f"获取线程摘要失败: HTTP {resp.status_code}")
|
||
return {"summary": "加载失败", "message_count": 0}
|
||
|
||
except Exception as e:
|
||
error(f"获取线程摘要异常: {e}")
|
||
return {"summary": "加载失败", "message_count": 0}
|
||
|
||
# ==================== 聊天接口 ====================
|
||
|
||
def chat_stream(
|
||
self,
|
||
message: str,
|
||
thread_id: str,
|
||
model: str,
|
||
user_id: str
|
||
) -> Generator[Dict[str, Any], None, None]:
|
||
"""
|
||
流式对话接口(SSE)
|
||
|
||
Args:
|
||
message: 用户消息
|
||
thread_id: 线程 ID
|
||
model: 模型名称
|
||
user_id: 用户 ID
|
||
|
||
Yields:
|
||
SSE 事件字典,类型包括:
|
||
- token: 逐字输出 {type: "token", content: "..."}
|
||
- tool_start: 工具调用开始 {type: "tool_start", tool: "..."}
|
||
- tool_end: 工具调用完成 {type: "tool_end", tool: "..."}
|
||
- done: 对话完成 {type: "done", token_usage: {...}, elapsed_time: ...}
|
||
- error: 错误信息 {type: "error", message: "..."}
|
||
"""
|
||
payload = {
|
||
"message": message,
|
||
"thread_id": thread_id,
|
||
"model": model,
|
||
"user_id": user_id
|
||
}
|
||
|
||
try:
|
||
with requests.post(
|
||
f"{self.base_url}/chat/stream",
|
||
json=payload,
|
||
stream=True,
|
||
timeout=config.stream_timeout
|
||
) as response:
|
||
if response.status_code != 200:
|
||
yield {
|
||
"type": "error",
|
||
"message": f"请求失败: HTTP {response.status_code}"
|
||
}
|
||
return
|
||
|
||
for line in response.iter_lines():
|
||
if line:
|
||
line = line.decode('utf-8')
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
if data_str == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
yield data
|
||
except json.JSONDecodeError as e:
|
||
warning(f"JSON 解析失败: {e}")
|
||
|
||
except requests.exceptions.Timeout:
|
||
yield {
|
||
"type": "error",
|
||
"message": "请求超时,请检查网络连接"
|
||
}
|
||
except Exception as e:
|
||
error(f"流式对话异常: {e}")
|
||
yield {
|
||
"type": "error",
|
||
"message": f"请求失败: {str(e)}"
|
||
}
|
||
|
||
|
||
# 全局 API 客户端实例(单例模式)
|
||
api_client = APIClient()
|