This commit is contained in:
163
frontend/state.py
Normal file
163
frontend/state.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
前端状态管理模块
|
||||
使用 Streamlit Session State 管理应用状态
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Dict, Any
|
||||
import streamlit as st
|
||||
|
||||
from .config import config
|
||||
|
||||
|
||||
class AppState:
|
||||
"""应用状态管理器 - 统一管理所有 session_state"""
|
||||
|
||||
@staticmethod
|
||||
def init():
|
||||
"""初始化所有状态变量"""
|
||||
# 用户状态
|
||||
if "user_id" not in st.session_state:
|
||||
st.session_state.user_id = config.default_user_id
|
||||
if "logged_in" not in st.session_state:
|
||||
st.session_state.logged_in = False
|
||||
|
||||
# 对话状态
|
||||
if "current_thread_id" not in st.session_state:
|
||||
st.session_state.current_thread_id = str(uuid.uuid4())
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
# 历史列表
|
||||
if "threads" not in st.session_state:
|
||||
st.session_state.threads = []
|
||||
if "loading_history" not in st.session_state:
|
||||
st.session_state.loading_history = False
|
||||
|
||||
# 模型选择
|
||||
if "selected_model" not in st.session_state:
|
||||
st.session_state.selected_model = config.default_model
|
||||
|
||||
# ==================== 用户相关 ====================
|
||||
|
||||
@staticmethod
|
||||
def get_user_id() -> str:
|
||||
"""获取当前用户 ID"""
|
||||
return st.session_state.user_id
|
||||
|
||||
@staticmethod
|
||||
def is_logged_in() -> bool:
|
||||
"""检查是否已登录"""
|
||||
return st.session_state.logged_in
|
||||
|
||||
@staticmethod
|
||||
def login(username: str):
|
||||
"""
|
||||
用户登录
|
||||
|
||||
Args:
|
||||
username: 用户名,为空则使用默认用户
|
||||
"""
|
||||
st.session_state.user_id = username.strip() if username.strip() else config.default_user_id
|
||||
st.session_state.logged_in = True
|
||||
|
||||
@staticmethod
|
||||
def logout():
|
||||
"""用户登出,重置为默认用户"""
|
||||
st.session_state.logged_in = False
|
||||
st.session_state.user_id = config.default_user_id
|
||||
st.session_state.threads = []
|
||||
|
||||
# ==================== 线程相关 ====================
|
||||
|
||||
@staticmethod
|
||||
def get_current_thread_id() -> str:
|
||||
"""获取当前线程 ID"""
|
||||
return st.session_state.current_thread_id
|
||||
|
||||
@staticmethod
|
||||
def set_current_thread_id(thread_id: str):
|
||||
"""
|
||||
设置当前线程 ID
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
"""
|
||||
st.session_state.current_thread_id = thread_id
|
||||
|
||||
@staticmethod
|
||||
def start_new_thread():
|
||||
"""开始新对话,生成新线程 ID 并清空消息"""
|
||||
st.session_state.current_thread_id = str(uuid.uuid4())
|
||||
st.session_state.messages = []
|
||||
|
||||
# ==================== 消息相关 ====================
|
||||
|
||||
@staticmethod
|
||||
def get_messages() -> List[Dict[str, str]]:
|
||||
"""获取消息列表"""
|
||||
return st.session_state.messages
|
||||
|
||||
@staticmethod
|
||||
def add_message(role: str, content: str):
|
||||
"""
|
||||
添加消息
|
||||
|
||||
Args:
|
||||
role: 消息角色 (user/assistant)
|
||||
content: 消息内容
|
||||
"""
|
||||
st.session_state.messages.append({"role": role, "content": content})
|
||||
|
||||
@staticmethod
|
||||
def clear_messages():
|
||||
"""清空消息列表"""
|
||||
st.session_state.messages = []
|
||||
|
||||
@staticmethod
|
||||
def get_message_stats() -> Dict[str, int]:
|
||||
"""
|
||||
获取消息统计
|
||||
|
||||
Returns:
|
||||
包含 user 和 assistant 消息数量的字典
|
||||
"""
|
||||
messages = st.session_state.messages
|
||||
return {
|
||||
"user": len([m for m in messages if m["role"] == "user"]),
|
||||
"assistant": len([m for m in messages if m["role"] == "assistant"])
|
||||
}
|
||||
|
||||
# ==================== 历史列表相关 ====================
|
||||
|
||||
@staticmethod
|
||||
def get_threads() -> List[Dict[str, Any]]:
|
||||
"""获取历史线程列表"""
|
||||
return st.session_state.threads
|
||||
|
||||
@staticmethod
|
||||
def set_threads(threads: List[Dict[str, Any]]):
|
||||
"""
|
||||
设置历史线程列表
|
||||
|
||||
Args:
|
||||
threads: 线程列表
|
||||
"""
|
||||
st.session_state.threads = threads
|
||||
|
||||
# ==================== 模型相关 ====================
|
||||
|
||||
@staticmethod
|
||||
def get_selected_model() -> str:
|
||||
"""获取选中的模型"""
|
||||
return st.session_state.selected_model
|
||||
|
||||
@staticmethod
|
||||
def set_selected_model(model: str):
|
||||
"""
|
||||
设置选中的模型
|
||||
|
||||
Args:
|
||||
model: 模型标识符
|
||||
"""
|
||||
st.session_state.selected_model = model
|
||||
Reference in New Issue
Block a user