99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
"""
|
||
Streamlit 前端 - 支持模型选择
|
||
"""
|
||
|
||
# 标准库
|
||
import uuid
|
||
|
||
# 第三方库
|
||
import requests
|
||
import streamlit as st
|
||
|
||
|
||
# 原来的硬编码,本地测试
|
||
# API_URL = "http://115.190.121.151:8001/chat"
|
||
|
||
# 改为相对路径(由 Nginx 代理转发,路径前缀为 /ai)
|
||
API_URL = "/ai/api/chat"
|
||
|
||
st.set_page_config(page_title="AI 个人助手", page_icon="🤖")
|
||
st.title("🤖 个人生活与数据分析助手")
|
||
|
||
# 模型选项(与后端支持的模型名称一致)
|
||
MODEL_OPTIONS = {
|
||
"zhipu": "智谱 GLM-4.7-Flash(在线)",
|
||
"local": "本地 vLLM(Gemma-4)"
|
||
}
|
||
|
||
# 初始化会话状态
|
||
if "messages" not in st.session_state:
|
||
st.session_state.messages = []
|
||
if "thread_id" not in st.session_state:
|
||
st.session_state.thread_id = str(uuid.uuid4())
|
||
if "selected_model" not in st.session_state:
|
||
st.session_state.selected_model = "zhipu"
|
||
|
||
# 侧边栏:模型选择和会话管理
|
||
with st.sidebar:
|
||
st.header("⚙️ 设置")
|
||
|
||
# 模型选择
|
||
selected_model_key = st.selectbox(
|
||
"选择大模型",
|
||
options=list(MODEL_OPTIONS.keys()),
|
||
format_func=lambda x: MODEL_OPTIONS[x],
|
||
index=0
|
||
)
|
||
st.session_state.selected_model = selected_model_key
|
||
|
||
# 会话信息显示
|
||
st.write(f"当前会话 ID: `{st.session_state.thread_id[:8]}...`")
|
||
|
||
# 新会话按钮
|
||
if st.button("🔄 新会话"):
|
||
st.session_state.thread_id = str(uuid.uuid4())
|
||
st.session_state.messages = []
|
||
st.rerun()
|
||
|
||
# 显示历史消息
|
||
for msg in st.session_state.messages:
|
||
with st.chat_message(msg["role"]):
|
||
st.markdown(msg["content"])
|
||
|
||
# 用户输入
|
||
if prompt := st.chat_input("请输入您的问题..."):
|
||
# 显示用户消息
|
||
with st.chat_message("user"):
|
||
st.markdown(prompt)
|
||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||
|
||
# 调用后端 API(携带模型参数)
|
||
with st.chat_message("assistant"):
|
||
with st.spinner("思考中..."):
|
||
try:
|
||
response = requests.post(
|
||
API_URL,
|
||
json={
|
||
"message": prompt,
|
||
"thread_id": st.session_state.thread_id,
|
||
"model": st.session_state.selected_model
|
||
},
|
||
timeout=60
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
reply = data["reply"]
|
||
model_used = data["model_used"]
|
||
|
||
# 显示回复
|
||
st.markdown(reply)
|
||
|
||
# 显示使用的模型(小字提示)
|
||
st.caption(f"🤖 使用模型: {MODEL_OPTIONS.get(model_used, model_used)}")
|
||
|
||
st.session_state.messages.append({"role": "assistant", "content": reply})
|
||
except Exception as e:
|
||
error_msg = f"请求失败: {e}"
|
||
st.error(error_msg)
|
||
st.session_state.messages.append({"role": "assistant", "content": error_msg})
|