实现前后端分离的agent
This commit is contained in:
2
.env
2
.env
@@ -1,4 +1,4 @@
|
||||
LOCAL_MODEL_PATH=glm-4.7-flash
|
||||
LOCAL_MODEL_PATH=gemma-4-E2B-it
|
||||
ZHIPUAI_API_KEY=4d568a4367f1442bbc226cc0daf84566.44SsKVWkVIM2Mkeg
|
||||
VLLM_LOCAL_KEY=token-abc123
|
||||
EOF
|
||||
5
.vscode/settings.json
vendored
Normal file
5
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"editor.fontSize": 24,
|
||||
"editor.formatOnSave": true,
|
||||
"files.autoSave": "onWindowChange"
|
||||
}
|
||||
245
QUICKSTART.md
Normal file
245
QUICKSTART.md
Normal file
@@ -0,0 +1,245 @@
|
||||
# 快速开始指南 - 多模型切换功能
|
||||
|
||||
## 🚀 5分钟快速启动
|
||||
|
||||
### 步骤 1: 启动必要的容器
|
||||
|
||||
```bash
|
||||
# 使用提供的启动脚本(推荐)
|
||||
./start.sh
|
||||
|
||||
# 或者手动启动容器
|
||||
# 1. 启动 vLLM (如果需要本地模型)
|
||||
docker run -d --rm \
|
||||
--group-add=video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
-v /home/huang/Study/AIModel/gemma-4-E2B-it:/models/gemma-4-E2B-it \
|
||||
-e VLLM_ROCM_USE_AITER=0 \
|
||||
-e HF_TOKEN="$HF_TOKEN" \
|
||||
-p 8000:8000 \
|
||||
--ipc=host \
|
||||
--entrypoint vllm \
|
||||
my-vllm-gemma4:working \
|
||||
serve /models/gemma-4-E2B-it \
|
||||
--served-model-name gemma-4-E2B-it \
|
||||
--dtype auto \
|
||||
--api-key token-abc123 \
|
||||
--trust-remote-code \
|
||||
--port 8000 \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--max-model-len 8192
|
||||
|
||||
# 2. 启动 PostgreSQL
|
||||
docker run -d \
|
||||
--name postgres-langgraph \
|
||||
-e POSTGRES_PASSWORD=mysecretpassword \
|
||||
-e POSTGRES_DB=langgraph_db \
|
||||
-p 5432:5432 \
|
||||
-v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \
|
||||
postgres:16
|
||||
```
|
||||
|
||||
### 步骤 2: 配置环境变量
|
||||
|
||||
编辑 `.env` 文件:
|
||||
|
||||
```env
|
||||
ZHIPUAI_API_KEY=your_actual_zhipuai_api_key
|
||||
VLLM_LOCAL_KEY=token-abc123
|
||||
```
|
||||
|
||||
### 步骤 3: 启动服务
|
||||
|
||||
```bash
|
||||
# 方式1: 使用启动脚本(推荐)
|
||||
./start.sh
|
||||
|
||||
# 方式2: 手动启动
|
||||
# 终端1: 启动后端
|
||||
python backend.py
|
||||
|
||||
# 终端2: 启动前端
|
||||
streamlit run frontend.py
|
||||
```
|
||||
|
||||
### 步骤 4: 访问应用
|
||||
|
||||
浏览器打开: `http://localhost:8501`
|
||||
|
||||
---
|
||||
|
||||
## 🎯 使用多模型切换功能
|
||||
|
||||
### 在前端切换模型
|
||||
|
||||
1. **打开侧边栏**:点击左上角的菜单图标
|
||||
2. **选择模型**:在"选择大模型"下拉框中选择:
|
||||
- 智谱 GLM-4.7-Flash(在线)
|
||||
- 本地 vLLM(Gemma-4)
|
||||
3. **开始对话**:输入您的问题,系统会使用选定的模型处理
|
||||
|
||||
### 特性说明
|
||||
|
||||
✅ **实时切换**:可以在对话过程中随时切换模型
|
||||
✅ **记忆共享**:同一会话 ID 下,不同模型共享对话历史
|
||||
✅ **自动降级**:如果选择的模型不可用,自动切换到可用模型
|
||||
✅ **状态显示**:每条回复下方会显示实际使用的模型
|
||||
|
||||
---
|
||||
|
||||
## 🧪 测试功能
|
||||
|
||||
### 运行自动化测试
|
||||
|
||||
```bash
|
||||
# 确保后端正在运行
|
||||
python test_multi_model.py
|
||||
```
|
||||
|
||||
测试内容包括:
|
||||
- 各模型的可用性测试
|
||||
- 跨模型会话记忆测试
|
||||
- API 响应格式验证
|
||||
|
||||
### 手动测试
|
||||
|
||||
1. **测试智谱模型**:
|
||||
- 选择"智谱 GLM-4.7-Flash"
|
||||
- 询问:"你好,请介绍一下自己"
|
||||
- 观察回复速度和内容质量
|
||||
|
||||
2. **测试本地模型**:
|
||||
- 选择"本地 vLLM(Gemma-4)"
|
||||
- 询问相同问题
|
||||
- 对比两个模型的回复差异
|
||||
|
||||
3. **测试记忆功能**:
|
||||
- 第一轮(智谱模型):"我叫小明,记住我的名字"
|
||||
- 第二轮(本地模型):"我叫什么名字?"
|
||||
- 验证是否能正确回忆
|
||||
|
||||
---
|
||||
|
||||
## 🔧 常见问题
|
||||
|
||||
### Q1: 某个模型初始化失败怎么办?
|
||||
|
||||
**A:** 系统会自动跳过失败的模型,使用其他可用模型。检查日志了解具体原因:
|
||||
- 智谱模型:确认 `ZHIPUAI_API_KEY` 是否正确
|
||||
- 本地模型:确认 vLLM 容器是否运行
|
||||
|
||||
### Q2: 如何添加新模型?
|
||||
|
||||
**A:** 在 `agent.py` 中添加:
|
||||
|
||||
```python
|
||||
def _create_new_model_llm(self):
|
||||
"""创建新模型的 LLM"""
|
||||
return YourChatModel(
|
||||
model="model-name",
|
||||
api_key="your-key",
|
||||
# ... 其他参数
|
||||
)
|
||||
|
||||
# 在 initialize() 方法的 model_configs 中添加
|
||||
model_configs = {
|
||||
"zhipu": self._create_zhipu_llm,
|
||||
"local": self._create_local_llm,
|
||||
"new_model": self._create_new_model_llm, # 新增
|
||||
}
|
||||
```
|
||||
|
||||
然后在前端 `frontend.py` 的 `MODEL_OPTIONS` 中添加对应选项。
|
||||
|
||||
### Q3: 会话记忆是如何工作的?
|
||||
|
||||
**A:**
|
||||
- 使用 PostgreSQL 存储对话历史
|
||||
- 通过 `thread_id` 关联同一会话的消息
|
||||
- 不同模型共享同一个 checkpointer,因此可以跨模型保持上下文
|
||||
- 点击"新会话"按钮会生成新的 `thread_id`
|
||||
|
||||
### Q4: 性能优化建议
|
||||
|
||||
**A:**
|
||||
- 智谱模型:适合快速响应场景,无需本地 GPU
|
||||
- 本地模型:适合数据隐私要求高的场景,需要 GPU 支持
|
||||
- 长时间对话建议定期开启新会话,避免上下文过长
|
||||
|
||||
---
|
||||
|
||||
## 📊 架构优势
|
||||
|
||||
### 预编译 Graph
|
||||
|
||||
每个模型在启动时都会预编译独立的 LangGraph:
|
||||
- ✅ 避免每次请求都重新编译,提升性能
|
||||
- ✅ 各模型独立,互不影响
|
||||
- ✅ 支持热插拔,可动态添加/移除模型
|
||||
|
||||
### 智能降级
|
||||
|
||||
如果选择的模型不可用:
|
||||
1. 后端自动切换到第一个可用模型
|
||||
2. 返回响应中包含 `model_used` 字段
|
||||
3. 前端显示实际使用的模型
|
||||
4. 用户无感知,体验流畅
|
||||
|
||||
### 统一接口
|
||||
|
||||
无论使用哪个模型:
|
||||
- API 接口保持一致
|
||||
- 工具调用方式相同
|
||||
- 会话记忆机制统一
|
||||
- 前端操作体验一致
|
||||
|
||||
---
|
||||
|
||||
## 🎓 进阶使用
|
||||
|
||||
### 固定会话 ID
|
||||
|
||||
如需在不同浏览器或设备间继续同一会话:
|
||||
|
||||
```python
|
||||
# 在 frontend.py 中修改
|
||||
st.session_state.thread_id = "my_fixed_session_id"
|
||||
```
|
||||
|
||||
### 自定义超时时间
|
||||
|
||||
```python
|
||||
# 在 frontend.py 中修改 timeout 参数
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={...},
|
||||
timeout=120 # 增加到 120 秒
|
||||
)
|
||||
```
|
||||
|
||||
### 批量测试
|
||||
|
||||
```python
|
||||
# 创建测试脚本
|
||||
import requests
|
||||
|
||||
messages = ["问题1", "问题2", "问题3"]
|
||||
for msg in messages:
|
||||
response = requests.post(API_URL, json={"message": msg, "model": "zhipu"})
|
||||
print(response.json()["reply"])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📞 获取帮助
|
||||
|
||||
- 查看完整文档:[README.md](README.md)
|
||||
- 查看项目结构:参考 [README.md](README.md) 中的项目结构部分
|
||||
- 报告问题:提交 Issue 并附上日志信息
|
||||
|
||||
---
|
||||
|
||||
**祝您使用愉快!** 🎉
|
||||
268
README.md
268
README.md
@@ -1,2 +1,268 @@
|
||||
# ailine
|
||||
# AI Agent - 个人生活助手和数据分析助手
|
||||
|
||||
## 项目概述
|
||||
|
||||
这是一个基于 LangGraph、LangChain 和 FastAPI 构建的 AI 助手系统,能够处理天气查询、文件读取、网页抓取等任务。采用前后端分离架构,支持 PostgreSQL 持久化对话记忆。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
Agent1/
|
||||
├── tools.py # 工具定义(纯函数、@tool)
|
||||
├── graph_builder.py # LangGraph 状态图构建(状态定义、节点、边)
|
||||
├── agent.py # AIAgentService 类(模型初始化、graph 管理、消息处理)
|
||||
├── backend.py # FastAPI 应用(路由、WebSocket、lifespan)
|
||||
├── frontend.py # Streamlit 前端(通过 HTTP 调用后端)
|
||||
├── .env # 环境变量(ZHIPUAI_API_KEY 等)
|
||||
├── requirement.txt # Python 依赖包列表
|
||||
└── user_docs/ # 允许读取的文档目录
|
||||
├── a.txt
|
||||
├── b.pdf
|
||||
└── c.xlsx
|
||||
```
|
||||
|
||||
## 核心功能
|
||||
|
||||
- 🌤️ **天气查询**:获取指定地点的当前温度
|
||||
- 📄 **文本文件读取**:读取 `.txt`、`.md` 等文本文件
|
||||
- 📑 **PDF 文件读取**:解析 PDF 文件并提取文本内容
|
||||
- 📊 **Excel 数据处理**:读取 Excel 文件并转换为 Markdown 表格
|
||||
- 🌐 **网页抓取**:抓取网页正文内容
|
||||
- 💾 **持久化记忆**:使用 PostgreSQL 保存对话历史,支持多轮对话上下文
|
||||
- 🔄 **多模型动态切换**:前端可选择不同的大语言模型,后端自动切换处理
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **后端框架**:FastAPI + Uvicorn
|
||||
- **前端框架**:Streamlit
|
||||
- **AI 框架**:LangGraph + LangChain
|
||||
- **数据库**:PostgreSQL(用于持久化对话记忆)
|
||||
- **LLM 支持**:
|
||||
- 智谱 AI(glm-4.7-flash):在线服务,响应速度快
|
||||
- 本地 vLLM(gemma-4-E2B-it):本地部署,数据隐私性好
|
||||
|
||||
系统支持多种大语言模型,可在前端动态切换。每个模型在启动时都会预编译独立的 LangGraph,确保最佳性能。如果某个模型初始化失败(如 API Key 未配置),系统会自动降级到可用模型。
|
||||
|
||||
## 环境要求
|
||||
|
||||
- Python 3.10+
|
||||
- PostgreSQL 16+
|
||||
- Docker(可选,用于运行 PostgreSQL)
|
||||
|
||||
## 安装步骤
|
||||
|
||||
### 1. 启动 PostgreSQL 容器
|
||||
|
||||
```bash
|
||||
docker run -d \
|
||||
--name postgres-langgraph \
|
||||
-e POSTGRES_PASSWORD=mysecretpassword \
|
||||
-e POSTGRES_DB=langgraph_db \
|
||||
-p 5432:5432 \
|
||||
-v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \
|
||||
postgres:16
|
||||
```
|
||||
|
||||
### 2. 安装 Python 依赖
|
||||
|
||||
```bash
|
||||
pip install fastapi uvicorn streamlit requests psycopg[binary,pool] \
|
||||
langgraph langgraph-checkpoint-postgres langchain langchain-community \
|
||||
langchain-openai python-dotenv pypdf pandas beautifulsoup4
|
||||
```
|
||||
|
||||
或者使用 requirements.txt:
|
||||
|
||||
```bash
|
||||
pip install -r requirement.txt
|
||||
```
|
||||
|
||||
### 3. 配置环境变量
|
||||
|
||||
编辑 `.env` 文件,设置您的 API 密钥:
|
||||
|
||||
```env
|
||||
ZHIPUAI_API_KEY=your_zhipuai_api_key_here
|
||||
VLLM_LOCAL_KEY=token-abc123 # 如果使用本地模型
|
||||
```
|
||||
|
||||
## 运行步骤
|
||||
|
||||
### 1. 启动后端服务
|
||||
|
||||
```bash
|
||||
python backend.py
|
||||
```
|
||||
|
||||
看到 `Uvicorn running on http://0.0.0.0:8001` 即表示启动成功。
|
||||
|
||||
### 2. 启动前端界面(新终端)
|
||||
|
||||
```bash
|
||||
streamlit run frontend.py
|
||||
```
|
||||
|
||||
浏览器会自动打开 `http://localhost:8501`,即可开始使用。
|
||||
|
||||
## API 接口
|
||||
|
||||
### POST /chat
|
||||
|
||||
同步对话接口,支持模型选择
|
||||
|
||||
**请求体:**
|
||||
```json
|
||||
{
|
||||
"message": "今天北京天气怎么样?",
|
||||
"thread_id": "optional-thread-id",
|
||||
"model": "zhipu" // 可选: "zhipu" 或 "local"
|
||||
}
|
||||
```
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"reply": "当前北京的温度为25℃",
|
||||
"thread_id": "generated-or-provided-thread-id",
|
||||
"model_used": "zhipu" // 实际使用的模型
|
||||
}
|
||||
```
|
||||
|
||||
**模型选项:**
|
||||
- `zhipu`:智谱 GLM-4.7-Flash(在线)
|
||||
- `local`:本地 vLLM Gemma-4(需要启动 vLLM 容器)
|
||||
|
||||
### WebSocket /ws
|
||||
|
||||
流式对话接口(可选扩展)
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 工具调用示例
|
||||
|
||||
1. **查询天气**:
|
||||
```
|
||||
用户:今天上海天气怎么样?
|
||||
```
|
||||
|
||||
2. **读取文本文件**:
|
||||
```
|
||||
用户:请读取 a.txt 文件的内容
|
||||
```
|
||||
|
||||
3. **读取 PDF 文件**:
|
||||
```
|
||||
用户:帮我总结一下 b.pdf 的内容
|
||||
```
|
||||
|
||||
4. **读取 Excel 文件**:
|
||||
```
|
||||
用户:显示 c.xlsx 的数据
|
||||
```
|
||||
|
||||
5. **抓取网页**:
|
||||
```
|
||||
用户:请抓取 https://example.com 的内容
|
||||
```
|
||||
|
||||
### 会话记忆
|
||||
|
||||
- 系统会自动为每个会话生成唯一的 `thread_id`
|
||||
- 相同 `thread_id` 的对话会共享历史记录
|
||||
- 即使重启后端服务,对话历史依然保存在 PostgreSQL 中
|
||||
- 如需固定会话 ID,可在前端代码中修改 `st.session_state.thread_id` 为固定字符串
|
||||
|
||||
### 多模型切换
|
||||
|
||||
**前端操作:**
|
||||
1. 在左侧边栏的"选择大模型"下拉框中选择模型
|
||||
2. 可随时切换模型,甚至在同一会话中
|
||||
3. 点击"🔄 新会话"按钮可清空当前对话并开始新的会话
|
||||
|
||||
**后端行为:**
|
||||
- 启动时会预编译所有可用模型的 LangGraph
|
||||
- 如果某个模型初始化失败(如 API Key 未配置),会自动跳过
|
||||
- 请求时如果选择的模型不可用,会自动降级到第一个可用模型
|
||||
- 响应中会返回 `model_used` 字段,显示实际使用的模型
|
||||
|
||||
**添加新模型:**
|
||||
在 `agent.py` 的 `initialize()` 方法中的 `model_configs` 字典添加新模型即可:
|
||||
```python
|
||||
model_configs = {
|
||||
"zhipu": self._create_zhipu_llm,
|
||||
"local": self._create_local_llm,
|
||||
"new_model": self._create_new_model_llm, # 添加新模型
|
||||
}
|
||||
```
|
||||
|
||||
## 架构说明
|
||||
|
||||
### 模块职责
|
||||
|
||||
- **tools.py**:独立工具模块,包含所有 `@tool` 装饰的纯函数,无外部依赖,可单独测试
|
||||
- **graph_builder.py**:LangGraph 状态图构建器,定义状态、节点函数和条件边
|
||||
- **agent.py**:AIAgentService 服务类,负责模型初始化和 graph 编译,使用 `AsyncPostgresSaver`
|
||||
- **backend.py**:FastAPI 应用,提供 REST API 和 WebSocket 接口,端口 8001
|
||||
- **frontend.py**:Streamlit 前端,通过 HTTP 调用后端 API,实现友好的用户界面
|
||||
|
||||
### 数据流
|
||||
|
||||
```
|
||||
用户输入 → Streamlit 前端 → FastAPI 后端 → AIAgentService
|
||||
→ LangGraph StateGraph → LLM + Tools → PostgreSQL (记忆)
|
||||
→ 返回响应 → 前端展示
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **文件安全**:所有文件读取操作仅限于 `./user_docs` 目录,防止路径遍历攻击
|
||||
2. **端口冲突**:后端使用 8001 端口,避免与本地 vLLM 服务的 8000 端口冲突
|
||||
3. **API 密钥**:请妥善保管 `.env` 文件中的 API 密钥,不要提交到版本控制系统
|
||||
4. **数据库持久化**:PostgreSQL 数据卷挂载到 `~/docker_volumes/postgres_data`,确保数据安全
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 问题:无法连接 PostgreSQL
|
||||
|
||||
**解决方案:**
|
||||
```bash
|
||||
# 检查容器是否运行
|
||||
docker ps | grep postgres-langgraph
|
||||
|
||||
# 查看容器日志
|
||||
docker logs postgres-langgraph
|
||||
|
||||
# 重新启动容器
|
||||
docker restart postgres-langgraph
|
||||
```
|
||||
|
||||
### 问题:后端启动失败
|
||||
|
||||
**解决方案:**
|
||||
- 确认端口 8001 未被占用
|
||||
- 检查 `.env` 文件中的 API 密钥是否正确配置
|
||||
- 确认所有依赖包已正确安装
|
||||
- 查看启动日志,确认至少有一个模型初始化成功
|
||||
|
||||
### 问题:模型切换后无响应
|
||||
|
||||
**解决方案:**
|
||||
- 检查所选模型的配置是否正确(如智谱 API Key)
|
||||
- 确认 vLLM 容器是否正在运行(如果使用本地模型)
|
||||
- 查看后端日志,确认模型是否初始化成功
|
||||
- 尝试切换到另一个模型
|
||||
|
||||
### 问题:工具调用失败
|
||||
|
||||
**解决方案:**
|
||||
- 确认文件位于 `./user_docs` 目录下
|
||||
- 检查文件格式是否正确
|
||||
- 查看后端日志获取详细错误信息
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用 MIT 许可证。详见 [LICENSE](LICENSE) 文件。
|
||||
|
||||
## 贡献
|
||||
|
||||
欢迎提交 Issue 和 Pull Request!
|
||||
|
||||
224
agent.py
224
agent.py
@@ -1,187 +1,85 @@
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.agents import create_agent
|
||||
import requests
|
||||
import pypdf
|
||||
import pandas as pd
|
||||
from dotenv import load_dotenv
|
||||
"""
|
||||
AI Agent 服务类 - 支持多模型动态切换
|
||||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_huggingface import HuggingFacePipeline,ChatHuggingFace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.messages import HumanMessage
|
||||
from transformers import BitsAndBytesConfig
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
##--基础定义
|
||||
# 本地模块
|
||||
from graph_builder import GraphBuilder
|
||||
from tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
|
||||
load_dotenv()
|
||||
|
||||
LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH","glm-4.7-flash")
|
||||
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
|
||||
VLLM_LOCAL_KEY = os.getenv("VLLM_LOCAL_KEY", "")
|
||||
DEVICE = os.getenv("DEVICE")
|
||||
|
||||
##加载模型
|
||||
local_llm = None
|
||||
online_llm = None
|
||||
class AIAgentService:
|
||||
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
|
||||
|
||||
def get_local_llm():
|
||||
global local_llm
|
||||
if local_llm is None:
|
||||
local_llm = ChatOpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key=SecretStr(VLLM_LOCAL_KEY),
|
||||
model="gemma-4-E2B-it",
|
||||
)
|
||||
return local_llm
|
||||
def __init__(self, checkpointer):
|
||||
"""
|
||||
初始化服务
|
||||
Args:
|
||||
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {} # 存储不同模型对应的 graph 实例
|
||||
|
||||
def get_online_llm():
|
||||
global online_llm
|
||||
if online_llm is None:
|
||||
online_llm = ChatZhipuAI(
|
||||
def _create_zhipu_llm(self):
|
||||
"""创建智谱在线 LLM"""
|
||||
api_key = os.getenv("ZHIPUAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set in environment")
|
||||
return ChatZhipuAI(
|
||||
model="glm-4.7-flash",
|
||||
api_key=ZHIPUAI_API_KEY,
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
)
|
||||
return online_llm
|
||||
|
||||
##工具调用
|
||||
def _create_local_llm(self):
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
return ChatOpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
|
||||
model="gemma-4-E2B-it",
|
||||
)
|
||||
|
||||
@tool
|
||||
def get_currenttemperature(location: str) -> str:
|
||||
"""获取指定地点的当前温度,当用户询问天气或温度时使用此工具。"""
|
||||
return f'当前{location}的温度为25℃'
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
model_configs = {
|
||||
"zhipu": self._create_zhipu_llm,
|
||||
"local": self._create_local_llm,
|
||||
}
|
||||
|
||||
# sym:file_allow_check
|
||||
def file_allow_check(filename: str) -> Path:
|
||||
"""
|
||||
检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。
|
||||
返回合法的 Path 对象,若不合法则抛出异常。
|
||||
"""
|
||||
allowed_dir = Path("./user_docs").resolve()
|
||||
allowed_dir.mkdir(exist_ok=True)
|
||||
|
||||
file_path = (allowed_dir / filename).resolve()
|
||||
if not str(file_path).startswith(str(allowed_dir)):
|
||||
raise ValueError("错误:非法文件路径。")
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
@tool
|
||||
def read_local_file(filename: str) -> str:
|
||||
"""
|
||||
读取用户指定名称的本地文本文件内容并返回摘要。
|
||||
参数 filename: 文件名,例如 'project_plan.txt' 或 'notes.md'。
|
||||
"""
|
||||
for model_name, llm_creator in model_configs.items():
|
||||
try:
|
||||
file_path = file_allow_check(filename)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return str(e)
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
# 2. 内容过长时,可以在此处增加一个简单的摘要逻辑,或者直接返回前N个字符
|
||||
# 为了演示,这里返回前1000个字符
|
||||
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
|
||||
llm = llm_creator()
|
||||
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[model_name] = graph
|
||||
print(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
except Exception as e:
|
||||
return f"读取文件时出错:{str(e)}"
|
||||
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型,请检查配置")
|
||||
|
||||
@tool
|
||||
def read_pdf_summary(filename: str) -> str:
|
||||
"""
|
||||
读取PDF文件并返回内容文本。参数 filename: PDF文件名,例如 'report.pdf'。
|
||||
"""
|
||||
try:
|
||||
file_path = file_allow_check(filename)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return str(e)
|
||||
try:
|
||||
text = ""
|
||||
with open(file_path, 'rb') as f:
|
||||
reader = pypdf.PdfReader(f)
|
||||
for page in reader.pages[:3]:
|
||||
text += page.extract_text()
|
||||
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
|
||||
except Exception as e:
|
||||
return f"读取PDF出错:{e}"
|
||||
return self
|
||||
|
||||
@tool
|
||||
def read_excel_as_markdown(filename: str) -> str:
|
||||
"""
|
||||
读取Excel文件,并将其主要数据转换为Markdown表格格式。参数 filename: Excel文件名,例如 'data.xlsx'。
|
||||
"""
|
||||
try:
|
||||
file_path = file_allow_check(filename)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return str(e)
|
||||
|
||||
try:
|
||||
df = pd.read_excel(file_path)
|
||||
markdown_table = df.head(10).to_markdown(index=False)
|
||||
return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}"
|
||||
except Exception as e:
|
||||
return f"读取Excel出错:{e}"
|
||||
|
||||
@tool
|
||||
def fetch_webpage_content(url: str) -> str:
|
||||
"""
|
||||
抓取给定URL的网页正文内容,并返回清晰的纯文本。
|
||||
参数 url: 完整的网页地址,例如 'https://example.com/article'。
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
# 简单的正文提取,去除脚本和样式
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = '\n'.join(chunk for chunk in chunks if chunk)
|
||||
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
|
||||
except Exception as e:
|
||||
return f"抓取网页时出错:{str(e)}"
|
||||
|
||||
#使用langgraph
|
||||
agent=create_agent(
|
||||
model=get_local_llm(),
|
||||
tools=[get_currenttemperature,read_local_file,fetch_webpage_content,read_pdf_summary,read_excel_as_markdown],
|
||||
system_prompt=(
|
||||
"你是一个个人生活助手和数据分析助手。请说中文。"
|
||||
"当用户询问天气或温度时,使用get_currenttemperature工具获取信息。"
|
||||
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
||||
"当用户要求分析文档时,请使用合适的工具读取内容,然后:1. 总结核心发现。2. 如果涉及数据,请以Markdown表格或列表的形式清晰地呈现。"
|
||||
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述、<think>标记或内部推理。直接给出最终答案或工具调用指令。"
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
user_input = input("请输入: ")
|
||||
if user_input.lower() == "exit":
|
||||
break
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
response=agent.invoke({"messages":[HumanMessage(content=user_input)]})
|
||||
# 计算思考时间
|
||||
thinking_time = time.time() - start_time
|
||||
# 提取回答内容
|
||||
final_answer=response["messages"][-1].content
|
||||
# 打印回答和统计信息
|
||||
print(f"\n{final_answer}")
|
||||
print(f"思考时间: {thinking_time:.2f}秒")
|
||||
print("-" * 50)
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
|
||||
"""处理用户消息,返回最终答案"""
|
||||
if model not in self.graphs:
|
||||
fallback_model = next(iter(self.graphs.keys()))
|
||||
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
|
||||
model = fallback_model
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
input_state = {"messages": [HumanMessage(content=message)]}
|
||||
result = await graph.ainvoke(input_state, config=config)
|
||||
return result["messages"][-1].content
|
||||
115
backend.py
Normal file
115
backend.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
|
||||
采用依赖注入模式,优雅管理资源生命周期
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
from agent import AIAgentService
|
||||
|
||||
# PostgreSQL 连接字符串
|
||||
DB_URI = "postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:创建并注入全局服务"""
|
||||
# 1. 创建数据库连接池并初始化表
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. 构建 AI Agent 服务
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
|
||||
# 3. 将服务实例存入 app.state
|
||||
app.state.agent_service = agent_service
|
||||
|
||||
# 应用运行中...
|
||||
yield
|
||||
|
||||
# 4. 关闭时自动清理数据库连接(async with 负责)
|
||||
print("🛑 应用关闭,数据库连接池已释放")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# CORS 中间件(允许前端跨域)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ========== Pydantic 模型 ==========
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
thread_id: str | None = None
|
||||
model: str = "zhipu"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str
|
||||
thread_id: str
|
||||
model_used: str
|
||||
|
||||
|
||||
# ========== 依赖注入函数 ==========
|
||||
def get_agent_service(request: Request) -> AIAgentService:
|
||||
"""从 app.state 中获取全局 AIAgentService 实例"""
|
||||
return request.app.state.agent_service
|
||||
|
||||
|
||||
# ========== HTTP 端点 ==========
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(
|
||||
request: ChatRequest,
|
||||
agent_service: AIAgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""同步对话接口,支持模型选择"""
|
||||
if not request.message:
|
||||
raise HTTPException(status_code=400, detail="message required")
|
||||
|
||||
thread_id = request.thread_id or str(uuid.uuid4())
|
||||
reply = await agent_service.process_message(
|
||||
request.message, thread_id, request.model
|
||||
)
|
||||
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model)
|
||||
|
||||
|
||||
# ========== WebSocket 端点(可选) ==========
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
agent_service: AIAgentService = Depends(get_agent_service)
|
||||
):
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
message = data.get("message")
|
||||
thread_id = data.get("thread_id", str(uuid.uuid4()))
|
||||
model = data.get("model", "zhipu")
|
||||
if not message:
|
||||
await websocket.send_json({"error": "missing message"})
|
||||
continue
|
||||
reply = await agent_service.process_message(message, thread_id, model)
|
||||
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
94
frontend.py
Normal file
94
frontend.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Streamlit 前端 - 支持模型选择
|
||||
"""
|
||||
|
||||
# 标准库
|
||||
import uuid
|
||||
|
||||
# 第三方库
|
||||
import requests
|
||||
import streamlit as st
|
||||
|
||||
# 后端 API 地址(端口 8001)
|
||||
API_URL = "http://localhost:8001/chat"
|
||||
|
||||
st.set_page_config(page_title="AI 个人助手", page_icon="🤖")
|
||||
st.title("🤖 个人生活与数据分析助手")
|
||||
|
||||
# 模型选项(与后端支持的模型名称一致)
|
||||
MODEL_OPTIONS = {
|
||||
"zhipu": "智谱 GLM-4.7-Flash(在线)",
|
||||
"local": "本地 vLLM(Gemma-4)"
|
||||
}
|
||||
|
||||
# 初始化会话状态
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "thread_id" not in st.session_state:
|
||||
st.session_state.thread_id = str(uuid.uuid4())
|
||||
if "selected_model" not in st.session_state:
|
||||
st.session_state.selected_model = "zhipu"
|
||||
|
||||
# 侧边栏:模型选择和会话管理
|
||||
with st.sidebar:
|
||||
st.header("⚙️ 设置")
|
||||
|
||||
# 模型选择
|
||||
selected_model_key = st.selectbox(
|
||||
"选择大模型",
|
||||
options=list(MODEL_OPTIONS.keys()),
|
||||
format_func=lambda x: MODEL_OPTIONS[x],
|
||||
index=0
|
||||
)
|
||||
st.session_state.selected_model = selected_model_key
|
||||
|
||||
# 会话信息显示
|
||||
st.write(f"当前会话 ID: `{st.session_state.thread_id[:8]}...`")
|
||||
|
||||
# 新会话按钮
|
||||
if st.button("🔄 新会话"):
|
||||
st.session_state.thread_id = str(uuid.uuid4())
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
# 显示历史消息
|
||||
for msg in st.session_state.messages:
|
||||
with st.chat_message(msg["role"]):
|
||||
st.markdown(msg["content"])
|
||||
|
||||
# 用户输入
|
||||
if prompt := st.chat_input("请输入您的问题..."):
|
||||
# 显示用户消息
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 调用后端 API(携带模型参数)
|
||||
with st.chat_message("assistant"):
|
||||
with st.spinner("思考中..."):
|
||||
try:
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={
|
||||
"message": prompt,
|
||||
"thread_id": st.session_state.thread_id,
|
||||
"model": st.session_state.selected_model
|
||||
},
|
||||
timeout=60
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
reply = data["reply"]
|
||||
model_used = data["model_used"]
|
||||
|
||||
# 显示回复
|
||||
st.markdown(reply)
|
||||
|
||||
# 显示使用的模型(小字提示)
|
||||
st.caption(f"🤖 使用模型: {MODEL_OPTIONS.get(model_used, model_used)}")
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": reply})
|
||||
except Exception as e:
|
||||
error_msg = f"请求失败: {e}"
|
||||
st.error(error_msg)
|
||||
st.session_state.messages.append({"role": "assistant", "content": error_msg})
|
||||
127
graph_builder.py
Normal file
127
graph_builder.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
|
||||
"""
|
||||
|
||||
import operator
|
||||
import asyncio
|
||||
from typing import Literal, Annotated, Any
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class MessageState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
llm_calls: int
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
|
||||
|
||||
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]):
|
||||
"""
|
||||
初始化构建器
|
||||
|
||||
Args:
|
||||
llm: 大语言模型实例
|
||||
tools: 工具列表
|
||||
tools_by_name: 名称到工具函数的映射
|
||||
"""
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
self.tools_by_name = tools_by_name
|
||||
self._llm_with_tools = llm.bind_tools(tools)
|
||||
self._prompt = self._create_prompt()
|
||||
self._chain = self._prompt | self._llm_with_tools
|
||||
|
||||
@staticmethod
|
||||
def _create_prompt() -> ChatPromptTemplate:
|
||||
"""创建系统提示模板(静态方法,无需访问实例)"""
|
||||
return ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=(
|
||||
"你是一个个人生活助手和数据分析助手。请说中文。"
|
||||
"当用户询问天气或温度时,使用get_current_temperature工具获取信息。"
|
||||
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
||||
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
|
||||
)),
|
||||
MessagesPlaceholder(variable_name="message")
|
||||
])
|
||||
|
||||
async def call_llm(self, state: MessageState) -> dict:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._chain.invoke({"message": state["messages"]})
|
||||
)
|
||||
return {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1
|
||||
}
|
||||
|
||||
async def call_tools(self, state: MessageState) -> dict:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
对于每个工具调用,在线程池中执行同步工具函数
|
||||
"""
|
||||
last_message = state['messages'][-1]
|
||||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||||
return {"messages": []}
|
||||
|
||||
results = []
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
tool_id = tool_call["id"]
|
||||
tool_func = self.tools_by_name.get(tool_name)
|
||||
|
||||
if tool_func is None:
|
||||
results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id))
|
||||
continue
|
||||
|
||||
try:
|
||||
# 同步工具函数在线程池中执行
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: tool_func.invoke(tool_args)
|
||||
)
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
except Exception as e:
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
|
||||
return {"messages": results}
|
||||
|
||||
@staticmethod
|
||||
def should_continue(state: MessageState) -> Literal['tool_node', END]:
|
||||
"""
|
||||
条件边判断(静态方法)
|
||||
决定下一步是进入工具节点还是结束
|
||||
"""
|
||||
last_message = state["messages"][-1]
|
||||
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
|
||||
return 'tool_node'
|
||||
return END
|
||||
|
||||
def build(self) -> StateGraph:
|
||||
"""
|
||||
构建未编译的状态图(返回 StateGraph 实例)
|
||||
图中节点直接使用实例方法 call_llm, call_tools
|
||||
"""
|
||||
builder = StateGraph(MessageState)
|
||||
builder.add_node("llm_call", self.call_llm)
|
||||
builder.add_node("tool_node", self.call_tools)
|
||||
builder.add_edge(START, "llm_call")
|
||||
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
|
||||
builder.add_edge("tool_node", "llm_call")
|
||||
return builder
|
||||
@@ -13,11 +13,28 @@ langchain-huggingface>=0.0.3
|
||||
langchain-core>=0.1.0
|
||||
langchain-openai>=0.0.5
|
||||
|
||||
# LangGraph
|
||||
langgraph>=0.0.30
|
||||
langgraph-checkpoint-postgres>=0.0.5
|
||||
|
||||
# ZhipuAI (智谱AI)
|
||||
zhipuai>=1.0.0
|
||||
|
||||
# Backend
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
websockets>=12.0
|
||||
|
||||
# Frontend
|
||||
streamlit>=1.30.0
|
||||
|
||||
# Database
|
||||
psycopg[binary,pool]>=3.1.0
|
||||
|
||||
# Pydantic
|
||||
pydantic>=2.0.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv>=1.0.0
|
||||
typing-extensions>=4.9.0
|
||||
ipython>=8.0.0
|
||||
|
||||
145
start.sh
Normal file
145
start.sh
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/bin/bash
|
||||
|
||||
# AI Agent 启动脚本
|
||||
# 用法: ./start.sh [backend|frontend|both]
|
||||
|
||||
set -e
|
||||
|
||||
# 颜色定义
|
||||
GREEN='\033[0;32m'
|
||||
BLUE='\033[0;34m'
|
||||
RED='\033[0;31m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo -e "${BLUE} AI Agent - 个人生活助手启动脚本${NC}"
|
||||
echo -e "${BLUE}========================================${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查 vLLM 容器是否运行
|
||||
check_vllm() {
|
||||
if ! docker ps --format '{{.Names}}' | grep -q "^gemma4-server$"; then
|
||||
echo -e "${YELLOW}⚠️ vLLM 容器未运行!${NC}"
|
||||
echo "正在启动 vLLM 容器(Gemma-4 模型)..."
|
||||
|
||||
# 检查模型文件是否存在
|
||||
if [ ! -d "/home/huang/Study/AIModel/gemma-4-E2B-it" ]; then
|
||||
echo -e "${RED}✗ 错误:模型目录不存在: /home/huang/Study/AIModel/gemma-4-E2B-it${NC}"
|
||||
echo "请先下载模型或修改模型路径"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run -d \
|
||||
--name gemma4-server \
|
||||
--group-add=video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
-v /home/huang/Study/AIModel/gemma-4-E2B-it:/models/gemma-4-E2B-it \
|
||||
-e VLLM_ROCM_USE_AITER=0 \
|
||||
-e HF_TOKEN="${HF_TOKEN}" \
|
||||
-p 8000:8000 \
|
||||
--ipc=host \
|
||||
--entrypoint vllm \
|
||||
my-vllm-gemma4:working \
|
||||
serve /models/gemma-4-E2B-it \
|
||||
--served-model-name gemma-4-E2B-it \
|
||||
--dtype auto \
|
||||
--api-key token-abc123 \
|
||||
--trust-remote-code \
|
||||
--port 8000 \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--max-model-len 8192
|
||||
|
||||
echo -e "${GREEN}✓ vLLM 容器已启动${NC}"
|
||||
echo -e "${YELLOW}⏳ 等待模型加载(可能需要几分钟)...${NC}"
|
||||
sleep 10
|
||||
else
|
||||
echo -e "${GREEN}✓ vLLM 容器正在运行${NC}"
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查 PostgreSQL 容器是否运行
|
||||
check_postgres() {
|
||||
if ! docker ps | grep -q postgres-langgraph; then
|
||||
echo -e "${YELLOW}⚠️ PostgreSQL 容器未运行!${NC}"
|
||||
echo "正在启动 PostgreSQL 容器..."
|
||||
docker run -d \
|
||||
--name postgres-langgraph \
|
||||
-e POSTGRES_PASSWORD=mysecretpassword \
|
||||
-e POSTGRES_DB=langgraph_db \
|
||||
-p 5432:5432 \
|
||||
-v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \
|
||||
postgres:16
|
||||
|
||||
echo -e "${GREEN}✓ PostgreSQL 容器已启动${NC}"
|
||||
sleep 3
|
||||
else
|
||||
echo -e "${GREEN}✓ PostgreSQL 容器正在运行${NC}"
|
||||
fi
|
||||
}
|
||||
|
||||
# 启动后端
|
||||
start_backend() {
|
||||
echo -e "\n${BLUE}🚀 启动后端服务 (端口 8001)...${NC}"
|
||||
python backend.py &
|
||||
BACKEND_PID=$!
|
||||
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# 启动前端
|
||||
start_frontend() {
|
||||
echo -e "\n${BLUE}🎨 启动前端界面...${NC}"
|
||||
streamlit run frontend.py &
|
||||
FRONTEND_PID=$!
|
||||
echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}"
|
||||
echo -e "${GREEN}✓ 请在浏览器中打开: http://localhost:8501${NC}"
|
||||
}
|
||||
|
||||
# 清理函数
|
||||
cleanup() {
|
||||
echo -e "\n${RED}🛑 正在停止所有服务...${NC}"
|
||||
if [ ! -z "$BACKEND_PID" ]; then
|
||||
kill $BACKEND_PID 2>/dev/null || true
|
||||
echo -e "${GREEN}✓ 后端服务已停止${NC}"
|
||||
fi
|
||||
if [ ! -z "$FRONTEND_PID" ]; then
|
||||
kill $FRONTEND_PID 2>/dev/null || true
|
||||
echo -e "${GREEN}✓ 前端服务已停止${NC}"
|
||||
fi
|
||||
echo -e "${YELLOW}💡 提示:Docker 容器需要手动停止${NC}"
|
||||
echo -e " 停止 vLLM: docker stop gemma4-server"
|
||||
echo -e " 停止 PostgreSQL: docker stop postgres-langgraph"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# 捕获 Ctrl+C
|
||||
trap cleanup SIGINT SIGTERM
|
||||
|
||||
# 主逻辑
|
||||
case "${1:-both}" in
|
||||
backend)
|
||||
check_vllm
|
||||
check_postgres
|
||||
start_backend
|
||||
echo -e "\n${GREEN}后端服务正在运行,按 Ctrl+C 停止${NC}"
|
||||
wait $BACKEND_PID
|
||||
;;
|
||||
frontend)
|
||||
start_frontend
|
||||
echo -e "\n${GREEN}前端服务正在运行,按 Ctrl+C 停止${NC}"
|
||||
wait $FRONTEND_PID
|
||||
;;
|
||||
both|*)
|
||||
check_vllm
|
||||
check_postgres
|
||||
start_backend
|
||||
start_frontend
|
||||
echo -e "\n${GREEN}所有服务正在运行,按 Ctrl+C 停止 Python 服务${NC}"
|
||||
echo -e "${YELLOW}注意:Docker 容器会在后台继续运行${NC}"
|
||||
wait
|
||||
;;
|
||||
esac
|
||||
@@ -1,20 +0,0 @@
|
||||
from openai import OpenAI
|
||||
|
||||
# 连接本地 vLLM 服务
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1", # 容器映射的地址
|
||||
api_key="token-abc123", # 与你启动命令中的 --api-key 一致
|
||||
)
|
||||
|
||||
# 发起对话
|
||||
response = client.chat.completions.create(
|
||||
model="gemma-4-E2B-it", # --served-model-name 指定的名称
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "请用中文介绍一下你自己"}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
134
test_multi_model.py
Normal file
134
test_multi_model.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
多模型切换功能测试脚本
|
||||
用于验证后端是否正确支持多模型动态切换
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
API_URL = "http://localhost:8001/chat"
|
||||
|
||||
|
||||
def test_model_switching():
|
||||
"""测试模型切换功能"""
|
||||
|
||||
print("=" * 60)
|
||||
print("测试多模型切换功能")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试消息
|
||||
test_message = "你好,请简单介绍一下自己"
|
||||
|
||||
# 测试不同的模型
|
||||
models_to_test = ["zhipu", "local"]
|
||||
|
||||
for model in models_to_test:
|
||||
print(f"\n📤 测试模型: {model}")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={
|
||||
"message": test_message,
|
||||
"model": model
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f"✅ 成功!")
|
||||
print(f" 使用的模型: {data['model_used']}")
|
||||
print(f" 会话 ID: {data['thread_id'][:8]}...")
|
||||
print(f" 回复预览: {data['reply'][:100]}...")
|
||||
else:
|
||||
print(f"❌ 失败! 状态码: {response.status_code}")
|
||||
print(f" 错误信息: {response.text}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print(f"⏰ 超时! 模型 '{model}' 响应时间过长")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(f"🔌 连接失败! 请确认后端服务正在运行 (python backend.py)")
|
||||
except Exception as e:
|
||||
print(f"💥 异常: {str(e)}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def test_conversation_memory():
|
||||
"""测试跨模型的会话记忆"""
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试跨模型会话记忆")
|
||||
print("=" * 60)
|
||||
|
||||
import uuid
|
||||
thread_id = str(uuid.uuid4())
|
||||
|
||||
print(f"\n📝 使用固定会话 ID: {thread_id[:8]}...")
|
||||
|
||||
# 第一轮对话 - 使用 zhipu 模型
|
||||
print("\n📤 第1轮 - 使用 zhipu 模型")
|
||||
try:
|
||||
response1 = requests.post(
|
||||
API_URL,
|
||||
json={
|
||||
"message": "我叫小明,记住我的名字",
|
||||
"thread_id": thread_id,
|
||||
"model": "zhipu"
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
if response1.status_code == 200:
|
||||
data1 = response1.json()
|
||||
print(f" ✅ 回复: {data1['reply'][:100]}...")
|
||||
print(f" 🤖 使用模型: {data1['model_used']}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 失败: {e}")
|
||||
return
|
||||
|
||||
# 第二轮对话 - 切换到 local 模型,测试是否记得名字
|
||||
print("\n📤 第2轮 - 切换到 local 模型")
|
||||
try:
|
||||
response2 = requests.post(
|
||||
API_URL,
|
||||
json={
|
||||
"message": "我叫什么名字?",
|
||||
"thread_id": thread_id,
|
||||
"model": "local"
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
if response2.status_code == 200:
|
||||
data2 = response2.json()
|
||||
print(f" ✅ 回复: {data2['reply'][:100]}...")
|
||||
print(f" 🤖 使用模型: {data2['model_used']}")
|
||||
|
||||
# 检查是否记得名字
|
||||
if "小明" in data2['reply']:
|
||||
print(" 🎉 成功!跨模型记忆功能正常")
|
||||
else:
|
||||
print(" ⚠️ 注意:模型可能没有正确回忆上下文")
|
||||
except Exception as e:
|
||||
print(f" ❌ 失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("会话记忆测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n⚠️ 请确保后端服务正在运行 (python backend.py)\n")
|
||||
|
||||
# 运行基本测试
|
||||
test_model_switching()
|
||||
|
||||
# 询问是否运行记忆测试
|
||||
choice = input("\n是否运行会话记忆测试?(y/n): ").strip().lower()
|
||||
if choice == 'y':
|
||||
test_conversation_memory()
|
||||
|
||||
print("\n✨ 所有测试完成!")
|
||||
103
tools.py
Normal file
103
tools.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
工具定义模块 - 纯函数工具,无依赖 AIAgent 类
|
||||
"""
|
||||
|
||||
# 标准库
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 第三方库
|
||||
import pandas as pd
|
||||
import pypdf
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _file_allow_check(filename: str) -> Path:
|
||||
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
|
||||
allowed_dir = Path("./user_docs").resolve()
|
||||
allowed_dir.mkdir(exist_ok=True)
|
||||
|
||||
file_path = (allowed_dir / filename).resolve()
|
||||
if not str(file_path).startswith(str(allowed_dir)):
|
||||
raise ValueError("错误:非法文件路径。")
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_temperature(location: str) -> str:
|
||||
"""获取指定地点的当前温度。"""
|
||||
return f'当前{location}的温度为25℃'
|
||||
|
||||
|
||||
@tool
|
||||
def read_local_file(filename: str) -> str:
|
||||
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
|
||||
except Exception as e:
|
||||
return f"读取文件时出错:{str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def read_pdf_summary(filename: str) -> str:
|
||||
"""读取PDF文件并返回内容文本摘要。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
text = ""
|
||||
with open(file_path, 'rb') as f:
|
||||
reader = pypdf.PdfReader(f)
|
||||
for page in reader.pages[:3]:
|
||||
text += page.extract_text()
|
||||
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
|
||||
except Exception as e:
|
||||
return f"读取PDF出错:{e}"
|
||||
|
||||
|
||||
@tool
|
||||
def read_excel_as_markdown(filename: str) -> str:
|
||||
"""读取Excel文件,并将其主要数据转换为Markdown表格格式。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
df = pd.read_excel(file_path)
|
||||
markdown_table = df.head(10).to_markdown(index=False)
|
||||
return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}"
|
||||
except Exception as e:
|
||||
return f"读取Excel出错:{e}"
|
||||
|
||||
|
||||
@tool
|
||||
def fetch_webpage_content(url: str) -> str:
|
||||
"""抓取给定URL的网页正文内容,并返回清晰的纯文本。"""
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = '\n'.join(chunk for chunk in chunks if chunk)
|
||||
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
|
||||
except Exception as e:
|
||||
return f"抓取网页时出错:{str(e)}"
|
||||
|
||||
|
||||
# 工具列表和映射(全局常量)
|
||||
AVAILABLE_TOOLS = [
|
||||
get_current_temperature,
|
||||
read_local_file,
|
||||
fetch_webpage_content,
|
||||
read_pdf_summary,
|
||||
read_excel_as_markdown
|
||||
]
|
||||
TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS}
|
||||
Reference in New Issue
Block a user