""" 前端状态管理模块 使用 Streamlit Session State 管理应用状态 """ import uuid from typing import List, Dict, Any import streamlit as st from frontend.config import config class AppState: """应用状态管理器 - 统一管理所有 session_state""" @staticmethod def init(): """初始化所有状态变量""" # 用户状态 if "user_id" not in st.session_state: st.session_state.user_id = config.default_user_id if "logged_in" not in st.session_state: st.session_state.logged_in = False # 对话状态 if "current_thread_id" not in st.session_state: st.session_state.current_thread_id = str(uuid.uuid4()) if "messages" not in st.session_state: st.session_state.messages = [] # 历史列表 if "threads" not in st.session_state: st.session_state.threads = [] if "loading_history" not in st.session_state: st.session_state.loading_history = False # 模型选择 if "selected_model" not in st.session_state: st.session_state.selected_model = config.default_model # ==================== 用户相关 ==================== @staticmethod def get_user_id() -> str: """获取当前用户 ID""" return st.session_state.user_id @staticmethod def is_logged_in() -> bool: """检查是否已登录""" return st.session_state.logged_in @staticmethod def login(username: str): """ 用户登录 Args: username: 用户名,为空则使用默认用户 """ st.session_state.user_id = username.strip() if username.strip() else config.default_user_id st.session_state.logged_in = True # 登录后必须开启一个干净的新对话 AppState.start_new_thread() @staticmethod def logout(): """用户登出,重置为默认用户""" st.session_state.logged_in = False st.session_state.user_id = config.default_user_id st.session_state.threads = [] # 登出后必须开启一个干净的新对话 AppState.start_new_thread() # ==================== 线程相关 ==================== @staticmethod def get_current_thread_id() -> str: """获取当前线程 ID""" return st.session_state.current_thread_id @staticmethod def set_current_thread_id(thread_id: str): """ 设置当前线程 ID Args: thread_id: 线程 ID """ st.session_state.current_thread_id = thread_id @staticmethod def start_new_thread(): """开始新对话,生成新线程 ID 并清空消息""" st.session_state.current_thread_id = str(uuid.uuid4()) st.session_state.messages = [] # ==================== 消息相关 ==================== @staticmethod def get_messages() -> List[Dict[str, str]]: """获取消息列表""" return st.session_state.messages @staticmethod def add_message(role: str, content: str): """ 添加消息 Args: role: 消息角色 (user/assistant) content: 消息内容 """ st.session_state.messages.append({"role": role, "content": content}) @staticmethod def clear_messages(): """清空消息列表""" st.session_state.messages = [] @staticmethod def get_message_stats() -> Dict[str, int]: """ 获取消息统计 Returns: 包含 user 和 assistant 消息数量的字典 """ messages = st.session_state.messages return { "user": len([m for m in messages if m["role"] == "user"]), "assistant": len([m for m in messages if m["role"] == "assistant"]) } # ==================== 历史列表相关 ==================== @staticmethod def get_threads() -> List[Dict[str, Any]]: """获取历史线程列表""" return st.session_state.threads @staticmethod def set_threads(threads: List[Dict[str, Any]]): """ 设置历史线程列表 Args: threads: 线程列表 """ st.session_state.threads = threads # ==================== 模型相关 ==================== @staticmethod def get_selected_model() -> str: """获取选中的模型""" return st.session_state.selected_model @staticmethod def set_selected_model(model: str): """ 设置选中的模型 Args: model: 模型标识符 """ st.session_state.selected_model = model