149 lines
4.1 KiB
Python
149 lines
4.1 KiB
Python
"""
|
|
中间聊天区组件
|
|
包含模型选择、消息显示和输入框
|
|
"""
|
|
|
|
import streamlit as st
|
|
|
|
# 使用绝对导入
|
|
from frontend.state import AppState
|
|
from frontend.api_client import api_client
|
|
from frontend.config import config
|
|
|
|
|
|
def render_chat_area():
|
|
"""渲染中间聊天区域"""
|
|
# 模型选择器
|
|
_render_model_selector()
|
|
|
|
st.divider()
|
|
|
|
# 聊天容器
|
|
_render_chat_container()
|
|
|
|
# 输入框
|
|
_render_input_box()
|
|
|
|
|
|
def _render_model_selector():
|
|
"""渲染模型选择器"""
|
|
col_model, col_empty = st.columns([2, 3])
|
|
|
|
with col_model:
|
|
selected_model = st.selectbox(
|
|
"🧠 选择模型",
|
|
options=list(config.model_options.keys()),
|
|
format_func=lambda x: config.model_options[x],
|
|
index=_get_model_index()
|
|
)
|
|
AppState.set_selected_model(selected_model)
|
|
|
|
|
|
def _get_model_index() -> int:
|
|
"""
|
|
获取当前选中模型的索引
|
|
|
|
Returns:
|
|
模型索引
|
|
"""
|
|
current_model = AppState.get_selected_model()
|
|
model_keys = list(config.model_options.keys())
|
|
return model_keys.index(current_model) if current_model in model_keys else 0
|
|
|
|
|
|
def _render_chat_container():
|
|
"""渲染聊天消息容器"""
|
|
chat_container = st.container(height=500)
|
|
|
|
with chat_container:
|
|
messages = AppState.get_messages()
|
|
for msg in messages:
|
|
with st.chat_message(msg["role"]):
|
|
st.markdown(msg["content"])
|
|
|
|
|
|
def _render_input_box():
|
|
"""渲染输入框和流式响应处理"""
|
|
if prompt := st.chat_input("请输入您的问题...", key="chat_input"):
|
|
_handle_user_message(prompt)
|
|
|
|
|
|
def _handle_user_message(prompt: str):
|
|
"""
|
|
处理用户消息
|
|
|
|
Args:
|
|
prompt: 用户输入的消息
|
|
"""
|
|
# 显示用户消息
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
AppState.add_message("user", prompt)
|
|
|
|
# 流式调用 AI 回复
|
|
_handle_ai_response()
|
|
|
|
|
|
def _handle_ai_response():
|
|
"""处理 AI 流式响应"""
|
|
with st.chat_message("assistant"):
|
|
message_placeholder = st.empty()
|
|
tool_status_placeholder = st.empty()
|
|
full_response = ""
|
|
|
|
# 调用流式 API
|
|
stream = api_client.chat_stream(
|
|
message=AppState.get_messages()[-1]["content"],
|
|
thread_id=AppState.get_current_thread_id(),
|
|
model=AppState.get_selected_model(),
|
|
user_id=AppState.get_user_id()
|
|
)
|
|
|
|
# 消费流式响应
|
|
for event in stream:
|
|
event_type = event.get("type")
|
|
|
|
if event_type == "token":
|
|
# 逐字输出
|
|
full_response += event.get("content", "")
|
|
message_placeholder.markdown(full_response + "▌")
|
|
|
|
elif event_type == "tool_start":
|
|
# 工具调用开始
|
|
tool_name = event.get("tool", "")
|
|
tool_status_placeholder.info(f"🔧 调用工具: {tool_name}...")
|
|
|
|
elif event_type == "tool_end":
|
|
# 工具调用完成
|
|
tool_name = event.get("tool", "")
|
|
tool_status_placeholder.success(f"✅ 工具 {tool_name} 完成")
|
|
tool_status_placeholder.empty()
|
|
|
|
elif event_type == "done":
|
|
# 对话完成
|
|
_show_completion_stats(event)
|
|
|
|
elif event_type == "error":
|
|
# 错误处理
|
|
st.error(f"❌ 错误: {event.get('message', '未知错误')}")
|
|
|
|
# 显示完整响应
|
|
message_placeholder.markdown(full_response)
|
|
AppState.add_message("assistant", full_response)
|
|
tool_status_placeholder.empty()
|
|
|
|
|
|
def _show_completion_stats(event: dict):
|
|
"""
|
|
显示对话完成统计信息
|
|
|
|
Args:
|
|
event: 完成事件数据
|
|
"""
|
|
token_usage = event.get("token_usage", {})
|
|
elapsed = event.get("elapsed_time", 0)
|
|
|
|
if token_usage:
|
|
total_tokens = token_usage.get("total_tokens", 0)
|
|
st.caption(f"📊 消耗 {total_tokens} tokens | ⏱️ {elapsed:.2f}s")
|