This commit is contained in:
32
.env
32
.env
@@ -1,3 +1,33 @@
|
||||
# =============================================================================
|
||||
# 本地开发环境配置
|
||||
# 用于 python app/backend.py 和 streamlit run frontend/frontend.py
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AI 模型 API 密钥
|
||||
# -----------------------------------------------------------------------------
|
||||
ZHIPUAI_API_KEY=4d568a4367f1442bbc226cc0daf84566.44SsKVWkVIM2Mkeg
|
||||
DEEPSEEK_API_KEY=sk-e74b13ac778f4b7eb29afa418a14421e
|
||||
VLLM_LOCAL_KEY=token-abc123
|
||||
EOF
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# vLLM 服务配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# 本地开发时,vLLM 通常在 localhost 运行
|
||||
VLLM_BASE_URL=http://localhost:8000/v1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 数据库配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# 本地开发时,数据库在 localhost 运行
|
||||
DB_URI=postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 前端配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# 本地开发时,后端也在 localhost 运行
|
||||
API_URL=http://localhost:8001/chat
|
||||
|
||||
# 本地开发 - 显示所有调试信息
|
||||
LOG_LEVEL=DEBUG
|
||||
DEBUG=true
|
||||
35
.env.docker
Normal file
35
.env.docker
Normal file
@@ -0,0 +1,35 @@
|
||||
# =============================================================================
|
||||
# Docker Compose 服务器部署配置
|
||||
# 用法: cp .env.docker .env 然后修改 API Key
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AI 模型 API 密钥(必需 - 请修改为真实值)
|
||||
# -----------------------------------------------------------------------------
|
||||
ZHIPUAI_API_KEY=your_zhipuai_api_key_here
|
||||
DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
||||
VLLM_LOCAL_KEY=token-abc123
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# vLLM 服务配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# Docker 部署时,如果 vLLM 在宿主机运行,使用 FRP 穿透地址或宿主机 IP
|
||||
# 如果 vLLM 也在 Docker 中,使用 Docker 服务名或容器 IP
|
||||
VLLM_BASE_URL=http://115.190.121.151:18000/v1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 数据库配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# Docker Compose 内部网络,使用服务名 'postgres'
|
||||
DB_URI=postgresql://postgres:mysecretpassword@postgres:5432/langgraph_db?sslmode=disable
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 前端配置(通过 docker-compose.yml 注入,此处仅作文档说明)
|
||||
# -----------------------------------------------------------------------------
|
||||
# 注意:API_URL 在 docker-compose.yml 中已配置为 http://backend:8001/chat
|
||||
# 本地无需设置,Docker 容器启动时会自动注入
|
||||
# API_URL=http://backend:8001/chat
|
||||
|
||||
# 生产环境 - 仅显示关键信息
|
||||
LOG_LEVEL=WARNING
|
||||
DEBUG=false
|
||||
29
.env.example
29
.env.example
@@ -1,29 +0,0 @@
|
||||
# =============================================================================
|
||||
# 环境变量配置模板
|
||||
# 复制此文件为 .env 并填入真实值:cp .env.example .env
|
||||
# =============================================================================
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AI 模型 API 密钥(必需)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# 智谱 AI API 密钥(用于在线模型调用 GLM-4.7-Flash)
|
||||
# 获取地址: https://open.bigmodel.cn/
|
||||
ZHIPUAI_API_KEY=4d568a4367f1442bbc226cc0daf84566.44SsKVWkVIM2Mkeg
|
||||
|
||||
# 本地 vLLM 服务认证 Token(用于本地 Gemma 模型调用)
|
||||
# 如果使用本地 vLLM 容器,需要设置此值与 vLLM 容器的 --api-key 参数一致
|
||||
VLLM_LOCAL_KEY=token-abc123
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 数据库配置(可选 - 代码中有默认值)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# PostgreSQL 数据库连接字符串
|
||||
# Docker Compose 部署时使用服务名 'postgres':
|
||||
# DB_URI=postgresql://postgres:mysecretpassword@postgres:5432/langgraph_db?sslmode=disable
|
||||
|
||||
# 本地开发时使用 localhost:
|
||||
# DB_URI=postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable
|
||||
|
||||
# 如果不设置,代码将使用默认值(Docker 环境指向 postgres 服务)
|
||||
132
QUICKSTART.md
132
QUICKSTART.md
@@ -13,8 +13,8 @@
|
||||
#### 1. 配置环境变量
|
||||
|
||||
```bash
|
||||
# 复制模板文件
|
||||
cp .env.example .env
|
||||
# 复制 Docker 部署模板
|
||||
cp .env.docker .env
|
||||
|
||||
# 编辑 .env 文件,填入真实的 API Key
|
||||
vim .env # 或使用你喜欢的编辑器
|
||||
@@ -25,7 +25,10 @@ vim .env # 或使用你喜欢的编辑器
|
||||
- `VLLM_LOCAL_KEY` - 本地 vLLM 服务认证 Token(与 vLLM 容器的 `--api-key` 参数一致)
|
||||
|
||||
**可选配置项**:
|
||||
- `DB_URI` - PostgreSQL 连接字符串(默认已配置,通常无需修改)
|
||||
- `VLLM_BASE_URL` - vLLM 服务地址(默认已配置为 FRP 穿透地址)
|
||||
- `DB_URI` - PostgreSQL 连接字符串(默认已配置,使用 Docker 服务名 `postgres`)
|
||||
|
||||
**注意**:Docker Compose 部署时,`API_URL` 由 `docker-compose.yml` 自动注入,无需在 `.env` 中配置。
|
||||
|
||||
#### 2. 启动服务
|
||||
|
||||
@@ -91,19 +94,24 @@ pip install -r requirement.txt
|
||||
|
||||
复制并编辑 `.env` 文件:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
# 基于 Docker 模板创建,然后修改为本地配置
|
||||
cp .env.docker .env
|
||||
vim .env
|
||||
```
|
||||
|
||||
**本地开发需要额外配置数据库连接**:
|
||||
**本地开发需要修改以下配置**:
|
||||
|
||||
```env
|
||||
ZHIPUAI_API_KEY=your_api_key_here
|
||||
VLLM_LOCAL_KEY=token-abc123
|
||||
|
||||
# 本地开发时,数据库主机改为 localhost
|
||||
# 本地开发时,vLLM 和数据库都在 localhost
|
||||
VLLM_BASE_URL=http://localhost:8000/v1
|
||||
DB_URI=postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable
|
||||
|
||||
# 本地开发时,后端也在 localhost
|
||||
API_URL=http://localhost:8001/chat
|
||||
```
|
||||
|
||||
#### 4. 启动服务
|
||||
@@ -310,20 +318,70 @@ lsof -i :8001
|
||||
- 端口 8001 被占用
|
||||
- 依赖包缺失
|
||||
|
||||
#### 3. 前端无法连接后端
|
||||
#### 3. 前端无法连接后端(NameResolutionError)
|
||||
|
||||
```bash
|
||||
# 检查后端是否正常运行
|
||||
curl http://localhost:8001/
|
||||
|
||||
# 检查网络连接
|
||||
docker compose exec frontend ping backend
|
||||
**错误信息:**
|
||||
```
|
||||
HTTPConnectionPool(host='backend', port=8001): Max retries exceeded with url: /chat
|
||||
(Caused by NameResolutionError("HTTPConnection(host='backend', port=8001): Failed to resolve 'backend'"))
|
||||
```
|
||||
|
||||
**原因分析:**
|
||||
- 前端容器和后端容器不在同一个 Docker 网络中
|
||||
- docker-compose.yml 中的服务名配置错误
|
||||
- 环境变量 `API_URL` 配置不正确
|
||||
|
||||
**解决方案:**
|
||||
- 确认后端服务已启动
|
||||
- 检查防火墙设置
|
||||
- 重启前端服务
|
||||
|
||||
1. **检查容器是否在同一网络中:**
|
||||
```bash
|
||||
# 查看所有 Docker 网络
|
||||
docker network ls
|
||||
|
||||
# 检查 ai-network 网络中的容器
|
||||
docker network inspect docker_ai-network
|
||||
```
|
||||
|
||||
2. **确认服务名正确:**
|
||||
```bash
|
||||
# 查看运行中的容器
|
||||
docker compose ps
|
||||
|
||||
# 应该看到:ai-backend, ai-frontend, ai-postgres
|
||||
```
|
||||
|
||||
3. **验证环境变量配置:**
|
||||
```bash
|
||||
# 进入前端容器检查环境变量
|
||||
docker compose exec frontend env | grep API_URL
|
||||
|
||||
# 应该输出:API_URL=http://backend:8001/chat
|
||||
```
|
||||
|
||||
4. **重启服务:**
|
||||
```bash
|
||||
# 完全停止并重新启动所有服务
|
||||
docker compose down
|
||||
docker compose up -d --build
|
||||
|
||||
# 查看启动日志
|
||||
docker compose logs -f
|
||||
```
|
||||
|
||||
5. **测试网络连通性:**
|
||||
```bash
|
||||
# 从前端容器 ping 后端服务
|
||||
docker compose exec frontend ping backend
|
||||
|
||||
# 从前端容器访问后端 API
|
||||
docker compose exec frontend curl http://backend:8001/health
|
||||
```
|
||||
|
||||
**重要提示:**
|
||||
- Docker Compose 会自动创建名为 `<项目目录>_ai-network` 的网络
|
||||
- 容器间通过**服务名**(而非容器名)进行通信
|
||||
- 在 `docker-compose.yml` 中,服务名是 `backend`、`frontend`、`postgres`
|
||||
- 确保所有服务都连接到同一个自定义网络(`ai-network`)
|
||||
|
||||
#### 4. 模型初始化失败
|
||||
|
||||
@@ -337,6 +395,46 @@ docker compose logs backend | grep -i "model\|error"
|
||||
- vLLM 容器未启动(如使用本地模型)
|
||||
- 网络连接问题
|
||||
|
||||
#### 5. 环境变量未生效
|
||||
|
||||
**症状:**
|
||||
- 服务启动时提示缺少必需的环境变量
|
||||
- API Key 为空或使用默认值
|
||||
|
||||
**解决方案:**
|
||||
|
||||
1. **检查 .env 文件格式:**
|
||||
```bash
|
||||
# 确保文件末尾没有多余字符(如 EOF)
|
||||
cat -A .env
|
||||
|
||||
# 正确格式应该是每行一个变量,无多余空格或特殊字符
|
||||
```
|
||||
|
||||
2. **验证环境变量已加载:**
|
||||
```bash
|
||||
# 检查后端容器的环境变量
|
||||
docker compose exec backend env | grep ZHIPUAI_API_KEY
|
||||
|
||||
# 检查前端容器的环境变量
|
||||
docker compose exec frontend env | grep API_URL
|
||||
```
|
||||
|
||||
3. **重新构建容器:**
|
||||
```bash
|
||||
# 修改 .env 后需要重新创建容器
|
||||
docker compose down
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
4. **确认 .env 文件位置:**
|
||||
```bash
|
||||
# .env 文件应该在项目根目录(与 docker-compose.yml 的父目录同级)
|
||||
ls -la .env
|
||||
|
||||
# docker-compose.yml 中使用了 context: .. ,所以 .env 应该在上一级目录
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 监控和维护
|
||||
|
||||
55
README.md
55
README.md
@@ -226,30 +226,57 @@ model_configs = {
|
||||
|
||||
## ⚙️ 环境配置
|
||||
|
||||
### 配置文件说明
|
||||
|
||||
项目使用两个环境配置文件:
|
||||
|
||||
| 文件 | 用途 | 是否提交 Git |
|
||||
|------|------|------------|
|
||||
| `.env` | 实际使用的配置 | ❌ 否(已忽略) |
|
||||
| `.env.docker` | Docker 部署模板 | ✅ 是 |
|
||||
|
||||
**使用方法:**
|
||||
|
||||
- **本地开发**:手动创建 `.env`,配置 `localhost` 相关地址
|
||||
- **Docker 部署**:`cp .env.docker .env`,然后修改 API Key
|
||||
|
||||
### 必需的环境变量
|
||||
|
||||
在 `.env` 文件中配置:
|
||||
代码中所有使用 `os.getenv()` 的地方都必须在 `.env` 文件中定义:
|
||||
|
||||
```
|
||||
# 智谱 AI API Key(必需)
|
||||
| 变量名 | 说明 | 本地开发示例 | Docker 部署示例 |
|
||||
|--------|------|------------|----------------|
|
||||
| `ZHIPUAI_API_KEY` | 智谱 AI API 密钥 | `your_key_here` | `your_key_here` |
|
||||
| `VLLM_LOCAL_KEY` | vLLM 认证 Token | `token-abc123` | `token-abc123` |
|
||||
| `VLLM_BASE_URL` | vLLM 服务地址 | `http://localhost:8000/v1` | `http://115.190.121.151:18000/v1` |
|
||||
| `DB_URI` | PostgreSQL 连接字符串 | `postgresql://...@localhost:5432/...` | `postgresql://...@postgres:5432/...` |
|
||||
| `API_URL` | 后端 API 地址 | `http://localhost:8001/chat` | (由 docker-compose.yml 注入) |
|
||||
|
||||
### 配置示例
|
||||
|
||||
#### 本地开发 (.env)
|
||||
```bash
|
||||
ZHIPUAI_API_KEY=your_api_key_here
|
||||
|
||||
# vLLM 本地模型 Token(可选)
|
||||
VLLM_LOCAL_KEY=token-abc123
|
||||
VLLM_BASE_URL=http://localhost:8000/v1
|
||||
DB_URI=postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable
|
||||
API_URL=http://localhost:8001/chat
|
||||
```
|
||||
|
||||
### 数据库配置
|
||||
|
||||
默认使用 PostgreSQL,连接字符串:
|
||||
```
|
||||
postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db
|
||||
#### Docker 部署 (.env.docker)
|
||||
```bash
|
||||
ZHIPUAI_API_KEY=your_api_key_here
|
||||
VLLM_LOCAL_KEY=token-abc123
|
||||
VLLM_BASE_URL=http://115.190.121.151:18000/v1
|
||||
DB_URI=postgresql://postgres:mysecretpassword@postgres:5432/langgraph_db?sslmode=disable
|
||||
# API_URL 在 docker-compose.yml 中配置为 http://backend:8001/chat
|
||||
```
|
||||
|
||||
**注意**:
|
||||
- **本地开发模式**:使用 `localhost` 或 `127.0.0.1`
|
||||
- **Docker Compose 部署**:后端容器内应使用服务名 `postgres`(通过环境变量 `DB_URI` 自动配置)
|
||||
### 注意事项
|
||||
|
||||
如使用 Docker Compose,数据库会在内部网络中自动配置。
|
||||
- ⚠️ **不要硬编码敏感信息**:所有 API Key 必须通过环境变量配置
|
||||
- ⚠️ **Docker 网络差异**:容器内使用服务名(如 `postgres`、`backend`),本地使用 `localhost`
|
||||
- ⚠️ **修改后重启**:修改 `.env` 后,Docker 部署需要执行 `docker compose down && docker compose up -d --build`
|
||||
|
||||
---
|
||||
|
||||
|
||||
94
app/agent.py
94
app/agent.py
@@ -11,8 +11,11 @@ from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
# 本地模块
|
||||
from app.graph_builder import GraphBuilder
|
||||
from app.graph_builder import GraphBuilder, GraphContext
|
||||
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from app.logger import debug, info, warning, error
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -20,13 +23,15 @@ load_dotenv()
|
||||
class AIAgentService:
|
||||
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
|
||||
|
||||
def __init__(self, checkpointer):
|
||||
def __init__(self, checkpointer: AsyncPostgresSaver, store: AsyncPostgresStore):
|
||||
"""
|
||||
初始化服务
|
||||
Args:
|
||||
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
|
||||
store: 已经初始化的 AsyncPostgresStore 实例
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.store = store
|
||||
self.graphs = {} # 存储不同模型对应的 graph 实例
|
||||
|
||||
def _create_zhipu_llm(self):
|
||||
@@ -39,49 +44,108 @@ class AIAgentService:
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
)
|
||||
|
||||
def _create_deepseek_llm(self):
|
||||
"""创建 DeepSeek LLM(使用 OpenAI 兼容 API)"""
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set in environment")
|
||||
return ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(api_key),
|
||||
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
)
|
||||
|
||||
def _create_local_llm(self):
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://115.190.121.151:18000/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
# 原来是 http://localhost:8000/v1
|
||||
# 改为 FRP 穿透后的公网地址
|
||||
base_url = "http://115.190.121.151:18000/v1",
|
||||
base_url=vllm_base_url,
|
||||
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer 和 store)"""
|
||||
model_configs = {
|
||||
"zhipu": self._create_zhipu_llm,
|
||||
"deepseek": self._create_deepseek_llm,
|
||||
"local": self._create_local_llm,
|
||||
}
|
||||
|
||||
for model_name, llm_creator in model_configs.items():
|
||||
try:
|
||||
info(f"🔄 正在初始化模型 '{model_name}'...")
|
||||
llm = llm_creator()
|
||||
|
||||
# 测试 LLM 连接(可选,用于调试)
|
||||
if model_name == "local":
|
||||
debug(f" 测试 vLLM 连接: {os.getenv('VLLM_BASE_URL', '未设置')}")
|
||||
elif model_name == "deepseek":
|
||||
debug(f" 测试 DeepSeek API 连接: https://api.deepseek.com")
|
||||
|
||||
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
graph = builder.compile(checkpointer=self.checkpointer, store=self.store)
|
||||
self.graphs[model_name] = graph
|
||||
print(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
info(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
|
||||
import traceback
|
||||
error_detail = traceback.format_exc()
|
||||
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
|
||||
debug(f" 详细错误:\n{error_detail}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型,请检查配置")
|
||||
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
|
||||
"1. ZHIPUAI_API_KEY 未配置或无效\n"
|
||||
"2. DEEPSEEK_API_KEY 未配置或无效\n"
|
||||
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
|
||||
"4. 网络连接问题")
|
||||
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
|
||||
"""处理用户消息,返回最终答案"""
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
|
||||
"""
|
||||
处理用户消息,返回包含回复、token统计和耗时的字典
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"reply": str, # AI 回复内容
|
||||
"token_usage": dict, # Token 使用详情
|
||||
"elapsed_time": float # 调用耗时(秒)
|
||||
}
|
||||
"""
|
||||
if model not in self.graphs:
|
||||
fallback_model = next(iter(self.graphs.keys()))
|
||||
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
|
||||
warning(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
|
||||
model = fallback_model
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
input_state = {"messages": [HumanMessage(content=message)]}
|
||||
result = await graph.ainvoke(input_state, config=config)
|
||||
return result["messages"][-1].content
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
result = await graph.ainvoke(input_state, config=config, context=context)
|
||||
|
||||
reply = result["messages"][-1].content
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@@ -7,17 +7,23 @@ import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||||
from app.agent import AIAgentService
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
# PostgreSQL 连接字符串(优先从环境变量读取,适配 Docker 和本地开发)
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# PostgreSQL 连接字符串配置
|
||||
# 优先级:环境变量 DB_URI > Docker 内部服务名 > 本地开发地址
|
||||
DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:mysecretpassword@postgres:5432/langgraph_db?sslmode=disable"
|
||||
"postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
|
||||
|
||||
@@ -25,11 +31,15 @@ DB_URI = os.getenv(
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:创建并注入全局服务"""
|
||||
# 1. 创建数据库连接池并初始化表
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
async with (
|
||||
AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer,
|
||||
AsyncPostgresStore.from_conn_string(DB_URI) as store
|
||||
):
|
||||
await checkpointer.setup()
|
||||
await store.setup()
|
||||
|
||||
# 2. 构建 AI Agent 服务
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
agent_service = AIAgentService(checkpointer,store)
|
||||
await agent_service.initialize()
|
||||
|
||||
# 3. 将服务实例存入 app.state
|
||||
@@ -39,7 +49,7 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
# 4. 关闭时自动清理数据库连接(async with 负责)
|
||||
print("🛑 应用关闭,数据库连接池已释放")
|
||||
info("🛑 应用关闭,数据库连接池已释放")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
@@ -66,12 +76,17 @@ class ChatRequest(BaseModel):
|
||||
message: str
|
||||
thread_id: str | None = None
|
||||
model: str = "zhipu"
|
||||
user_id: str = "default_user"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str
|
||||
thread_id: str
|
||||
model_used: str
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
# ========== 依赖注入函数 ==========
|
||||
@@ -91,11 +106,27 @@ async def chat_endpoint(
|
||||
raise HTTPException(status_code=400, detail="message required")
|
||||
|
||||
thread_id = request.thread_id or str(uuid.uuid4())
|
||||
reply = await agent_service.process_message(
|
||||
request.message, thread_id, request.model
|
||||
result = await agent_service.process_message(
|
||||
request.message, thread_id, request.model, request.user_id
|
||||
)
|
||||
|
||||
# 提取 token 统计信息
|
||||
token_usage = result.get("token_usage", {})
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
elapsed_time = result.get("elapsed_time", 0.0)
|
||||
|
||||
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model)
|
||||
|
||||
return ChatResponse(
|
||||
reply=result["reply"],
|
||||
thread_id=thread_id,
|
||||
model_used=actual_model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
elapsed_time=elapsed_time
|
||||
)
|
||||
|
||||
|
||||
# ========== WebSocket 端点(可选) ==========
|
||||
@@ -111,10 +142,11 @@ async def websocket_endpoint(
|
||||
message = data.get("message")
|
||||
thread_id = data.get("thread_id", str(uuid.uuid4()))
|
||||
model = data.get("model", "zhipu")
|
||||
user_id = data.get("user_id", "default_user")
|
||||
if not message:
|
||||
await websocket.send_json({"error": "missing message"})
|
||||
continue
|
||||
reply = await agent_service.process_message(message, thread_id, model)
|
||||
reply = await agent_service.process_message(message, thread_id, model, user_id)
|
||||
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
|
||||
except WebSocketDisconnect:
|
||||
|
||||
@@ -4,19 +4,34 @@ LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
|
||||
|
||||
import operator
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Literal, Annotated, Any
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||||
from langgraph.runtime import Runtime
|
||||
from dataclasses import dataclass
|
||||
import uuid
|
||||
|
||||
# 本地模块
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
|
||||
class MessageState(TypedDict):
|
||||
class MessagesState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
llm_calls: int
|
||||
memory_context:str
|
||||
last_token_usage: dict # 本次调用的 token 使用详情
|
||||
last_elapsed_time: float # 本次调用耗时(秒)
|
||||
|
||||
@dataclass
|
||||
class GraphContext:
|
||||
user_id: str
|
||||
# 可扩展更多上下文信息
|
||||
|
||||
class GraphBuilder:
|
||||
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
|
||||
@@ -42,33 +57,132 @@ class GraphBuilder:
|
||||
"""创建系统提示模板(静态方法,无需访问实例)"""
|
||||
return ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=(
|
||||
"你是一个个人生活助手和数据分析助手。请说中文。"
|
||||
"当用户询问天气或温度时,使用get_current_temperature工具获取信息。"
|
||||
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
||||
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
|
||||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
"- 获取温度/天气:`get_current_temperature`\n"
|
||||
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n"
|
||||
"- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n"
|
||||
"- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n"
|
||||
"- 抓取网页内容:`fetch_webpage_content`\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
|
||||
"2. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"3. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)),
|
||||
MessagesPlaceholder(variable_name="message")
|
||||
MessagesPlaceholder(variable_name="messages")
|
||||
])
|
||||
|
||||
async def call_llm(self, state: MessageState) -> dict:
|
||||
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._chain.invoke({"message": state["messages"]})
|
||||
)
|
||||
return {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1
|
||||
}
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
|
||||
async def call_tools(self, state: MessageState) -> dict:
|
||||
# 构建完整的输入消息列表(用于调试打印)
|
||||
system_prompt = self._prompt.messages[0] # SystemMessage
|
||||
if isinstance(system_prompt, SystemMessage):
|
||||
system_content = system_prompt.content.format(memory_context=memory_context)
|
||||
else:
|
||||
system_content = str(system_prompt.content)
|
||||
|
||||
input_messages = [SystemMessage(content=system_content)] + state["messages"]
|
||||
|
||||
# 打印发送给大模型的最终输入
|
||||
debug("\n" + "="*80)
|
||||
debug("📤 [LLM输入] 发送给大模型的完整消息:")
|
||||
debug(f" 总消息数: {len(input_messages)}")
|
||||
for i, msg in enumerate(input_messages):
|
||||
content_preview = str(msg.content) # 不截断,完整输出
|
||||
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._chain.invoke({
|
||||
"messages": state["messages"],
|
||||
"memory_context": memory_context
|
||||
})
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
if 'token_usage' in meta:
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
return {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||||
"last_token_usage": token_usage,
|
||||
"last_elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug("="*80 + "\n")
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
)
|
||||
return {
|
||||
"messages": [error_response],
|
||||
"llm_calls": state.get('llm_calls', 0),
|
||||
"last_token_usage": {},
|
||||
"last_elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
对于每个工具调用,在线程池中执行同步工具函数
|
||||
@@ -91,11 +205,15 @@ class GraphBuilder:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 同步工具函数在线程池中执行
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: tool_func.invoke(tool_args)
|
||||
)
|
||||
# 修复闭包问题:将变量作为默认参数传入 lambda
|
||||
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
||||
if hasattr(tool_func, 'ainvoke'):
|
||||
observation = await tool_func.ainvoke(tool_args)
|
||||
else:
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
|
||||
)
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
except Exception as e:
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
@@ -103,25 +221,101 @@ class GraphBuilder:
|
||||
return {"messages": results}
|
||||
|
||||
@staticmethod
|
||||
def should_continue(state: MessageState) -> Literal['tool_node', END]:
|
||||
"""
|
||||
条件边判断(静态方法)
|
||||
决定下一步是进入工具节点还是结束
|
||||
"""
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
|
||||
"""决定下一步:工具调用、保存记忆还是结束"""
|
||||
last_message = state["messages"][-1]
|
||||
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
|
||||
|
||||
# 1. 如果需要调用工具,优先进入工具节点
|
||||
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
||||
return 'tool_node'
|
||||
return END
|
||||
|
||||
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
|
||||
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
|
||||
if isinstance(last_message, AIMessage):
|
||||
return 'save_memory'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
return 'END'
|
||||
|
||||
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""搜索并返回长期记忆"""
|
||||
user_id = runtime.context.user_id
|
||||
namespace = ("memories", user_id)
|
||||
query = str(state["messages"][-1].content)
|
||||
|
||||
debug(f"\n{'='*60}")
|
||||
debug(f"🔎 [记忆检索] 开始检索")
|
||||
debug(f" ├─ 用户ID: {user_id}")
|
||||
debug(f" ├─ 命名空间: {namespace}")
|
||||
debug(f" ├─ 查询内容: '{query}'")
|
||||
debug(f" └─ 消息总数: {len(state['messages'])}")
|
||||
|
||||
try:
|
||||
memories = await runtime.store.asearch(namespace, query=query)
|
||||
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
|
||||
|
||||
if memories:
|
||||
memory_text = "\n".join([m.value["data"] for m in memories])
|
||||
debug(f"📚 [记忆内容]")
|
||||
for i, memory in enumerate(memories, 1):
|
||||
debug(f" [{i}] {memory.value['data']}")
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": memory_text}
|
||||
else:
|
||||
debug(f"⚠️ [记忆检索] 未找到相关记忆")
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": ""}
|
||||
|
||||
except Exception as e:
|
||||
error(f"❌ [记忆检索] 检索失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": ""}
|
||||
|
||||
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""尝试从对话中提取并保存长期记忆"""
|
||||
# 获取最后一条用户消息(通常是要记住的内容的来源)
|
||||
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
|
||||
if not user_messages:
|
||||
return {}
|
||||
|
||||
last_user_msg = user_messages[-1].content.lower()
|
||||
|
||||
# 简单触发逻辑:包含"记住"或"保存"等关键词
|
||||
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
|
||||
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
|
||||
memory_content = f"用户说过:{last_user_msg}"
|
||||
user_id = runtime.context.user_id
|
||||
namespace = ("memories", user_id)
|
||||
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
|
||||
info(f"✅ 长期记忆已保存:{memory_content}")
|
||||
|
||||
return {}
|
||||
|
||||
def build(self) -> StateGraph:
|
||||
"""
|
||||
构建未编译的状态图(返回 StateGraph 实例)
|
||||
图中节点直接使用实例方法 call_llm, call_tools
|
||||
"""
|
||||
builder = StateGraph(MessageState)
|
||||
builder = StateGraph(MessagesState,context_schema=GraphContext)
|
||||
builder.add_node("retrieve_memory", self.retrieve_memory)
|
||||
builder.add_node("llm_call", self.call_llm)
|
||||
builder.add_node("tool_node", self.call_tools)
|
||||
builder.add_edge(START, "llm_call")
|
||||
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
|
||||
builder.add_node("save_memory", self.save_memory)
|
||||
|
||||
builder.add_edge(START, "retrieve_memory")
|
||||
builder.add_edge("retrieve_memory", "llm_call")
|
||||
builder.add_conditional_edges(
|
||||
"llm_call",
|
||||
self.should_continue,
|
||||
{
|
||||
"tool_node": "tool_node",
|
||||
"save_memory": "save_memory",
|
||||
'END': END
|
||||
}
|
||||
)
|
||||
builder.add_edge("tool_node", "llm_call")
|
||||
builder.add_edge("save_memory", END)
|
||||
|
||||
return builder
|
||||
55
app/logger.py
Normal file
55
app/logger.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
统一的日志模块 - 基于环境变量控制日志级别
|
||||
类似 C# 的条件编译效果,开发时打印详细调试信息,生产环境只输出关键信息
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 先加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 从环境变量读取日志级别,默认 INFO
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
|
||||
# 根据环境变量控制是否显示详细调试信息
|
||||
DEBUG_MODE = os.getenv("DEBUG", "false").lower() == "true"
|
||||
|
||||
# 创建统一的日志器
|
||||
logger = logging.getLogger("ai_agent")
|
||||
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
||||
|
||||
# 避免重复添加 handler
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
# 重要:handler 也需要设置级别,否则可能继承根 logger 的级别
|
||||
handler.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def debug(msg: Any, *args, **kwargs):
|
||||
"""调试日志,仅在 DEBUG 环境变量为 true 时打印"""
|
||||
if DEBUG_MODE:
|
||||
logger.debug(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(msg: Any, *args, **kwargs):
|
||||
"""信息日志"""
|
||||
logger.info(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(msg: Any, *args, **kwargs):
|
||||
"""警告日志"""
|
||||
logger.warning(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(msg: Any, *args, **kwargs):
|
||||
"""错误日志"""
|
||||
logger.error(msg, *args, **kwargs)
|
||||
@@ -46,7 +46,8 @@ services:
|
||||
dockerfile: docker/Dockerfile.frontend
|
||||
container_name: ai-frontend
|
||||
environment:
|
||||
- API_URL=http://backend:8001/chat # Docker 内部使用服务名解析
|
||||
# Docker 内部网络使用服务名 'backend' 解析后端服务
|
||||
- API_URL=http://backend:8001/chat
|
||||
ports:
|
||||
- "8501:8501"
|
||||
networks:
|
||||
|
||||
@@ -7,12 +7,16 @@ import os
|
||||
import uuid
|
||||
|
||||
# 第三方库
|
||||
from dotenv import load_dotenv
|
||||
import requests
|
||||
import streamlit as st
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# 后端 API 地址:优先从环境变量读取,Docker 环境使用服务名,本地开发可覆盖
|
||||
API_URL = os.getenv("API_URL", "http://backend:8001/chat")
|
||||
# 后端 API 地址配置
|
||||
# 优先级:环境变量 API_URL > Docker 内部服务名 > 本地开发地址
|
||||
API_URL = os.getenv("API_URL", "http://localhost:8001/chat")
|
||||
|
||||
st.set_page_config(page_title="AI 个人助手", page_icon="🤖")
|
||||
st.title("🤖 个人生活与数据分析助手")
|
||||
@@ -20,6 +24,7 @@ st.title("🤖 个人生活与数据分析助手")
|
||||
# 模型选项(与后端支持的模型名称一致)
|
||||
MODEL_OPTIONS = {
|
||||
"zhipu": "智谱 GLM-4.7-Flash(在线)",
|
||||
"deepseek": "DeepSeek V3.2(在线)",
|
||||
"local": "本地 vLLM(Gemma-4)"
|
||||
}
|
||||
|
||||
@@ -82,12 +87,20 @@ if prompt := st.chat_input("请输入您的问题..."):
|
||||
data = response.json()
|
||||
reply = data["reply"]
|
||||
model_used = data["model_used"]
|
||||
input_tokens = data.get("input_tokens", 0)
|
||||
output_tokens = data.get("output_tokens", 0)
|
||||
total_tokens = data.get("total_tokens", 0)
|
||||
elapsed_time = data.get("elapsed_time", 0.0)
|
||||
|
||||
# 显示回复
|
||||
st.markdown(reply)
|
||||
|
||||
# 显示使用的模型(小字提示)
|
||||
st.caption(f"🤖 使用模型: {MODEL_OPTIONS.get(model_used, model_used)}")
|
||||
# 显示使用的模型和性能指标
|
||||
stats_text = f"🤖 模型: {MODEL_OPTIONS.get(model_used, model_used)}"
|
||||
stats_text += f" | ⏱️ 耗时: {elapsed_time:.2f}s"
|
||||
if total_tokens > 0:
|
||||
stats_text += f" | 📊 Tokens: {input_tokens}(输入) + {output_tokens}(输出) = {total_tokens}(总计)"
|
||||
st.caption(stats_text)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": reply})
|
||||
except Exception as e:
|
||||
|
||||
380
scripts/start.sh
380
scripts/start.sh
@@ -1,7 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
# AI Agent 启动脚本
|
||||
# 用法: ./start.sh [backend|frontend|both]
|
||||
# =============================================================================
|
||||
# AI Agent 启动与管理脚本
|
||||
# 用法: ./start.sh [check|backend|frontend|both|docker-up|docker-down]
|
||||
# =============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
@@ -12,100 +13,296 @@ RED='\033[0;31m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# 项目根目录
|
||||
PROJECT_DIR="/home/huang/Study/AIProject/Agent1"
|
||||
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo -e "${BLUE} AI Agent - 个人生活助手启动脚本${NC}"
|
||||
echo -e "${BLUE} AI Agent - 个人生活助手${NC}"
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查 vLLM 容器是否运行
|
||||
check_vllm() {
|
||||
if ! docker ps --format '{{.Names}}' | grep -q "^gemma4-server$"; then
|
||||
echo -e "${YELLOW}⚠️ vLLM 容器未运行!${NC}"
|
||||
echo "正在启动 vLLM 容器(Gemma-4 模型)..."
|
||||
# =============================================================================
|
||||
# 配置检查函数
|
||||
# =============================================================================
|
||||
check_config() {
|
||||
echo -e "${BLUE}📋 开始环境配置检查...${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查模型文件是否存在
|
||||
if [ ! -d "/home/huang/Study/AIModel/gemma-4-E2B-it" ]; then
|
||||
echo -e "${RED}✗ 错误:模型目录不存在: /home/huang/Study/AIModel/gemma-4-E2B-it${NC}"
|
||||
echo "请先下载模型或修改模型路径"
|
||||
exit 1
|
||||
PASS=0
|
||||
FAIL=0
|
||||
WARN=0
|
||||
|
||||
# 辅助函数
|
||||
check_pass() {
|
||||
echo -e "${GREEN}✓${NC} $1"
|
||||
((PASS++))
|
||||
}
|
||||
|
||||
check_fail() {
|
||||
echo -e "${RED}✗${NC} $1"
|
||||
((FAIL++))
|
||||
}
|
||||
|
||||
check_warn() {
|
||||
echo -e "${YELLOW}⚠${NC} $1"
|
||||
((WARN++))
|
||||
}
|
||||
|
||||
# 1. 检查 .env 文件
|
||||
echo "🔍 检查配置文件..."
|
||||
if [ -f "$PROJECT_DIR/.env" ]; then
|
||||
check_pass ".env 文件存在"
|
||||
|
||||
# 检查文件格式
|
||||
if grep -q "^EOF" .env 2>/dev/null; then
|
||||
check_fail ".env 文件格式错误:发现多余的 EOF 标记"
|
||||
else
|
||||
check_pass ".env 文件格式正确"
|
||||
fi
|
||||
|
||||
docker run -d \
|
||||
--name gemma4-server \
|
||||
--group-add=video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
-v /home/huang/Study/AIModel/gemma-4-E2B-it:/models/gemma-4-E2B-it \
|
||||
-e VLLM_ROCM_USE_AITER=0 \
|
||||
-e HF_TOKEN="${HF_TOKEN}" \
|
||||
-p 8000:8000 \
|
||||
--ipc=host \
|
||||
--entrypoint vllm \
|
||||
my-vllm-gemma4:working \
|
||||
serve /models/gemma-4-E2B-it \
|
||||
--served-model-name gemma-4-E2B-it \
|
||||
--dtype auto \
|
||||
--api-key token-abc123 \
|
||||
--trust-remote-code \
|
||||
--port 8000 \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--max-model-len 8192
|
||||
|
||||
echo -e "${GREEN}✓ vLLM 容器已启动${NC}"
|
||||
echo -e "${YELLOW}⏳ 等待模型加载(可能需要几分钟)...${NC}"
|
||||
sleep 10
|
||||
else
|
||||
echo -e "${GREEN}✓ vLLM 容器正在运行${NC}"
|
||||
check_fail ".env 文件不存在"
|
||||
echo " 提示: 请创建 .env 文件并配置环境变量"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# 2. 检查必需的环境变量
|
||||
echo ""
|
||||
echo "🔑 检查环境变量..."
|
||||
|
||||
# 检查 ZHIPUAI_API_KEY
|
||||
if grep -q "^ZHIPUAI_API_KEY=" "$PROJECT_DIR/.env" 2>/dev/null; then
|
||||
API_KEY=$(grep "^ZHIPUAI_API_KEY=" "$PROJECT_DIR/.env" | head -1 | cut -d'=' -f2- | tr -d '[:space:]')
|
||||
if [ ${#API_KEY} -gt 10 ]; then
|
||||
check_pass "ZHIPUAI_API_KEY 已配置(长度: ${#API_KEY})"
|
||||
else
|
||||
check_fail "ZHIPUAI_API_KEY 配置可能无效(过短)"
|
||||
fi
|
||||
else
|
||||
check_fail "ZHIPUAI_API_KEY 未配置或格式错误"
|
||||
fi
|
||||
|
||||
# 检查 VLLM_LOCAL_KEY
|
||||
if grep -q "^VLLM_LOCAL_KEY=" "$PROJECT_DIR/.env" 2>/dev/null; then
|
||||
check_pass "VLLM_LOCAL_KEY 已配置"
|
||||
else
|
||||
check_warn "VLLM_LOCAL_KEY 未配置(如不使用本地模型可忽略)"
|
||||
fi
|
||||
|
||||
# 检查 DB_URI
|
||||
if grep -q "^DB_URI=" "$PROJECT_DIR/.env" 2>/dev/null; then
|
||||
check_pass "DB_URI 已配置"
|
||||
else
|
||||
check_warn "DB_URI 未配置(将使用默认值)"
|
||||
fi
|
||||
|
||||
# 3. 检查 Docker 环境
|
||||
echo ""
|
||||
echo "🐳 检查 Docker 环境..."
|
||||
|
||||
if command -v docker &> /dev/null; then
|
||||
check_pass "Docker 已安装"
|
||||
|
||||
if docker info &> /dev/null; then
|
||||
check_pass "Docker 守护进程正在运行"
|
||||
else
|
||||
check_fail "Docker 守护进程未运行"
|
||||
echo " 提示: sudo systemctl start docker"
|
||||
fi
|
||||
else
|
||||
check_fail "Docker 未安装"
|
||||
fi
|
||||
|
||||
if command -v docker compose version &> /dev/null || command -v docker-compose &> /dev/null; then
|
||||
check_pass "Docker Compose 已安装"
|
||||
else
|
||||
check_fail "Docker Compose 未安装"
|
||||
fi
|
||||
|
||||
# 4. 检查端口占用
|
||||
echo ""
|
||||
echo "🔌 检查端口占用..."
|
||||
|
||||
for port in 8001 8501; do
|
||||
if lsof -i :$port &> /dev/null; then
|
||||
check_warn "端口 $port 已被占用"
|
||||
else
|
||||
check_pass "端口 $port 可用"
|
||||
fi
|
||||
done
|
||||
|
||||
# 总结
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " 检查结果汇总"
|
||||
echo "=========================================="
|
||||
echo -e "${GREEN}通过: $PASS${NC}"
|
||||
echo -e "${RED}失败: $FAIL${NC}"
|
||||
echo -e "${YELLOW}警告: $WARN${NC}"
|
||||
echo ""
|
||||
|
||||
if [ $FAIL -eq 0 ]; then
|
||||
echo -e "${GREEN}✅ 配置检查通过!${NC}"
|
||||
return 0
|
||||
else
|
||||
echo -e "${RED}❌ 发现 $FAIL 个错误,请修复后重试${NC}"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Docker 容器检查函数
|
||||
# =============================================================================
|
||||
check_vllm() {
|
||||
echo -e "${BLUE}🔍 检查 vLLM 容器...${NC}"
|
||||
if ! docker ps --format '{{.Names}}' | grep -q "^gemma4-server$"; then
|
||||
echo -e "${YELLOW}⚠️ vLLM 容器未运行${NC}"
|
||||
return 1
|
||||
else
|
||||
echo -e "${GREEN}✓ vLLM 容器正在运行 (gemma4-server)${NC}"
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查 PostgreSQL 容器是否运行
|
||||
check_postgres() {
|
||||
if ! docker ps | grep -q postgres-langgraph; then
|
||||
echo -e "${YELLOW}⚠️ PostgreSQL 容器未运行!${NC}"
|
||||
echo "正在启动 PostgreSQL 容器..."
|
||||
docker run -d \
|
||||
--name postgres-langgraph \
|
||||
-e POSTGRES_PASSWORD=mysecretpassword \
|
||||
-e POSTGRES_DB=langgraph_db \
|
||||
-p 5432:5432 \
|
||||
-v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \
|
||||
postgres:16
|
||||
|
||||
echo -e "${GREEN}✓ PostgreSQL 容器已启动${NC}"
|
||||
sleep 3
|
||||
echo -e "${BLUE}🔍 检查 PostgreSQL 容器...${NC}"
|
||||
if ! docker ps --format '{{.Names}}' | grep -q "^postgres-langgraph$"; then
|
||||
echo -e "${YELLOW}⚠️ PostgreSQL 容器未运行${NC}"
|
||||
return 1
|
||||
else
|
||||
echo -e "${GREEN}✓ PostgreSQL 容器正在运行${NC}"
|
||||
echo -e "${GREEN}✓ PostgreSQL 容器正在运行 (postgres-langgraph)${NC}"
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
|
||||
# 启动后端
|
||||
# =============================================================================
|
||||
# 启动 Docker 依赖服务
|
||||
# =============================================================================
|
||||
start_vllm() {
|
||||
echo -e "${BLUE}🚀 启动 vLLM 容器...${NC}"
|
||||
|
||||
# 检查模型文件
|
||||
if [ ! -d "/home/huang/Study/AIModel/gemma-4-E2B-it" ]; then
|
||||
echo -e "${RED}✗ 错误:模型目录不存在: /home/huang/Study/AIModel/gemma-4-E2B-it${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run -d \
|
||||
--name gemma4-server \
|
||||
--group-add=video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
-v /home/huang/Study/AIModel/gemma-4-E2B-it:/models/gemma-4-E2B-it \
|
||||
-e VLLM_ROCM_USE_AITER=0 \
|
||||
-e HF_TOKEN="${HF_TOKEN:-}" \
|
||||
-p 8000:8000 \
|
||||
--ipc=host \
|
||||
--entrypoint vllm \
|
||||
my-vllm-gemma4:working \
|
||||
serve /models/gemma-4-E2B-it \
|
||||
--served-model-name gemma-4-E2B-it \
|
||||
--dtype auto \
|
||||
--api-key token-abc123 \
|
||||
--trust-remote-code \
|
||||
--port 8000 \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--max-model-len 8192
|
||||
|
||||
echo -e "${GREEN}✓ vLLM 容器已启动${NC}"
|
||||
echo -e "${YELLOW}⏳ 等待模型加载(可能需要几分钟)...${NC}"
|
||||
sleep 10
|
||||
}
|
||||
|
||||
start_postgres() {
|
||||
echo -e "${BLUE}🚀 启动 PostgreSQL 容器...${NC}"
|
||||
|
||||
docker run -d \
|
||||
--name postgres-langgraph \
|
||||
-e POSTGRES_PASSWORD=mysecretpassword \
|
||||
-e POSTGRES_DB=langgraph_db \
|
||||
-p 5432:5432 \
|
||||
-v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \
|
||||
postgres:16
|
||||
|
||||
echo -e "${GREEN}✓ PostgreSQL 容器已启动${NC}"
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# 启动 Python 服务
|
||||
# =============================================================================
|
||||
start_backend() {
|
||||
echo -e "\n${BLUE}🚀 启动后端服务 (端口 8001)...${NC}"
|
||||
cd /home/huang/Study/AIProject/Agent1
|
||||
export PYTHONPATH=$(pwd)
|
||||
cd "$PROJECT_DIR"
|
||||
|
||||
# 加载 .env 文件中的环境变量
|
||||
set -a
|
||||
source .env 2>/dev/null || true
|
||||
set +a
|
||||
|
||||
export PYTHONPATH="$PROJECT_DIR"
|
||||
python app/backend.py &
|
||||
BACKEND_PID=$!
|
||||
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# 启动前端
|
||||
start_frontend() {
|
||||
echo -e "\n${BLUE}🎨 启动前端界面...${NC}"
|
||||
cd /home/huang/Study/AIProject/Agent1
|
||||
export PYTHONPATH=$(pwd)
|
||||
echo -e "\n${BLUE}🎨 启动前端界面 (端口 8501)...${NC}"
|
||||
cd "$PROJECT_DIR"
|
||||
|
||||
# 加载 .env 文件中的环境变量
|
||||
set -a
|
||||
source .env 2>/dev/null || true
|
||||
set +a
|
||||
|
||||
export PYTHONPATH="$PROJECT_DIR"
|
||||
streamlit run frontend/frontend.py &
|
||||
FRONTEND_PID=$!
|
||||
echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}"
|
||||
echo -e "${GREEN}✓ 请在浏览器中打开: http://localhost:8501(本地开发)或 http://your-domain.com(Nginx 代理)${NC}"
|
||||
echo -e "${GREEN}✓ 访问地址:${NC}"
|
||||
echo -e " 本地开发: http://localhost:8501"
|
||||
echo -e " Nginx代理: http://your-domain.com"
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Docker Compose 管理
|
||||
# =============================================================================
|
||||
docker_up() {
|
||||
echo -e "${BLUE}🐳 使用 Docker Compose 启动所有服务...${NC}"
|
||||
cd "$PROJECT_DIR"
|
||||
|
||||
# 检查 .env 文件
|
||||
if [ ! -f ".env" ]; then
|
||||
echo -e "${RED}✗ 错误:.env 文件不存在${NC}"
|
||||
echo " 请先复制配置文件:"
|
||||
echo " cp .env.docker .env # 服务器部署"
|
||||
echo " 或"
|
||||
echo " cp .env.local .env # 本地开发"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker compose -f docker/docker-compose.yml up -d --build
|
||||
|
||||
echo -e "\n${GREEN}✓ Docker Compose 服务已启动${NC}"
|
||||
echo -e "${BLUE}📊 查看服务状态:${NC} docker compose -f docker/docker-compose.yml ps"
|
||||
echo -e "${BLUE}📝 查看日志:${NC} docker compose -f docker/docker-compose.yml logs -f"
|
||||
echo -e "${BLUE}🌐 访问应用:${NC} http://localhost:8501"
|
||||
}
|
||||
|
||||
docker_down() {
|
||||
echo -e "${BLUE}🛑 停止 Docker Compose 服务...${NC}"
|
||||
cd "$PROJECT_DIR"
|
||||
docker compose -f docker/docker-compose.yml down
|
||||
echo -e "${GREEN}✓ 服务已停止${NC}"
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# 清理函数
|
||||
# =============================================================================
|
||||
cleanup() {
|
||||
echo -e "\n${RED}🛑 正在停止所有服务...${NC}"
|
||||
echo -e "\n${RED}🛑 正在停止 Python 服务...${NC}"
|
||||
if [ ! -z "$BACKEND_PID" ]; then
|
||||
kill $BACKEND_PID 2>/dev/null || true
|
||||
echo -e "${GREEN}✓ 后端服务已停止${NC}"
|
||||
@@ -117,33 +314,72 @@ cleanup() {
|
||||
echo -e "${YELLOW}💡 提示:Docker 容器需要手动停止${NC}"
|
||||
echo -e " 停止 vLLM: docker stop gemma4-server"
|
||||
echo -e " 停止 PostgreSQL: docker stop postgres-langgraph"
|
||||
echo -e " 或使用: $0 docker-down"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# 捕获 Ctrl+C
|
||||
trap cleanup SIGINT SIGTERM
|
||||
|
||||
# =============================================================================
|
||||
# 主逻辑
|
||||
case "${1:-both}" in
|
||||
# =============================================================================
|
||||
case "${1:-help}" in
|
||||
check)
|
||||
check_config
|
||||
;;
|
||||
|
||||
backend)
|
||||
check_vllm
|
||||
check_postgres
|
||||
check_config || exit 1
|
||||
check_vllm || start_vllm
|
||||
check_postgres || start_postgres
|
||||
start_backend
|
||||
echo -e "\n${GREEN}后端服务正在运行,按 Ctrl+C 停止${NC}"
|
||||
wait $BACKEND_PID
|
||||
;;
|
||||
|
||||
frontend)
|
||||
check_config || exit 1
|
||||
start_frontend
|
||||
echo -e "\n${GREEN}前端服务正在运行,按 Ctrl+C 停止${NC}"
|
||||
wait $FRONTEND_PID
|
||||
;;
|
||||
both|*)
|
||||
check_vllm
|
||||
check_postgres
|
||||
|
||||
both)
|
||||
check_config || exit 1
|
||||
check_vllm || start_vllm
|
||||
check_postgres || start_postgres
|
||||
start_backend
|
||||
start_frontend
|
||||
echo -e "\n${GREEN}所有服务正在运行,按 Ctrl+C 停止 Python 服务${NC}"
|
||||
echo -e "${YELLOW}注意:Docker 容器会在后台继续运行${NC}"
|
||||
wait
|
||||
;;
|
||||
|
||||
docker-up)
|
||||
check_config || exit 1
|
||||
docker_up
|
||||
;;
|
||||
|
||||
docker-down)
|
||||
docker_down
|
||||
;;
|
||||
|
||||
help|*)
|
||||
echo -e "${BLUE}用法:${NC} $0 [command]"
|
||||
echo ""
|
||||
echo -e "${BLUE}命令:${NC}"
|
||||
echo " check 检查环境配置"
|
||||
echo " backend 仅启动后端服务"
|
||||
echo " frontend 仅启动前端服务"
|
||||
echo " both 启动前后端服务(默认)"
|
||||
echo " docker-up 使用 Docker Compose 启动所有服务"
|
||||
echo " docker-down 停止 Docker Compose 服务"
|
||||
echo " help 显示此帮助信息"
|
||||
echo ""
|
||||
echo -e "${BLUE}示例:${NC}"
|
||||
echo " $0 check # 检查配置"
|
||||
echo " $0 both # 启动本地开发环境"
|
||||
echo " $0 docker-up # 启动 Docker 部署环境"
|
||||
;;
|
||||
esac
|
||||
|
||||
Reference in New Issue
Block a user