This commit is contained in:
191
frontend/api_client.py
Normal file
191
frontend/api_client.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
API 客户端模块
|
||||
封装所有与后端的通信,支持流式响应
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any, Generator
|
||||
import requests
|
||||
|
||||
# 使用绝对导入
|
||||
from frontend.config import config
|
||||
from frontend.logger import error, warning
|
||||
|
||||
|
||||
class APIClient:
|
||||
"""后端 API 客户端 - 统一封装所有 HTTP 请求"""
|
||||
|
||||
def __init__(self, base_url: str = None):
|
||||
"""
|
||||
初始化 API 客户端
|
||||
|
||||
Args:
|
||||
base_url: 后端 API 地址(默认从配置读取)
|
||||
"""
|
||||
self.base_url = (base_url or config.api_base).rstrip("/")
|
||||
|
||||
# ==================== 历史管理接口 ====================
|
||||
|
||||
def get_user_threads(self, user_id: str, limit: int = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户的历史对话列表
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
limit: 返回数量限制(默认使用配置值)
|
||||
|
||||
Returns:
|
||||
线程列表,每个元素包含 thread_id, summary, message_count, last_updated
|
||||
"""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{self.base_url}/threads",
|
||||
params={
|
||||
"user_id": user_id,
|
||||
"limit": limit or config.history_limit
|
||||
},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
return resp.json().get("threads", [])
|
||||
else:
|
||||
warning(f"获取历史列表失败: HTTP {resp.status_code}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取历史列表异常: {e}")
|
||||
return []
|
||||
|
||||
def get_thread_messages(self, thread_id: str, user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取指定线程的完整消息历史
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
消息列表,每个元素包含 role 和 content
|
||||
"""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{self.base_url}/thread/{thread_id}/messages",
|
||||
params={"user_id": user_id},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
return resp.json().get("messages", [])
|
||||
else:
|
||||
warning(f"获取消息历史失败: HTTP {resp.status_code}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取消息历史异常: {e}")
|
||||
return []
|
||||
|
||||
def get_thread_summary(self, thread_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取指定线程的摘要信息
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
摘要信息字典
|
||||
"""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{self.base_url}/thread/{thread_id}/summary",
|
||||
params={"user_id": user_id},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
else:
|
||||
warning(f"获取线程摘要失败: HTTP {resp.status_code}")
|
||||
return {"summary": "加载失败", "message_count": 0}
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程摘要异常: {e}")
|
||||
return {"summary": "加载失败", "message_count": 0}
|
||||
|
||||
# ==================== 聊天接口 ====================
|
||||
|
||||
def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
thread_id: str,
|
||||
model: str,
|
||||
user_id: str
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""
|
||||
流式对话接口(SSE)
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
thread_id: 线程 ID
|
||||
model: 模型名称
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
SSE 事件字典,类型包括:
|
||||
- token: 逐字输出 {type: "token", content: "..."}
|
||||
- tool_start: 工具调用开始 {type: "tool_start", tool: "..."}
|
||||
- tool_end: 工具调用完成 {type: "tool_end", tool: "..."}
|
||||
- done: 对话完成 {type: "done", token_usage: {...}, elapsed_time: ...}
|
||||
- error: 错误信息 {type: "error", message: "..."}
|
||||
"""
|
||||
payload = {
|
||||
"message": message,
|
||||
"thread_id": thread_id,
|
||||
"model": model,
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
try:
|
||||
with requests.post(
|
||||
f"{self.base_url}/chat/stream",
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=config.stream_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
yield {
|
||||
"type": "error",
|
||||
"message": f"请求失败: HTTP {response.status_code}"
|
||||
}
|
||||
return
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode('utf-8')
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
yield data
|
||||
except json.JSONDecodeError as e:
|
||||
warning(f"JSON 解析失败: {e}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
yield {
|
||||
"type": "error",
|
||||
"message": "请求超时,请检查网络连接"
|
||||
}
|
||||
except Exception as e:
|
||||
error(f"流式对话异常: {e}")
|
||||
yield {
|
||||
"type": "error",
|
||||
"message": f"请求失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# 全局 API 客户端实例(单例模式)
|
||||
api_client = APIClient()
|
||||
Reference in New Issue
Block a user