410 lines
13 KiB
Python
410 lines
13 KiB
Python
"""
|
||
右侧栏组件:工具状态和统计信息
|
||
"""
|
||
import streamlit as st
|
||
|
||
|
||
def render_info_panel():
|
||
st.header("📊 会话信息")
|
||
|
||
# 当前线程信息
|
||
st.subheader("当前对话")
|
||
st.code(st.session_state.current_thread_id[:8] + "...", language=None)
|
||
|
||
st.divider()
|
||
|
||
# 消息统计
|
||
st.subheader("消息统计")
|
||
user_msgs = len([m for m in st.session_state.messages if m["role"] == "user"])
|
||
assistant_msgs = len([m for m in st.session_state.messages if m["role"] == "assistant"])
|
||
|
||
st.metric("用户消息", user_msgs)
|
||
st.metric("AI 回复", assistant_msgs)
|
||
|
||
st.divider()
|
||
|
||
# 使用提示
|
||
st.subheader("💡 使用提示")
|
||
st.markdown("""
|
||
- 左侧可切换历史对话
|
||
- 点击"新对话"开始新话题
|
||
- 登录后对话历史隔离
|
||
- 支持流式实时响应
|
||
- 模型可随时切换
|
||
""")
|
||
"""
|
||
中间栏组件:聊天区域
|
||
"""
|
||
import streamlit as st
|
||
from ..config import config
|
||
from ..api_client import stream_chat
|
||
|
||
|
||
def render_chat_area():
|
||
# 模型选择器
|
||
col_model, col_empty = st.columns([2, 3])
|
||
with col_model:
|
||
selected_model_key = st.selectbox(
|
||
"🧠 选择模型",
|
||
options=list(config.model_options.keys()),
|
||
format_func=lambda x: config.model_options[x],
|
||
index=list(config.model_options.keys()).index(st.session_state.selected_model) if st.session_state.selected_model in config.model_options else 0
|
||
)
|
||
st.session_state.selected_model = selected_model_key
|
||
|
||
st.divider()
|
||
|
||
# 显示消息历史
|
||
chat_container = st.container(height=500)
|
||
with chat_container:
|
||
for msg in st.session_state.messages:
|
||
with st.chat_message(msg["role"]):
|
||
st.markdown(msg["content"])
|
||
|
||
# 输入框
|
||
if prompt := st.chat_input("请输入您的问题...", key="chat_input"):
|
||
# 显示用户消息
|
||
with st.chat_message("user"):
|
||
st.markdown(prompt)
|
||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||
|
||
# 流式调用后端
|
||
with st.chat_message("assistant"):
|
||
message_placeholder = st.empty()
|
||
tool_status_placeholder = st.empty()
|
||
full_response = ""
|
||
|
||
stream_gen = stream_chat(
|
||
message=prompt,
|
||
thread_id=st.session_state.current_thread_id,
|
||
model=st.session_state.selected_model,
|
||
user_id=st.session_state.user_id
|
||
)
|
||
|
||
if stream_gen:
|
||
for data in stream_gen:
|
||
if data["type"] == "token":
|
||
full_response += data["content"]
|
||
message_placeholder.markdown(full_response + "▌")
|
||
|
||
elif data["type"] == "tool_start":
|
||
tool_status_placeholder.info(f"🔧 调用工具: {data['tool']}...")
|
||
|
||
elif data["type"] == "tool_end":
|
||
tool_status_placeholder.success(f"✅ 工具 {data['tool']} 完成")
|
||
tool_status_placeholder.empty()
|
||
|
||
elif data["type"] == "done":
|
||
# 最终响应
|
||
token_usage = data.get("token_usage", {})
|
||
elapsed = data.get("elapsed_time", 0)
|
||
if token_usage:
|
||
st.caption(f"📊 消耗 {token_usage.get('total_tokens', 0)} tokens | ⏱️ {elapsed:.2f}s")
|
||
|
||
elif data["type"] == "error":
|
||
st.error(f"❌ 错误: {data['message']}")
|
||
|
||
# 显示完整响应
|
||
message_placeholder.markdown(full_response)
|
||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||
tool_status_placeholder.empty()
|
||
"""
|
||
左侧栏组件:用户登录 + 历史对话列表
|
||
"""
|
||
from datetime import datetime
|
||
import streamlit as st
|
||
from ..state import AppState
|
||
from ..api_client import refresh_threads, load_thread_history
|
||
|
||
|
||
def render_sidebar():
|
||
st.header("👤 用户")
|
||
|
||
# 用户登录区域
|
||
if not st.session_state.logged_in:
|
||
username = st.text_input(
|
||
"输入用户名(可选)",
|
||
key="login_input",
|
||
placeholder="留空使用默认用户",
|
||
help="未登录将使用 default_user,可能导致对话污染"
|
||
)
|
||
|
||
if st.button("✅ 进入", type="primary", use_container_width=True):
|
||
AppState.login(username)
|
||
refresh_threads(st.session_state.user_id)
|
||
|
||
st.info("💡 建议登录以隔离对话历史")
|
||
else:
|
||
st.success(f"✅ 当前用户: `{st.session_state.user_id}`")
|
||
|
||
if st.button("🔄 切换用户", use_container_width=True):
|
||
AppState.reset_login()
|
||
|
||
st.divider()
|
||
|
||
# 历史对话列表
|
||
st.header("📚 对话历史")
|
||
|
||
# 刷新按钮
|
||
if st.button("🔄 刷新列表", use_container_width=True):
|
||
refresh_threads(st.session_state.user_id)
|
||
|
||
# 新对话按钮
|
||
if st.button("➕ 新对话", type="primary", use_container_width=True):
|
||
AppState.start_new_thread()
|
||
|
||
st.divider()
|
||
|
||
# 显示历史列表
|
||
if st.session_state.threads:
|
||
for thread in st.session_state.threads:
|
||
thread_id = thread["thread_id"]
|
||
summary = thread.get("summary", "空对话")
|
||
message_count = thread.get("message_count", 0)
|
||
last_updated = thread.get("last_updated", "")
|
||
|
||
# 格式化时间
|
||
if last_updated:
|
||
try:
|
||
dt = datetime.fromisoformat(last_updated.replace("Z", "+00:00"))
|
||
time_str = dt.strftime("%m-%d %H:%M")
|
||
except:
|
||
time_str = last_updated[:10]
|
||
else:
|
||
time_str = "未知"
|
||
|
||
# 按钮样式
|
||
is_current = thread_id == st.session_state.current_thread_id
|
||
button_type = "primary" if is_current else "secondary"
|
||
|
||
if st.button(
|
||
f"💬 {summary[:30]}{'...' if len(summary) > 30 else ''}\n\n🕐 {time_str} | {message_count}条",
|
||
key=f"thread_{thread_id}",
|
||
use_container_width=True,
|
||
type=button_type
|
||
):
|
||
load_thread_history(thread_id, st.session_state.user_id)
|
||
else:
|
||
st.info("暂无对话历史")
|
||
# Components package
|
||
"""
|
||
后端 API 客户端封装
|
||
"""
|
||
import json
|
||
import requests
|
||
import streamlit as st
|
||
from .config import config
|
||
|
||
|
||
def refresh_threads(user_id: str):
|
||
"""刷新用户的历史对话列表"""
|
||
try:
|
||
resp = requests.get(
|
||
f"{config.api_base}/threads",
|
||
params={"user_id": user_id, "limit": 50},
|
||
timeout=10
|
||
)
|
||
if resp.status_code == 200:
|
||
st.session_state.threads = resp.json()["threads"]
|
||
else:
|
||
st.error(f"加载历史列表失败: HTTP {resp.status_code}")
|
||
except Exception as e:
|
||
st.error(f"加载历史列表失败: {e}")
|
||
|
||
|
||
def load_thread_history(thread_id: str, user_id: str):
|
||
"""加载指定线程的完整消息历史"""
|
||
try:
|
||
resp = requests.get(
|
||
f"{config.api_base}/thread/{thread_id}/messages",
|
||
params={"user_id": user_id},
|
||
timeout=10
|
||
)
|
||
if resp.status_code == 200:
|
||
st.session_state.messages = resp.json()["messages"]
|
||
st.session_state.current_thread_id = thread_id
|
||
st.rerun()
|
||
else:
|
||
st.error(f"加载对话失败: HTTP {resp.status_code}")
|
||
except Exception as e:
|
||
st.error(f"加载对话失败: {e}")
|
||
|
||
|
||
def stream_chat(message: str, thread_id: str, model: str, user_id: str):
|
||
"""流式调用后端聊天接口"""
|
||
payload = {
|
||
"message": message,
|
||
"thread_id": thread_id,
|
||
"model": model,
|
||
"user_id": user_id
|
||
}
|
||
|
||
try:
|
||
with requests.post(
|
||
f"{config.api_base}/chat/stream",
|
||
json=payload,
|
||
stream=True,
|
||
timeout=120
|
||
) as response:
|
||
if response.status_code != 200:
|
||
st.error(f"请求失败: HTTP {response.status_code}")
|
||
return None
|
||
|
||
full_response = ""
|
||
for line in response.iter_lines():
|
||
if line:
|
||
line = line.decode('utf-8')
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
data = json.loads(data_str)
|
||
yield data
|
||
except json.JSONDecodeError:
|
||
pass
|
||
return full_response
|
||
|
||
except Exception as e:
|
||
st.error(f"请求失败: {e}")
|
||
return None
|
||
"""
|
||
Session State 管理
|
||
"""
|
||
import uuid
|
||
import streamlit as st
|
||
|
||
|
||
class AppState:
|
||
"""管理 Streamlit Session State"""
|
||
|
||
@staticmethod
|
||
def init():
|
||
"""初始化必要的 session state 变量"""
|
||
if "user_id" not in st.session_state:
|
||
st.session_state.user_id = "default_user"
|
||
if "logged_in" not in st.session_state:
|
||
st.session_state.logged_in = False
|
||
if "threads" not in st.session_state:
|
||
st.session_state.threads = []
|
||
if "current_thread_id" not in st.session_state:
|
||
st.session_state.current_thread_id = str(uuid.uuid4())
|
||
if "messages" not in st.session_state:
|
||
st.session_state.messages = []
|
||
if "selected_model" not in st.session_state:
|
||
st.session_state.selected_model = "zhipu"
|
||
if "loading_history" not in st.session_state:
|
||
st.session_state.loading_history = False
|
||
|
||
@staticmethod
|
||
def reset_login():
|
||
"""重置登录状态"""
|
||
st.session_state.logged_in = False
|
||
st.session_state.user_id = "default_user"
|
||
st.session_state.threads = []
|
||
st.rerun()
|
||
|
||
@staticmethod
|
||
def login(username: str):
|
||
"""执行登录"""
|
||
st.session_state.user_id = username.strip() if username.strip() else "default_user"
|
||
st.session_state.logged_in = True
|
||
st.rerun()
|
||
|
||
@staticmethod
|
||
def start_new_thread():
|
||
"""开始新对话"""
|
||
st.session_state.current_thread_id = str(uuid.uuid4())
|
||
st.session_state.messages = []
|
||
st.rerun()
|
||
"""
|
||
应用配置
|
||
"""
|
||
import os
|
||
from dataclasses import dataclass
|
||
|
||
|
||
@dataclass
|
||
class AppConfig:
|
||
page_title: str = "AI 个人助手"
|
||
page_icon: str = "🤖"
|
||
layout: str = "wide"
|
||
# 后端 API 地址配置
|
||
# 优先级:环境变量 API_URL > Docker 内部服务名 > 本地开发地址
|
||
api_base: str = os.getenv("API_URL", "http://localhost:8001").replace("/chat", "")
|
||
|
||
model_options: dict = None
|
||
|
||
def __post_init__(self):
|
||
if self.model_options is None:
|
||
self.model_options = {
|
||
"zhipu": "智谱 GLM-4.7-Flash(在线)",
|
||
"deepseek": "DeepSeek V3.2(在线)",
|
||
"local": "本地 vLLM(Gemma-4)"
|
||
}
|
||
|
||
config = AppConfig()
|
||
"""
|
||
AI Agent 前端主入口
|
||
采用模块化架构,仅负责组装各组件
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
|
||
# 添加项目根目录到 Python 路径,支持绝对导入
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
import streamlit as st
|
||
|
||
# 使用绝对导入
|
||
from frontend.config import config
|
||
from frontend.state import AppState
|
||
from frontend.components.sidebar import render_sidebar
|
||
from frontend.components.chat_area import render_chat_area
|
||
from frontend.components.info_panel import render_info_panel
|
||
|
||
|
||
# =============================================================================
|
||
# 页面配置
|
||
# =============================================================================
|
||
st.set_page_config(
|
||
page_title=config.page_title,
|
||
page_icon=config.page_icon,
|
||
layout=config.layout
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# 初始化状态
|
||
# =============================================================================
|
||
AppState.init()
|
||
|
||
|
||
# =============================================================================
|
||
# 主界面
|
||
# =============================================================================
|
||
def main():
|
||
"""主界面渲染 - 三栏布局"""
|
||
# 标题
|
||
st.title("🤖 个人生活与数据分析助手")
|
||
|
||
# 三栏布局:左侧栏(1) + 中间栏(3) + 右侧栏(1)
|
||
col_sidebar, col_chat, col_info = st.columns([1, 3, 1])
|
||
|
||
# 左侧栏:用户登录 + 历史对话
|
||
with col_sidebar:
|
||
render_sidebar()
|
||
|
||
# 中间栏:模型选择 + 聊天区域 + 输入框
|
||
with col_chat:
|
||
render_chat_area()
|
||
|
||
# 右侧栏:会话信息 + 统计 + 使用提示
|
||
with col_info:
|
||
render_info_panel()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|