95 lines
2.9 KiB
Python
95 lines
2.9 KiB
Python
|
|
"""
|
|||
|
|
Streamlit 前端 - 支持模型选择
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 标准库
|
|||
|
|
import uuid
|
|||
|
|
|
|||
|
|
# 第三方库
|
|||
|
|
import requests
|
|||
|
|
import streamlit as st
|
|||
|
|
|
|||
|
|
# 后端 API 地址(端口 8001)
|
|||
|
|
API_URL = "http://localhost:8001/chat"
|
|||
|
|
|
|||
|
|
st.set_page_config(page_title="AI 个人助手", page_icon="🤖")
|
|||
|
|
st.title("🤖 个人生活与数据分析助手")
|
|||
|
|
|
|||
|
|
# 模型选项(与后端支持的模型名称一致)
|
|||
|
|
MODEL_OPTIONS = {
|
|||
|
|
"zhipu": "智谱 GLM-4.7-Flash(在线)",
|
|||
|
|
"local": "本地 vLLM(Gemma-4)"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 初始化会话状态
|
|||
|
|
if "messages" not in st.session_state:
|
|||
|
|
st.session_state.messages = []
|
|||
|
|
if "thread_id" not in st.session_state:
|
|||
|
|
st.session_state.thread_id = str(uuid.uuid4())
|
|||
|
|
if "selected_model" not in st.session_state:
|
|||
|
|
st.session_state.selected_model = "zhipu"
|
|||
|
|
|
|||
|
|
# 侧边栏:模型选择和会话管理
|
|||
|
|
with st.sidebar:
|
|||
|
|
st.header("⚙️ 设置")
|
|||
|
|
|
|||
|
|
# 模型选择
|
|||
|
|
selected_model_key = st.selectbox(
|
|||
|
|
"选择大模型",
|
|||
|
|
options=list(MODEL_OPTIONS.keys()),
|
|||
|
|
format_func=lambda x: MODEL_OPTIONS[x],
|
|||
|
|
index=0
|
|||
|
|
)
|
|||
|
|
st.session_state.selected_model = selected_model_key
|
|||
|
|
|
|||
|
|
# 会话信息显示
|
|||
|
|
st.write(f"当前会话 ID: `{st.session_state.thread_id[:8]}...`")
|
|||
|
|
|
|||
|
|
# 新会话按钮
|
|||
|
|
if st.button("🔄 新会话"):
|
|||
|
|
st.session_state.thread_id = str(uuid.uuid4())
|
|||
|
|
st.session_state.messages = []
|
|||
|
|
st.rerun()
|
|||
|
|
|
|||
|
|
# 显示历史消息
|
|||
|
|
for msg in st.session_state.messages:
|
|||
|
|
with st.chat_message(msg["role"]):
|
|||
|
|
st.markdown(msg["content"])
|
|||
|
|
|
|||
|
|
# 用户输入
|
|||
|
|
if prompt := st.chat_input("请输入您的问题..."):
|
|||
|
|
# 显示用户消息
|
|||
|
|
with st.chat_message("user"):
|
|||
|
|
st.markdown(prompt)
|
|||
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|||
|
|
|
|||
|
|
# 调用后端 API(携带模型参数)
|
|||
|
|
with st.chat_message("assistant"):
|
|||
|
|
with st.spinner("思考中..."):
|
|||
|
|
try:
|
|||
|
|
response = requests.post(
|
|||
|
|
API_URL,
|
|||
|
|
json={
|
|||
|
|
"message": prompt,
|
|||
|
|
"thread_id": st.session_state.thread_id,
|
|||
|
|
"model": st.session_state.selected_model
|
|||
|
|
},
|
|||
|
|
timeout=60
|
|||
|
|
)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
data = response.json()
|
|||
|
|
reply = data["reply"]
|
|||
|
|
model_used = data["model_used"]
|
|||
|
|
|
|||
|
|
# 显示回复
|
|||
|
|
st.markdown(reply)
|
|||
|
|
|
|||
|
|
# 显示使用的模型(小字提示)
|
|||
|
|
st.caption(f"🤖 使用模型: {MODEL_OPTIONS.get(model_used, model_used)}")
|
|||
|
|
|
|||
|
|
st.session_state.messages.append({"role": "assistant", "content": reply})
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"请求失败: {e}"
|
|||
|
|
st.error(error_msg)
|
|||
|
|
st.session_state.messages.append({"role": "assistant", "content": error_msg})
|