Files
ailine/frontend/api_client.py

192 lines
6.0 KiB
Python
Raw Normal View History

2026-04-16 03:21:38 +08:00
"""
API 客户端模块
封装所有与后端的通信支持流式响应
"""
import json
from typing import List, Dict, Any, Generator
import requests
# 使用绝对导入
from frontend.config import config
from frontend.logger import error, warning
class APIClient:
"""后端 API 客户端 - 统一封装所有 HTTP 请求"""
def __init__(self, base_url: str = None):
"""
初始化 API 客户端
Args:
base_url: 后端 API 地址默认从配置读取
"""
self.base_url = (base_url or config.api_base).rstrip("/")
# ==================== 历史管理接口 ====================
def get_user_threads(self, user_id: str, limit: int = None) -> List[Dict[str, Any]]:
"""
获取用户的历史对话列表
Args:
user_id: 用户 ID
limit: 返回数量限制默认使用配置值
Returns:
线程列表每个元素包含 thread_id, summary, message_count, last_updated
"""
try:
resp = requests.get(
f"{self.base_url}/threads",
params={
"user_id": user_id,
"limit": limit or config.history_limit
},
timeout=10
)
if resp.status_code == 200:
return resp.json().get("threads", [])
else:
warning(f"获取历史列表失败: HTTP {resp.status_code}")
return []
except Exception as e:
error(f"获取历史列表异常: {e}")
return []
def get_thread_messages(self, thread_id: str, user_id: str) -> List[Dict[str, str]]:
"""
获取指定线程的完整消息历史
Args:
thread_id: 线程 ID
user_id: 用户 ID
Returns:
消息列表每个元素包含 role content
"""
try:
resp = requests.get(
f"{self.base_url}/thread/{thread_id}/messages",
params={"user_id": user_id},
timeout=10
)
if resp.status_code == 200:
return resp.json().get("messages", [])
else:
warning(f"获取消息历史失败: HTTP {resp.status_code}")
return []
except Exception as e:
error(f"获取消息历史异常: {e}")
return []
def get_thread_summary(self, thread_id: str, user_id: str) -> Dict[str, Any]:
"""
获取指定线程的摘要信息
Args:
thread_id: 线程 ID
user_id: 用户 ID
Returns:
摘要信息字典
"""
try:
resp = requests.get(
f"{self.base_url}/thread/{thread_id}/summary",
params={"user_id": user_id},
timeout=10
)
if resp.status_code == 200:
return resp.json()
else:
warning(f"获取线程摘要失败: HTTP {resp.status_code}")
return {"summary": "加载失败", "message_count": 0}
except Exception as e:
error(f"获取线程摘要异常: {e}")
return {"summary": "加载失败", "message_count": 0}
# ==================== 聊天接口 ====================
def chat_stream(
self,
message: str,
thread_id: str,
model: str,
user_id: str
) -> Generator[Dict[str, Any], None, None]:
"""
流式对话接口SSE
Args:
message: 用户消息
thread_id: 线程 ID
model: 模型名称
user_id: 用户 ID
Yields:
SSE 事件字典类型包括
- token: 逐字输出 {type: "token", content: "..."}
- tool_start: 工具调用开始 {type: "tool_start", tool: "..."}
- tool_end: 工具调用完成 {type: "tool_end", tool: "..."}
- done: 对话完成 {type: "done", token_usage: {...}, elapsed_time: ...}
- error: 错误信息 {type: "error", message: "..."}
"""
payload = {
"message": message,
"thread_id": thread_id,
"model": model,
"user_id": user_id
}
try:
with requests.post(
f"{self.base_url}/chat/stream",
json=payload,
stream=True,
timeout=config.stream_timeout
) as response:
if response.status_code != 200:
yield {
"type": "error",
"message": f"请求失败: HTTP {response.status_code}"
}
return
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
yield data
except json.JSONDecodeError as e:
warning(f"JSON 解析失败: {e}")
except requests.exceptions.Timeout:
yield {
"type": "error",
"message": "请求超时,请检查网络连接"
}
except Exception as e:
error(f"流式对话异常: {e}")
yield {
"type": "error",
"message": f"请求失败: {str(e)}"
}
# 全局 API 客户端实例(单例模式)
api_client = APIClient()