Files
ailine/frontend/api_client.py
root 626bae54ff
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 18s
前端修改
2026-04-16 03:21:38 +08:00

192 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API 客户端模块
封装所有与后端的通信,支持流式响应
"""
import json
from typing import List, Dict, Any, Generator
import requests
# 使用绝对导入
from frontend.config import config
from frontend.logger import error, warning
class APIClient:
"""后端 API 客户端 - 统一封装所有 HTTP 请求"""
def __init__(self, base_url: str = None):
"""
初始化 API 客户端
Args:
base_url: 后端 API 地址(默认从配置读取)
"""
self.base_url = (base_url or config.api_base).rstrip("/")
# ==================== 历史管理接口 ====================
def get_user_threads(self, user_id: str, limit: int = None) -> List[Dict[str, Any]]:
"""
获取用户的历史对话列表
Args:
user_id: 用户 ID
limit: 返回数量限制(默认使用配置值)
Returns:
线程列表,每个元素包含 thread_id, summary, message_count, last_updated
"""
try:
resp = requests.get(
f"{self.base_url}/threads",
params={
"user_id": user_id,
"limit": limit or config.history_limit
},
timeout=10
)
if resp.status_code == 200:
return resp.json().get("threads", [])
else:
warning(f"获取历史列表失败: HTTP {resp.status_code}")
return []
except Exception as e:
error(f"获取历史列表异常: {e}")
return []
def get_thread_messages(self, thread_id: str, user_id: str) -> List[Dict[str, str]]:
"""
获取指定线程的完整消息历史
Args:
thread_id: 线程 ID
user_id: 用户 ID
Returns:
消息列表,每个元素包含 role 和 content
"""
try:
resp = requests.get(
f"{self.base_url}/thread/{thread_id}/messages",
params={"user_id": user_id},
timeout=10
)
if resp.status_code == 200:
return resp.json().get("messages", [])
else:
warning(f"获取消息历史失败: HTTP {resp.status_code}")
return []
except Exception as e:
error(f"获取消息历史异常: {e}")
return []
def get_thread_summary(self, thread_id: str, user_id: str) -> Dict[str, Any]:
"""
获取指定线程的摘要信息
Args:
thread_id: 线程 ID
user_id: 用户 ID
Returns:
摘要信息字典
"""
try:
resp = requests.get(
f"{self.base_url}/thread/{thread_id}/summary",
params={"user_id": user_id},
timeout=10
)
if resp.status_code == 200:
return resp.json()
else:
warning(f"获取线程摘要失败: HTTP {resp.status_code}")
return {"summary": "加载失败", "message_count": 0}
except Exception as e:
error(f"获取线程摘要异常: {e}")
return {"summary": "加载失败", "message_count": 0}
# ==================== 聊天接口 ====================
def chat_stream(
self,
message: str,
thread_id: str,
model: str,
user_id: str
) -> Generator[Dict[str, Any], None, None]:
"""
流式对话接口SSE
Args:
message: 用户消息
thread_id: 线程 ID
model: 模型名称
user_id: 用户 ID
Yields:
SSE 事件字典,类型包括:
- token: 逐字输出 {type: "token", content: "..."}
- tool_start: 工具调用开始 {type: "tool_start", tool: "..."}
- tool_end: 工具调用完成 {type: "tool_end", tool: "..."}
- done: 对话完成 {type: "done", token_usage: {...}, elapsed_time: ...}
- error: 错误信息 {type: "error", message: "..."}
"""
payload = {
"message": message,
"thread_id": thread_id,
"model": model,
"user_id": user_id
}
try:
with requests.post(
f"{self.base_url}/chat/stream",
json=payload,
stream=True,
timeout=config.stream_timeout
) as response:
if response.status_code != 200:
yield {
"type": "error",
"message": f"请求失败: HTTP {response.status_code}"
}
return
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
yield data
except json.JSONDecodeError as e:
warning(f"JSON 解析失败: {e}")
except requests.exceptions.Timeout:
yield {
"type": "error",
"message": "请求超时,请检查网络连接"
}
except Exception as e:
error(f"流式对话异常: {e}")
yield {
"type": "error",
"message": f"请求失败: {str(e)}"
}
# 全局 API 客户端实例(单例模式)
api_client = APIClient()