Files
ailine/frontend/state.py

167 lines
4.8 KiB
Python
Raw Normal View History

2026-04-16 03:21:38 +08:00
"""
前端状态管理模块
使用 Streamlit Session State 管理应用状态
"""
import uuid
from typing import List, Dict, Any
import streamlit as st
from .config import config
class AppState:
"""应用状态管理器 - 统一管理所有 session_state"""
@staticmethod
def init():
"""初始化所有状态变量"""
# 用户状态
if "user_id" not in st.session_state:
st.session_state.user_id = config.default_user_id
if "logged_in" not in st.session_state:
st.session_state.logged_in = False
# 对话状态
if "current_thread_id" not in st.session_state:
st.session_state.current_thread_id = str(uuid.uuid4())
if "messages" not in st.session_state:
st.session_state.messages = []
# 历史列表
if "threads" not in st.session_state:
st.session_state.threads = []
if "loading_history" not in st.session_state:
st.session_state.loading_history = False
# 模型选择
if "selected_model" not in st.session_state:
st.session_state.selected_model = config.default_model
# ==================== 用户相关 ====================
@staticmethod
def get_user_id() -> str:
"""获取当前用户 ID"""
return st.session_state.user_id
@staticmethod
def is_logged_in() -> bool:
"""检查是否已登录"""
return st.session_state.logged_in
@staticmethod
def login(username: str):
"""
用户登录
Args:
username: 用户名为空则使用默认用户
"""
st.session_state.user_id = username.strip() if username.strip() else config.default_user_id
st.session_state.logged_in = True
2026-04-17 01:26:05 +08:00
# 登录后必须开启一个干净的新对话
AppState.start_new_thread()
2026-04-16 03:21:38 +08:00
@staticmethod
def logout():
"""用户登出,重置为默认用户"""
st.session_state.logged_in = False
st.session_state.user_id = config.default_user_id
st.session_state.threads = []
2026-04-17 01:26:05 +08:00
# 登出后必须开启一个干净的新对话
AppState.start_new_thread()
2026-04-16 03:21:38 +08:00
# ==================== 线程相关 ====================
@staticmethod
def get_current_thread_id() -> str:
"""获取当前线程 ID"""
return st.session_state.current_thread_id
@staticmethod
def set_current_thread_id(thread_id: str):
"""
设置当前线程 ID
Args:
thread_id: 线程 ID
"""
st.session_state.current_thread_id = thread_id
@staticmethod
def start_new_thread():
"""开始新对话,生成新线程 ID 并清空消息"""
st.session_state.current_thread_id = str(uuid.uuid4())
st.session_state.messages = []
# ==================== 消息相关 ====================
@staticmethod
def get_messages() -> List[Dict[str, str]]:
"""获取消息列表"""
return st.session_state.messages
@staticmethod
def add_message(role: str, content: str):
"""
添加消息
Args:
role: 消息角色 (user/assistant)
content: 消息内容
"""
st.session_state.messages.append({"role": role, "content": content})
@staticmethod
def clear_messages():
"""清空消息列表"""
st.session_state.messages = []
@staticmethod
def get_message_stats() -> Dict[str, int]:
"""
获取消息统计
Returns:
包含 user assistant 消息数量的字典
"""
messages = st.session_state.messages
return {
"user": len([m for m in messages if m["role"] == "user"]),
"assistant": len([m for m in messages if m["role"] == "assistant"])
}
# ==================== 历史列表相关 ====================
@staticmethod
def get_threads() -> List[Dict[str, Any]]:
"""获取历史线程列表"""
return st.session_state.threads
@staticmethod
def set_threads(threads: List[Dict[str, Any]]):
"""
设置历史线程列表
Args:
threads: 线程列表
"""
st.session_state.threads = threads
# ==================== 模型相关 ====================
@staticmethod
def get_selected_model() -> str:
"""获取选中的模型"""
return st.session_state.selected_model
@staticmethod
def set_selected_model(model: str):
"""
设置选中的模型
Args:
model: 模型标识符
"""
st.session_state.selected_model = model