This commit is contained in:
@@ -1,109 +1,409 @@
|
||||
"""
|
||||
Streamlit 前端 - 支持模型选择
|
||||
右侧栏组件:工具状态和统计信息
|
||||
"""
|
||||
|
||||
# 标准库
|
||||
import os
|
||||
import uuid
|
||||
|
||||
# 第三方库
|
||||
from dotenv import load_dotenv
|
||||
import requests
|
||||
import streamlit as st
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# 后端 API 地址配置
|
||||
# 优先级:环境变量 API_URL > Docker 内部服务名 > 本地开发地址
|
||||
API_URL = os.getenv("API_URL", "http://localhost:8001/chat")
|
||||
|
||||
st.set_page_config(page_title="AI 个人助手", page_icon="🤖")
|
||||
st.title("🤖 个人生活与数据分析助手")
|
||||
|
||||
# 模型选项(与后端支持的模型名称一致)
|
||||
MODEL_OPTIONS = {
|
||||
"zhipu": "智谱 GLM-4.7-Flash(在线)",
|
||||
"deepseek": "DeepSeek V3.2(在线)",
|
||||
"local": "本地 vLLM(Gemma-4)"
|
||||
}
|
||||
|
||||
# 初始化会话状态
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "thread_id" not in st.session_state:
|
||||
st.session_state.thread_id = str(uuid.uuid4())
|
||||
if "selected_model" not in st.session_state:
|
||||
st.session_state.selected_model = "zhipu"
|
||||
|
||||
# 侧边栏:模型选择和会话管理
|
||||
with st.sidebar:
|
||||
st.header("⚙️ 设置")
|
||||
def render_info_panel():
|
||||
st.header("📊 会话信息")
|
||||
|
||||
# 模型选择
|
||||
selected_model_key = st.selectbox(
|
||||
"选择大模型",
|
||||
options=list(MODEL_OPTIONS.keys()),
|
||||
format_func=lambda x: MODEL_OPTIONS[x],
|
||||
index=0
|
||||
)
|
||||
st.session_state.selected_model = selected_model_key
|
||||
# 当前线程信息
|
||||
st.subheader("当前对话")
|
||||
st.code(st.session_state.current_thread_id[:8] + "...", language=None)
|
||||
|
||||
# 会话信息显示
|
||||
st.write(f"当前会话 ID: `{st.session_state.thread_id[:8]}...`")
|
||||
st.divider()
|
||||
|
||||
# 新会话按钮
|
||||
if st.button("🔄 新会话"):
|
||||
st.session_state.thread_id = str(uuid.uuid4())
|
||||
st.session_state.messages = []
|
||||
# 消息统计
|
||||
st.subheader("消息统计")
|
||||
user_msgs = len([m for m in st.session_state.messages if m["role"] == "user"])
|
||||
assistant_msgs = len([m for m in st.session_state.messages if m["role"] == "assistant"])
|
||||
|
||||
st.metric("用户消息", user_msgs)
|
||||
st.metric("AI 回复", assistant_msgs)
|
||||
|
||||
st.divider()
|
||||
|
||||
# 使用提示
|
||||
st.subheader("💡 使用提示")
|
||||
st.markdown("""
|
||||
- 左侧可切换历史对话
|
||||
- 点击"新对话"开始新话题
|
||||
- 登录后对话历史隔离
|
||||
- 支持流式实时响应
|
||||
- 模型可随时切换
|
||||
""")
|
||||
"""
|
||||
中间栏组件:聊天区域
|
||||
"""
|
||||
import streamlit as st
|
||||
from ..config import config
|
||||
from ..api_client import stream_chat
|
||||
|
||||
|
||||
def render_chat_area():
|
||||
# 模型选择器
|
||||
col_model, col_empty = st.columns([2, 3])
|
||||
with col_model:
|
||||
selected_model_key = st.selectbox(
|
||||
"🧠 选择模型",
|
||||
options=list(config.model_options.keys()),
|
||||
format_func=lambda x: config.model_options[x],
|
||||
index=list(config.model_options.keys()).index(st.session_state.selected_model) if st.session_state.selected_model in config.model_options else 0
|
||||
)
|
||||
st.session_state.selected_model = selected_model_key
|
||||
|
||||
st.divider()
|
||||
|
||||
# 显示消息历史
|
||||
chat_container = st.container(height=500)
|
||||
with chat_container:
|
||||
for msg in st.session_state.messages:
|
||||
with st.chat_message(msg["role"]):
|
||||
st.markdown(msg["content"])
|
||||
|
||||
# 输入框
|
||||
if prompt := st.chat_input("请输入您的问题...", key="chat_input"):
|
||||
# 显示用户消息
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 流式调用后端
|
||||
with st.chat_message("assistant"):
|
||||
message_placeholder = st.empty()
|
||||
tool_status_placeholder = st.empty()
|
||||
full_response = ""
|
||||
|
||||
stream_gen = stream_chat(
|
||||
message=prompt,
|
||||
thread_id=st.session_state.current_thread_id,
|
||||
model=st.session_state.selected_model,
|
||||
user_id=st.session_state.user_id
|
||||
)
|
||||
|
||||
if stream_gen:
|
||||
for data in stream_gen:
|
||||
if data["type"] == "token":
|
||||
full_response += data["content"]
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
|
||||
elif data["type"] == "tool_start":
|
||||
tool_status_placeholder.info(f"🔧 调用工具: {data['tool']}...")
|
||||
|
||||
elif data["type"] == "tool_end":
|
||||
tool_status_placeholder.success(f"✅ 工具 {data['tool']} 完成")
|
||||
tool_status_placeholder.empty()
|
||||
|
||||
elif data["type"] == "done":
|
||||
# 最终响应
|
||||
token_usage = data.get("token_usage", {})
|
||||
elapsed = data.get("elapsed_time", 0)
|
||||
if token_usage:
|
||||
st.caption(f"📊 消耗 {token_usage.get('total_tokens', 0)} tokens | ⏱️ {elapsed:.2f}s")
|
||||
|
||||
elif data["type"] == "error":
|
||||
st.error(f"❌ 错误: {data['message']}")
|
||||
|
||||
# 显示完整响应
|
||||
message_placeholder.markdown(full_response)
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
tool_status_placeholder.empty()
|
||||
"""
|
||||
左侧栏组件:用户登录 + 历史对话列表
|
||||
"""
|
||||
from datetime import datetime
|
||||
import streamlit as st
|
||||
from ..state import AppState
|
||||
from ..api_client import refresh_threads, load_thread_history
|
||||
|
||||
|
||||
def render_sidebar():
|
||||
st.header("👤 用户")
|
||||
|
||||
# 用户登录区域
|
||||
if not st.session_state.logged_in:
|
||||
username = st.text_input(
|
||||
"输入用户名(可选)",
|
||||
key="login_input",
|
||||
placeholder="留空使用默认用户",
|
||||
help="未登录将使用 default_user,可能导致对话污染"
|
||||
)
|
||||
|
||||
if st.button("✅ 进入", type="primary", use_container_width=True):
|
||||
AppState.login(username)
|
||||
refresh_threads(st.session_state.user_id)
|
||||
|
||||
st.info("💡 建议登录以隔离对话历史")
|
||||
else:
|
||||
st.success(f"✅ 当前用户: `{st.session_state.user_id}`")
|
||||
|
||||
if st.button("🔄 切换用户", use_container_width=True):
|
||||
AppState.reset_login()
|
||||
|
||||
st.divider()
|
||||
|
||||
# 历史对话列表
|
||||
st.header("📚 对话历史")
|
||||
|
||||
# 刷新按钮
|
||||
if st.button("🔄 刷新列表", use_container_width=True):
|
||||
refresh_threads(st.session_state.user_id)
|
||||
|
||||
# 新对话按钮
|
||||
if st.button("➕ 新对话", type="primary", use_container_width=True):
|
||||
AppState.start_new_thread()
|
||||
|
||||
st.divider()
|
||||
|
||||
# 显示历史列表
|
||||
if st.session_state.threads:
|
||||
for thread in st.session_state.threads:
|
||||
thread_id = thread["thread_id"]
|
||||
summary = thread.get("summary", "空对话")
|
||||
message_count = thread.get("message_count", 0)
|
||||
last_updated = thread.get("last_updated", "")
|
||||
|
||||
# 格式化时间
|
||||
if last_updated:
|
||||
try:
|
||||
dt = datetime.fromisoformat(last_updated.replace("Z", "+00:00"))
|
||||
time_str = dt.strftime("%m-%d %H:%M")
|
||||
except:
|
||||
time_str = last_updated[:10]
|
||||
else:
|
||||
time_str = "未知"
|
||||
|
||||
# 按钮样式
|
||||
is_current = thread_id == st.session_state.current_thread_id
|
||||
button_type = "primary" if is_current else "secondary"
|
||||
|
||||
if st.button(
|
||||
f"💬 {summary[:30]}{'...' if len(summary) > 30 else ''}\n\n🕐 {time_str} | {message_count}条",
|
||||
key=f"thread_{thread_id}",
|
||||
use_container_width=True,
|
||||
type=button_type
|
||||
):
|
||||
load_thread_history(thread_id, st.session_state.user_id)
|
||||
else:
|
||||
st.info("暂无对话历史")
|
||||
# Components package
|
||||
"""
|
||||
后端 API 客户端封装
|
||||
"""
|
||||
import json
|
||||
import requests
|
||||
import streamlit as st
|
||||
from .config import config
|
||||
|
||||
|
||||
def refresh_threads(user_id: str):
|
||||
"""刷新用户的历史对话列表"""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{config.api_base}/threads",
|
||||
params={"user_id": user_id, "limit": 50},
|
||||
timeout=10
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
st.session_state.threads = resp.json()["threads"]
|
||||
else:
|
||||
st.error(f"加载历史列表失败: HTTP {resp.status_code}")
|
||||
except Exception as e:
|
||||
st.error(f"加载历史列表失败: {e}")
|
||||
|
||||
|
||||
def load_thread_history(thread_id: str, user_id: str):
|
||||
"""加载指定线程的完整消息历史"""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{config.api_base}/thread/{thread_id}/messages",
|
||||
params={"user_id": user_id},
|
||||
timeout=10
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
st.session_state.messages = resp.json()["messages"]
|
||||
st.session_state.current_thread_id = thread_id
|
||||
st.rerun()
|
||||
else:
|
||||
st.error(f"加载对话失败: HTTP {resp.status_code}")
|
||||
except Exception as e:
|
||||
st.error(f"加载对话失败: {e}")
|
||||
|
||||
|
||||
def stream_chat(message: str, thread_id: str, model: str, user_id: str):
|
||||
"""流式调用后端聊天接口"""
|
||||
payload = {
|
||||
"message": message,
|
||||
"thread_id": thread_id,
|
||||
"model": model,
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
try:
|
||||
with requests.post(
|
||||
f"{config.api_base}/chat/stream",
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=120
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
st.error(f"请求失败: HTTP {response.status_code}")
|
||||
return None
|
||||
|
||||
full_response = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode('utf-8')
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
yield data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return full_response
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"请求失败: {e}")
|
||||
return None
|
||||
"""
|
||||
Session State 管理
|
||||
"""
|
||||
import uuid
|
||||
import streamlit as st
|
||||
|
||||
|
||||
class AppState:
|
||||
"""管理 Streamlit Session State"""
|
||||
|
||||
@staticmethod
|
||||
def init():
|
||||
"""初始化必要的 session state 变量"""
|
||||
if "user_id" not in st.session_state:
|
||||
st.session_state.user_id = "default_user"
|
||||
if "logged_in" not in st.session_state:
|
||||
st.session_state.logged_in = False
|
||||
if "threads" not in st.session_state:
|
||||
st.session_state.threads = []
|
||||
if "current_thread_id" not in st.session_state:
|
||||
st.session_state.current_thread_id = str(uuid.uuid4())
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "selected_model" not in st.session_state:
|
||||
st.session_state.selected_model = "zhipu"
|
||||
if "loading_history" not in st.session_state:
|
||||
st.session_state.loading_history = False
|
||||
|
||||
@staticmethod
|
||||
def reset_login():
|
||||
"""重置登录状态"""
|
||||
st.session_state.logged_in = False
|
||||
st.session_state.user_id = "default_user"
|
||||
st.session_state.threads = []
|
||||
st.rerun()
|
||||
|
||||
# 显示历史消息
|
||||
for msg in st.session_state.messages:
|
||||
with st.chat_message(msg["role"]):
|
||||
st.markdown(msg["content"])
|
||||
@staticmethod
|
||||
def login(username: str):
|
||||
"""执行登录"""
|
||||
st.session_state.user_id = username.strip() if username.strip() else "default_user"
|
||||
st.session_state.logged_in = True
|
||||
st.rerun()
|
||||
|
||||
# 用户输入
|
||||
if prompt := st.chat_input("请输入您的问题..."):
|
||||
# 显示用户消息
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
@staticmethod
|
||||
def start_new_thread():
|
||||
"""开始新对话"""
|
||||
st.session_state.current_thread_id = str(uuid.uuid4())
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
"""
|
||||
应用配置
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
# 调用后端 API(携带模型参数)
|
||||
with st.chat_message("assistant"):
|
||||
with st.spinner("思考中..."):
|
||||
try:
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={
|
||||
"message": prompt,
|
||||
"thread_id": st.session_state.thread_id,
|
||||
"model": st.session_state.selected_model
|
||||
},
|
||||
timeout=60
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
reply = data["reply"]
|
||||
model_used = data["model_used"]
|
||||
input_tokens = data.get("input_tokens", 0)
|
||||
output_tokens = data.get("output_tokens", 0)
|
||||
total_tokens = data.get("total_tokens", 0)
|
||||
elapsed_time = data.get("elapsed_time", 0.0)
|
||||
|
||||
# 显示回复
|
||||
st.markdown(reply)
|
||||
|
||||
# 显示使用的模型和性能指标
|
||||
stats_text = f"🤖 模型: {MODEL_OPTIONS.get(model_used, model_used)}"
|
||||
stats_text += f" | ⏱️ 耗时: {elapsed_time:.2f}s"
|
||||
if total_tokens > 0:
|
||||
stats_text += f" | 📊 Tokens: {input_tokens}(输入) + {output_tokens}(输出) = {total_tokens}(总计)"
|
||||
st.caption(stats_text)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": reply})
|
||||
except Exception as e:
|
||||
error_msg = f"请求失败: {e}"
|
||||
st.error(error_msg)
|
||||
st.session_state.messages.append({"role": "assistant", "content": error_msg})
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
page_title: str = "AI 个人助手"
|
||||
page_icon: str = "🤖"
|
||||
layout: str = "wide"
|
||||
# 后端 API 地址配置
|
||||
# 优先级:环境变量 API_URL > Docker 内部服务名 > 本地开发地址
|
||||
api_base: str = os.getenv("API_URL", "http://localhost:8001").replace("/chat", "")
|
||||
|
||||
model_options: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.model_options is None:
|
||||
self.model_options = {
|
||||
"zhipu": "智谱 GLM-4.7-Flash(在线)",
|
||||
"deepseek": "DeepSeek V3.2(在线)",
|
||||
"local": "本地 vLLM(Gemma-4)"
|
||||
}
|
||||
|
||||
config = AppConfig()
|
||||
"""
|
||||
AI Agent 前端主入口
|
||||
采用模块化架构,仅负责组装各组件
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到 Python 路径,支持绝对导入
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# 使用绝对导入
|
||||
from frontend.config import config
|
||||
from frontend.state import AppState
|
||||
from frontend.components.sidebar import render_sidebar
|
||||
from frontend.components.chat_area import render_chat_area
|
||||
from frontend.components.info_panel import render_info_panel
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 页面配置
|
||||
# =============================================================================
|
||||
st.set_page_config(
|
||||
page_title=config.page_title,
|
||||
page_icon=config.page_icon,
|
||||
layout=config.layout
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 初始化状态
|
||||
# =============================================================================
|
||||
AppState.init()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 主界面
|
||||
# =============================================================================
|
||||
def main():
|
||||
"""主界面渲染 - 三栏布局"""
|
||||
# 标题
|
||||
st.title("🤖 个人生活与数据分析助手")
|
||||
|
||||
# 三栏布局:左侧栏(1) + 中间栏(3) + 右侧栏(1)
|
||||
col_sidebar, col_chat, col_info = st.columns([1, 3, 1])
|
||||
|
||||
# 左侧栏:用户登录 + 历史对话
|
||||
with col_sidebar:
|
||||
render_sidebar()
|
||||
|
||||
# 中间栏:模型选择 + 聊天区域 + 输入框
|
||||
with col_chat:
|
||||
render_chat_area()
|
||||
|
||||
# 右侧栏:会话信息 + 统计 + 使用提示
|
||||
with col_info:
|
||||
render_info_panel()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user