167 lines
4.8 KiB
Python
167 lines
4.8 KiB
Python
"""
|
|
前端状态管理模块
|
|
使用 Streamlit Session State 管理应用状态
|
|
"""
|
|
|
|
import uuid
|
|
from typing import List, Dict, Any
|
|
import streamlit as st
|
|
|
|
from .config import config
|
|
|
|
|
|
class AppState:
|
|
"""应用状态管理器 - 统一管理所有 session_state"""
|
|
|
|
@staticmethod
|
|
def init():
|
|
"""初始化所有状态变量"""
|
|
# 用户状态
|
|
if "user_id" not in st.session_state:
|
|
st.session_state.user_id = config.default_user_id
|
|
if "logged_in" not in st.session_state:
|
|
st.session_state.logged_in = False
|
|
|
|
# 对话状态
|
|
if "current_thread_id" not in st.session_state:
|
|
st.session_state.current_thread_id = str(uuid.uuid4())
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = []
|
|
|
|
# 历史列表
|
|
if "threads" not in st.session_state:
|
|
st.session_state.threads = []
|
|
if "loading_history" not in st.session_state:
|
|
st.session_state.loading_history = False
|
|
|
|
# 模型选择
|
|
if "selected_model" not in st.session_state:
|
|
st.session_state.selected_model = config.default_model
|
|
|
|
# ==================== 用户相关 ====================
|
|
|
|
@staticmethod
|
|
def get_user_id() -> str:
|
|
"""获取当前用户 ID"""
|
|
return st.session_state.user_id
|
|
|
|
@staticmethod
|
|
def is_logged_in() -> bool:
|
|
"""检查是否已登录"""
|
|
return st.session_state.logged_in
|
|
|
|
@staticmethod
|
|
def login(username: str):
|
|
"""
|
|
用户登录
|
|
|
|
Args:
|
|
username: 用户名,为空则使用默认用户
|
|
"""
|
|
st.session_state.user_id = username.strip() if username.strip() else config.default_user_id
|
|
st.session_state.logged_in = True
|
|
# 登录后必须开启一个干净的新对话
|
|
AppState.start_new_thread()
|
|
|
|
@staticmethod
|
|
def logout():
|
|
"""用户登出,重置为默认用户"""
|
|
st.session_state.logged_in = False
|
|
st.session_state.user_id = config.default_user_id
|
|
st.session_state.threads = []
|
|
# 登出后必须开启一个干净的新对话
|
|
AppState.start_new_thread()
|
|
|
|
# ==================== 线程相关 ====================
|
|
|
|
@staticmethod
|
|
def get_current_thread_id() -> str:
|
|
"""获取当前线程 ID"""
|
|
return st.session_state.current_thread_id
|
|
|
|
@staticmethod
|
|
def set_current_thread_id(thread_id: str):
|
|
"""
|
|
设置当前线程 ID
|
|
|
|
Args:
|
|
thread_id: 线程 ID
|
|
"""
|
|
st.session_state.current_thread_id = thread_id
|
|
|
|
@staticmethod
|
|
def start_new_thread():
|
|
"""开始新对话,生成新线程 ID 并清空消息"""
|
|
st.session_state.current_thread_id = str(uuid.uuid4())
|
|
st.session_state.messages = []
|
|
|
|
# ==================== 消息相关 ====================
|
|
|
|
@staticmethod
|
|
def get_messages() -> List[Dict[str, str]]:
|
|
"""获取消息列表"""
|
|
return st.session_state.messages
|
|
|
|
@staticmethod
|
|
def add_message(role: str, content: str):
|
|
"""
|
|
添加消息
|
|
|
|
Args:
|
|
role: 消息角色 (user/assistant)
|
|
content: 消息内容
|
|
"""
|
|
st.session_state.messages.append({"role": role, "content": content})
|
|
|
|
@staticmethod
|
|
def clear_messages():
|
|
"""清空消息列表"""
|
|
st.session_state.messages = []
|
|
|
|
@staticmethod
|
|
def get_message_stats() -> Dict[str, int]:
|
|
"""
|
|
获取消息统计
|
|
|
|
Returns:
|
|
包含 user 和 assistant 消息数量的字典
|
|
"""
|
|
messages = st.session_state.messages
|
|
return {
|
|
"user": len([m for m in messages if m["role"] == "user"]),
|
|
"assistant": len([m for m in messages if m["role"] == "assistant"])
|
|
}
|
|
|
|
# ==================== 历史列表相关 ====================
|
|
|
|
@staticmethod
|
|
def get_threads() -> List[Dict[str, Any]]:
|
|
"""获取历史线程列表"""
|
|
return st.session_state.threads
|
|
|
|
@staticmethod
|
|
def set_threads(threads: List[Dict[str, Any]]):
|
|
"""
|
|
设置历史线程列表
|
|
|
|
Args:
|
|
threads: 线程列表
|
|
"""
|
|
st.session_state.threads = threads
|
|
|
|
# ==================== 模型相关 ====================
|
|
|
|
@staticmethod
|
|
def get_selected_model() -> str:
|
|
"""获取选中的模型"""
|
|
return st.session_state.selected_model
|
|
|
|
@staticmethod
|
|
def set_selected_model(model: str):
|
|
"""
|
|
设置选中的模型
|
|
|
|
Args:
|
|
model: 模型标识符
|
|
"""
|
|
st.session_state.selected_model = model |