实现前后端分离的agent

This commit is contained in:
2026-04-13 19:49:18 +08:00
parent 09a5440045
commit 4385fabc22
13 changed files with 1317 additions and 188 deletions

2
.env
View File

@@ -1,4 +1,4 @@
LOCAL_MODEL_PATH=glm-4.7-flash
LOCAL_MODEL_PATH=gemma-4-E2B-it
ZHIPUAI_API_KEY=4d568a4367f1442bbc226cc0daf84566.44SsKVWkVIM2Mkeg
VLLM_LOCAL_KEY=token-abc123
EOF

5
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,5 @@
{
"editor.fontSize": 24,
"editor.formatOnSave": true,
"files.autoSave": "onWindowChange"
}

245
QUICKSTART.md Normal file
View File

@@ -0,0 +1,245 @@
# 快速开始指南 - 多模型切换功能
## 🚀 5分钟快速启动
### 步骤 1: 启动必要的容器
```bash
# 使用提供的启动脚本(推荐)
./start.sh
# 或者手动启动容器
# 1. 启动 vLLM (如果需要本地模型)
docker run -d --rm \
--group-add=video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--device=/dev/kfd \
--device=/dev/dri \
-v /home/huang/Study/AIModel/gemma-4-E2B-it:/models/gemma-4-E2B-it \
-e VLLM_ROCM_USE_AITER=0 \
-e HF_TOKEN="$HF_TOKEN" \
-p 8000:8000 \
--ipc=host \
--entrypoint vllm \
my-vllm-gemma4:working \
serve /models/gemma-4-E2B-it \
--served-model-name gemma-4-E2B-it \
--dtype auto \
--api-key token-abc123 \
--trust-remote-code \
--port 8000 \
--gpu-memory-utilization 0.85 \
--max-model-len 8192
# 2. 启动 PostgreSQL
docker run -d \
--name postgres-langgraph \
-e POSTGRES_PASSWORD=mysecretpassword \
-e POSTGRES_DB=langgraph_db \
-p 5432:5432 \
-v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \
postgres:16
```
### 步骤 2: 配置环境变量
编辑 `.env` 文件:
```env
ZHIPUAI_API_KEY=your_actual_zhipuai_api_key
VLLM_LOCAL_KEY=token-abc123
```
### 步骤 3: 启动服务
```bash
# 方式1: 使用启动脚本(推荐)
./start.sh
# 方式2: 手动启动
# 终端1: 启动后端
python backend.py
# 终端2: 启动前端
streamlit run frontend.py
```
### 步骤 4: 访问应用
浏览器打开: `http://localhost:8501`
---
## 🎯 使用多模型切换功能
### 在前端切换模型
1. **打开侧边栏**:点击左上角的菜单图标
2. **选择模型**:在"选择大模型"下拉框中选择:
- 智谱 GLM-4.7-Flash在线
- 本地 vLLMGemma-4
3. **开始对话**:输入您的问题,系统会使用选定的模型处理
### 特性说明
**实时切换**:可以在对话过程中随时切换模型
**记忆共享**:同一会话 ID 下,不同模型共享对话历史
**自动降级**:如果选择的模型不可用,自动切换到可用模型
**状态显示**:每条回复下方会显示实际使用的模型
---
## 🧪 测试功能
### 运行自动化测试
```bash
# 确保后端正在运行
python test_multi_model.py
```
测试内容包括:
- 各模型的可用性测试
- 跨模型会话记忆测试
- API 响应格式验证
### 手动测试
1. **测试智谱模型**
- 选择"智谱 GLM-4.7-Flash"
- 询问:"你好,请介绍一下自己"
- 观察回复速度和内容质量
2. **测试本地模型**
- 选择"本地 vLLMGemma-4"
- 询问相同问题
- 对比两个模型的回复差异
3. **测试记忆功能**
- 第一轮(智谱模型):"我叫小明,记住我的名字"
- 第二轮(本地模型):"我叫什么名字?"
- 验证是否能正确回忆
---
## 🔧 常见问题
### Q1: 某个模型初始化失败怎么办?
**A:** 系统会自动跳过失败的模型,使用其他可用模型。检查日志了解具体原因:
- 智谱模型:确认 `ZHIPUAI_API_KEY` 是否正确
- 本地模型:确认 vLLM 容器是否运行
### Q2: 如何添加新模型?
**A:**`agent.py` 中添加:
```python
def _create_new_model_llm(self):
"""创建新模型的 LLM"""
return YourChatModel(
model="model-name",
api_key="your-key",
# ... 其他参数
)
# 在 initialize() 方法的 model_configs 中添加
model_configs = {
"zhipu": self._create_zhipu_llm,
"local": self._create_local_llm,
"new_model": self._create_new_model_llm, # 新增
}
```
然后在前端 `frontend.py``MODEL_OPTIONS` 中添加对应选项。
### Q3: 会话记忆是如何工作的?
**A:**
- 使用 PostgreSQL 存储对话历史
- 通过 `thread_id` 关联同一会话的消息
- 不同模型共享同一个 checkpointer因此可以跨模型保持上下文
- 点击"新会话"按钮会生成新的 `thread_id`
### Q4: 性能优化建议
**A:**
- 智谱模型:适合快速响应场景,无需本地 GPU
- 本地模型:适合数据隐私要求高的场景,需要 GPU 支持
- 长时间对话建议定期开启新会话,避免上下文过长
---
## 📊 架构优势
### 预编译 Graph
每个模型在启动时都会预编译独立的 LangGraph
- ✅ 避免每次请求都重新编译,提升性能
- ✅ 各模型独立,互不影响
- ✅ 支持热插拔,可动态添加/移除模型
### 智能降级
如果选择的模型不可用:
1. 后端自动切换到第一个可用模型
2. 返回响应中包含 `model_used` 字段
3. 前端显示实际使用的模型
4. 用户无感知,体验流畅
### 统一接口
无论使用哪个模型:
- API 接口保持一致
- 工具调用方式相同
- 会话记忆机制统一
- 前端操作体验一致
---
## 🎓 进阶使用
### 固定会话 ID
如需在不同浏览器或设备间继续同一会话:
```python
# 在 frontend.py 中修改
st.session_state.thread_id = "my_fixed_session_id"
```
### 自定义超时时间
```python
# 在 frontend.py 中修改 timeout 参数
response = requests.post(
API_URL,
json={...},
timeout=120 # 增加到 120 秒
)
```
### 批量测试
```python
# 创建测试脚本
import requests
messages = ["问题1", "问题2", "问题3"]
for msg in messages:
response = requests.post(API_URL, json={"message": msg, "model": "zhipu"})
print(response.json()["reply"])
```
---
## 📞 获取帮助
- 查看完整文档:[README.md](README.md)
- 查看项目结构:参考 [README.md](README.md) 中的项目结构部分
- 报告问题:提交 Issue 并附上日志信息
---
**祝您使用愉快!** 🎉

268
README.md
View File

@@ -1,2 +1,268 @@
# ailine
# AI Agent - 个人生活助手和数据分析助手
## 项目概述
这是一个基于 LangGraph、LangChain 和 FastAPI 构建的 AI 助手系统,能够处理天气查询、文件读取、网页抓取等任务。采用前后端分离架构,支持 PostgreSQL 持久化对话记忆。
## 项目结构
```
Agent1/
├── tools.py # 工具定义(纯函数、@tool
├── graph_builder.py # LangGraph 状态图构建(状态定义、节点、边)
├── agent.py # AIAgentService 类模型初始化、graph 管理、消息处理)
├── backend.py # FastAPI 应用路由、WebSocket、lifespan
├── frontend.py # Streamlit 前端(通过 HTTP 调用后端)
├── .env # 环境变量ZHIPUAI_API_KEY 等)
├── requirement.txt # Python 依赖包列表
└── user_docs/ # 允许读取的文档目录
├── a.txt
├── b.pdf
└── c.xlsx
```
## 核心功能
- 🌤️ **天气查询**:获取指定地点的当前温度
- 📄 **文本文件读取**:读取 `.txt``.md` 等文本文件
- 📑 **PDF 文件读取**:解析 PDF 文件并提取文本内容
- 📊 **Excel 数据处理**:读取 Excel 文件并转换为 Markdown 表格
- 🌐 **网页抓取**:抓取网页正文内容
- 💾 **持久化记忆**:使用 PostgreSQL 保存对话历史,支持多轮对话上下文
- 🔄 **多模型动态切换**:前端可选择不同的大语言模型,后端自动切换处理
## 技术栈
- **后端框架**FastAPI + Uvicorn
- **前端框架**Streamlit
- **AI 框架**LangGraph + LangChain
- **数据库**PostgreSQL用于持久化对话记忆
- **LLM 支持**
- 智谱 AIglm-4.7-flash在线服务响应速度快
- 本地 vLLMgemma-4-E2B-it本地部署数据隐私性好
系统支持多种大语言模型,可在前端动态切换。每个模型在启动时都会预编译独立的 LangGraph确保最佳性能。如果某个模型初始化失败如 API Key 未配置),系统会自动降级到可用模型。
## 环境要求
- Python 3.10+
- PostgreSQL 16+
- Docker可选用于运行 PostgreSQL
## 安装步骤
### 1. 启动 PostgreSQL 容器
```bash
docker run -d \
--name postgres-langgraph \
-e POSTGRES_PASSWORD=mysecretpassword \
-e POSTGRES_DB=langgraph_db \
-p 5432:5432 \
-v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \
postgres:16
```
### 2. 安装 Python 依赖
```bash
pip install fastapi uvicorn streamlit requests psycopg[binary,pool] \
langgraph langgraph-checkpoint-postgres langchain langchain-community \
langchain-openai python-dotenv pypdf pandas beautifulsoup4
```
或者使用 requirements.txt
```bash
pip install -r requirement.txt
```
### 3. 配置环境变量
编辑 `.env` 文件,设置您的 API 密钥:
```env
ZHIPUAI_API_KEY=your_zhipuai_api_key_here
VLLM_LOCAL_KEY=token-abc123 # 如果使用本地模型
```
## 运行步骤
### 1. 启动后端服务
```bash
python backend.py
```
看到 `Uvicorn running on http://0.0.0.0:8001` 即表示启动成功。
### 2. 启动前端界面(新终端)
```bash
streamlit run frontend.py
```
浏览器会自动打开 `http://localhost:8501`,即可开始使用。
## API 接口
### POST /chat
同步对话接口,支持模型选择
**请求体:**
```json
{
"message": "今天北京天气怎么样?",
"thread_id": "optional-thread-id",
"model": "zhipu" // 可选: "zhipu" 或 "local"
}
```
**响应:**
```json
{
"reply": "当前北京的温度为25℃",
"thread_id": "generated-or-provided-thread-id",
"model_used": "zhipu" // 实际使用的模型
}
```
**模型选项:**
- `zhipu`:智谱 GLM-4.7-Flash在线
- `local`:本地 vLLM Gemma-4需要启动 vLLM 容器)
### WebSocket /ws
流式对话接口(可选扩展)
## 使用说明
### 工具调用示例
1. **查询天气**
```
用户:今天上海天气怎么样?
```
2. **读取文本文件**
```
用户:请读取 a.txt 文件的内容
```
3. **读取 PDF 文件**
```
用户:帮我总结一下 b.pdf 的内容
```
4. **读取 Excel 文件**
```
用户:显示 c.xlsx 的数据
```
5. **抓取网页**
```
用户:请抓取 https://example.com 的内容
```
### 会话记忆
- 系统会自动为每个会话生成唯一的 `thread_id`
- 相同 `thread_id` 的对话会共享历史记录
- 即使重启后端服务,对话历史依然保存在 PostgreSQL 中
- 如需固定会话 ID可在前端代码中修改 `st.session_state.thread_id` 为固定字符串
### 多模型切换
**前端操作:**
1. 在左侧边栏的"选择大模型"下拉框中选择模型
2. 可随时切换模型,甚至在同一会话中
3. 点击"🔄 新会话"按钮可清空当前对话并开始新的会话
**后端行为:**
- 启动时会预编译所有可用模型的 LangGraph
- 如果某个模型初始化失败(如 API Key 未配置),会自动跳过
- 请求时如果选择的模型不可用,会自动降级到第一个可用模型
- 响应中会返回 `model_used` 字段,显示实际使用的模型
**添加新模型:**
在 `agent.py` 的 `initialize()` 方法中的 `model_configs` 字典添加新模型即可:
```python
model_configs = {
"zhipu": self._create_zhipu_llm,
"local": self._create_local_llm,
"new_model": self._create_new_model_llm, # 添加新模型
}
```
## 架构说明
### 模块职责
- **tools.py**:独立工具模块,包含所有 `@tool` 装饰的纯函数,无外部依赖,可单独测试
- **graph_builder.py**LangGraph 状态图构建器,定义状态、节点函数和条件边
- **agent.py**AIAgentService 服务类,负责模型初始化和 graph 编译,使用 `AsyncPostgresSaver`
- **backend.py**FastAPI 应用,提供 REST API 和 WebSocket 接口,端口 8001
- **frontend.py**Streamlit 前端,通过 HTTP 调用后端 API实现友好的用户界面
### 数据流
```
用户输入 → Streamlit 前端 → FastAPI 后端 → AIAgentService
→ LangGraph StateGraph → LLM + Tools → PostgreSQL (记忆)
→ 返回响应 → 前端展示
```
## 注意事项
1. **文件安全**:所有文件读取操作仅限于 `./user_docs` 目录,防止路径遍历攻击
2. **端口冲突**:后端使用 8001 端口,避免与本地 vLLM 服务的 8000 端口冲突
3. **API 密钥**:请妥善保管 `.env` 文件中的 API 密钥,不要提交到版本控制系统
4. **数据库持久化**PostgreSQL 数据卷挂载到 `~/docker_volumes/postgres_data`,确保数据安全
## 故障排除
### 问题:无法连接 PostgreSQL
**解决方案:**
```bash
# 检查容器是否运行
docker ps | grep postgres-langgraph
# 查看容器日志
docker logs postgres-langgraph
# 重新启动容器
docker restart postgres-langgraph
```
### 问题:后端启动失败
**解决方案:**
- 确认端口 8001 未被占用
- 检查 `.env` 文件中的 API 密钥是否正确配置
- 确认所有依赖包已正确安装
- 查看启动日志,确认至少有一个模型初始化成功
### 问题:模型切换后无响应
**解决方案:**
- 检查所选模型的配置是否正确(如智谱 API Key
- 确认 vLLM 容器是否正在运行(如果使用本地模型)
- 查看后端日志,确认模型是否初始化成功
- 尝试切换到另一个模型
### 问题:工具调用失败
**解决方案:**
- 确认文件位于 `./user_docs` 目录下
- 检查文件格式是否正确
- 查看后端日志获取详细错误信息
## 许可证
本项目采用 MIT 许可证。详见 [LICENSE](LICENSE) 文件。
## 贡献
欢迎提交 Issue 和 Pull Request

224
agent.py
View File

@@ -1,187 +1,85 @@
from bs4 import BeautifulSoup
from langchain.agents import create_agent
import requests
import pypdf
import pandas as pd
from dotenv import load_dotenv
"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import os
import time
from pathlib import Path
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_huggingface import HuggingFacePipeline,ChatHuggingFace
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from transformers import BitsAndBytesConfig
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
##--基础定义
# 本地模块
from graph_builder import GraphBuilder
from tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
load_dotenv()
LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH","glm-4.7-flash")
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
VLLM_LOCAL_KEY = os.getenv("VLLM_LOCAL_KEY", "")
DEVICE = os.getenv("DEVICE")
##加载模型
local_llm = None
online_llm = None
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def get_local_llm():
global local_llm
if local_llm is None:
local_llm = ChatOpenAI(
base_url="http://localhost:8000/v1",
api_key=SecretStr(VLLM_LOCAL_KEY),
model="gemma-4-E2B-it",
)
return local_llm
def __init__(self, checkpointer):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
def get_online_llm():
global online_llm
if online_llm is None:
online_llm = ChatZhipuAI(
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=ZHIPUAI_API_KEY,
api_key=api_key,
temperature=0.1,
max_tokens=4096,
)
return online_llm
##工具调用
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
return ChatOpenAI(
base_url="http://localhost:8000/v1",
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
model="gemma-4-E2B-it",
)
@tool
def get_currenttemperature(location: str) -> str:
"""获取指定地点的当前温度,当用户询问天气或温度时使用此工具。"""
return f'当前{location}的温度为25℃'
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
model_configs = {
"zhipu": self._create_zhipu_llm,
"local": self._create_local_llm,
}
# sym:file_allow_check
def file_allow_check(filename: str) -> Path:
"""
检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。
返回合法的 Path 对象,若不合法则抛出异常。
"""
allowed_dir = Path("./user_docs").resolve()
allowed_dir.mkdir(exist_ok=True)
file_path = (allowed_dir / filename).resolve()
if not str(file_path).startswith(str(allowed_dir)):
raise ValueError("错误:非法文件路径。")
if not file_path.exists():
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
return file_path
@tool
def read_local_file(filename: str) -> str:
"""
读取用户指定名称的本地文本文件内容并返回摘要。
参数 filename: 文件名,例如 'project_plan.txt''notes.md'
"""
for model_name, llm_creator in model_configs.items():
try:
file_path = file_allow_check(filename)
except (ValueError, FileNotFoundError) as e:
return str(e)
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 2. 内容过长时可以在此处增加一个简单的摘要逻辑或者直接返回前N个字符
# 为了演示这里返回前1000个字符
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
print(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
return f"读取文件时出错:{str(e)}"
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型,请检查配置")
@tool
def read_pdf_summary(filename: str) -> str:
"""
读取PDF文件并返回内容文本。参数 filename: PDF文件名例如 'report.pdf'
"""
try:
file_path = file_allow_check(filename)
except (ValueError, FileNotFoundError) as e:
return str(e)
try:
text = ""
with open(file_path, 'rb') as f:
reader = pypdf.PdfReader(f)
for page in reader.pages[:3]:
text += page.extract_text()
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
except Exception as e:
return f"读取PDF出错{e}"
return self
@tool
def read_excel_as_markdown(filename: str) -> str:
"""
读取Excel文件并将其主要数据转换为Markdown表格格式。参数 filename: Excel文件名例如 'data.xlsx'
"""
try:
file_path = file_allow_check(filename)
except (ValueError, FileNotFoundError) as e:
return str(e)
try:
df = pd.read_excel(file_path)
markdown_table = df.head(10).to_markdown(index=False)
return f"Excel文件 '{filename}' 的数据预览前10行\n{markdown_table}"
except Exception as e:
return f"读取Excel出错{e}"
@tool
def fetch_webpage_content(url: str) -> str:
"""
抓取给定URL的网页正文内容并返回清晰的纯文本。
参数 url: 完整的网页地址,例如 'https://example.com/article'
"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
# 简单的正文提取,去除脚本和样式
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
except Exception as e:
return f"抓取网页时出错:{str(e)}"
#使用langgraph
agent=create_agent(
model=get_local_llm(),
tools=[get_currenttemperature,read_local_file,fetch_webpage_content,read_pdf_summary,read_excel_as_markdown],
system_prompt=(
"你是一个个人生活助手和数据分析助手。请说中文。"
"当用户询问天气或温度时使用get_currenttemperature工具获取信息。"
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读PDF文件时请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读Excel文件时请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
"当用户要求分析文档时请使用合适的工具读取内容然后1. 总结核心发现。2. 如果涉及数据请以Markdown表格或列表的形式清晰地呈现。"
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述、<think>标记或内部推理。直接给出最终答案或工具调用指令。"
)
)
while True:
user_input = input("请输入: ")
if user_input.lower() == "exit":
break
# 记录开始时间
start_time = time.time()
response=agent.invoke({"messages":[HumanMessage(content=user_input)]})
# 计算思考时间
thinking_time = time.time() - start_time
# 提取回答内容
final_answer=response["messages"][-1].content
# 打印回答和统计信息
print(f"\n{final_answer}")
print(f"思考时间: {thinking_time:.2f}")
print("-" * 50)
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
"""处理用户消息,返回最终答案"""
if model not in self.graphs:
fallback_model = next(iter(self.graphs.keys()))
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
model = fallback_model
graph = self.graphs[model]
config = {"configurable": {"thread_id": thread_id}}
input_state = {"messages": [HumanMessage(content=message)]}
result = await graph.ainvoke(input_state, config=config)
return result["messages"][-1].content

115
backend.py Normal file
View File

@@ -0,0 +1,115 @@
"""
FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
采用依赖注入模式,优雅管理资源生命周期
"""
import uuid
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from agent import AIAgentService
# PostgreSQL 连接字符串
DB_URI = "postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable"
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
# 1. 创建数据库连接池并初始化表
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
# 2. 构建 AI Agent 服务
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
# 3. 将服务实例存入 app.state
app.state.agent_service = agent_service
# 应用运行中...
yield
# 4. 关闭时自动清理数据库连接async with 负责)
print("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== Pydantic 模型 ==========
class ChatRequest(BaseModel):
message: str
thread_id: str | None = None
model: str = "zhipu"
class ChatResponse(BaseModel):
reply: str
thread_id: str
model_used: str
# ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
agent_service: AIAgentService = Depends(get_agent_service)
):
"""同步对话接口,支持模型选择"""
if not request.message:
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
reply = await agent_service.process_message(
request.message, thread_id, request.model
)
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
agent_service: AIAgentService = Depends(get_agent_service)
):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
message = data.get("message")
thread_id = data.get("thread_id", str(uuid.uuid4()))
model = data.get("model", "zhipu")
if not message:
await websocket.send_json({"error": "missing message"})
continue
reply = await agent_service.process_message(message, thread_id, model)
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
except WebSocketDisconnect:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)

94
frontend.py Normal file
View File

@@ -0,0 +1,94 @@
"""
Streamlit 前端 - 支持模型选择
"""
# 标准库
import uuid
# 第三方库
import requests
import streamlit as st
# 后端 API 地址(端口 8001
API_URL = "http://localhost:8001/chat"
st.set_page_config(page_title="AI 个人助手", page_icon="🤖")
st.title("🤖 个人生活与数据分析助手")
# 模型选项(与后端支持的模型名称一致)
MODEL_OPTIONS = {
"zhipu": "智谱 GLM-4.7-Flash在线",
"local": "本地 vLLMGemma-4"
}
# 初始化会话状态
if "messages" not in st.session_state:
st.session_state.messages = []
if "thread_id" not in st.session_state:
st.session_state.thread_id = str(uuid.uuid4())
if "selected_model" not in st.session_state:
st.session_state.selected_model = "zhipu"
# 侧边栏:模型选择和会话管理
with st.sidebar:
st.header("⚙️ 设置")
# 模型选择
selected_model_key = st.selectbox(
"选择大模型",
options=list(MODEL_OPTIONS.keys()),
format_func=lambda x: MODEL_OPTIONS[x],
index=0
)
st.session_state.selected_model = selected_model_key
# 会话信息显示
st.write(f"当前会话 ID: `{st.session_state.thread_id[:8]}...`")
# 新会话按钮
if st.button("🔄 新会话"):
st.session_state.thread_id = str(uuid.uuid4())
st.session_state.messages = []
st.rerun()
# 显示历史消息
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# 用户输入
if prompt := st.chat_input("请输入您的问题..."):
# 显示用户消息
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
# 调用后端 API携带模型参数
with st.chat_message("assistant"):
with st.spinner("思考中..."):
try:
response = requests.post(
API_URL,
json={
"message": prompt,
"thread_id": st.session_state.thread_id,
"model": st.session_state.selected_model
},
timeout=60
)
response.raise_for_status()
data = response.json()
reply = data["reply"]
model_used = data["model_used"]
# 显示回复
st.markdown(reply)
# 显示使用的模型(小字提示)
st.caption(f"🤖 使用模型: {MODEL_OPTIONS.get(model_used, model_used)}")
st.session_state.messages.append({"role": "assistant", "content": reply})
except Exception as e:
error_msg = f"请求失败: {e}"
st.error(error_msg)
st.session_state.messages.append({"role": "assistant", "content": error_msg})

127
graph_builder.py Normal file
View File

@@ -0,0 +1,127 @@
"""
LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
"""
import operator
import asyncio
from typing import Literal, Annotated, Any
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
class MessageState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
class GraphBuilder:
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
self._llm_with_tools = llm.bind_tools(tools)
self._prompt = self._create_prompt()
self._chain = self._prompt | self._llm_with_tools
@staticmethod
def _create_prompt() -> ChatPromptTemplate:
"""创建系统提示模板(静态方法,无需访问实例)"""
return ChatPromptTemplate.from_messages([
SystemMessage(content=(
"你是一个个人生活助手和数据分析助手。请说中文。"
"当用户询问天气或温度时,使用get_current_temperature工具获取信息。"
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
)),
MessagesPlaceholder(variable_name="message")
])
async def call_llm(self, state: MessageState) -> dict:
"""
LLM 调用节点(异步方法)
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
"""
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self._chain.invoke({"message": state["messages"]})
)
return {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1
}
async def call_tools(self, state: MessageState) -> dict:
"""
工具执行节点(异步方法)
对于每个工具调用,在线程池中执行同步工具函数
"""
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return {"messages": []}
results = []
loop = asyncio.get_event_loop()
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = self.tools_by_name.get(tool_name)
if tool_func is None:
results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id))
continue
try:
# 同步工具函数在线程池中执行
observation = await loop.run_in_executor(
None,
lambda: tool_func.invoke(tool_args)
)
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
except Exception as e:
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
return {"messages": results}
@staticmethod
def should_continue(state: MessageState) -> Literal['tool_node', END]:
"""
条件边判断(静态方法)
决定下一步是进入工具节点还是结束
"""
last_message = state["messages"][-1]
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
return 'tool_node'
return END
def build(self) -> StateGraph:
"""
构建未编译的状态图(返回 StateGraph 实例)
图中节点直接使用实例方法 call_llm, call_tools
"""
builder = StateGraph(MessageState)
builder.add_node("llm_call", self.call_llm)
builder.add_node("tool_node", self.call_tools)
builder.add_edge(START, "llm_call")
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
builder.add_edge("tool_node", "llm_call")
return builder

View File

@@ -13,11 +13,28 @@ langchain-huggingface>=0.0.3
langchain-core>=0.1.0
langchain-openai>=0.0.5
# LangGraph
langgraph>=0.0.30
langgraph-checkpoint-postgres>=0.0.5
# ZhipuAI (智谱AI)
zhipuai>=1.0.0
# Backend
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
websockets>=12.0
# Frontend
streamlit>=1.30.0
# Database
psycopg[binary,pool]>=3.1.0
# Pydantic
pydantic>=2.0.0
# Utilities
python-dotenv>=1.0.0
typing-extensions>=4.9.0
ipython>=8.0.0

145
start.sh Normal file
View File

@@ -0,0 +1,145 @@
#!/bin/bash
# AI Agent 启动脚本
# 用法: ./start.sh [backend|frontend|both]
set -e
# 颜色定义
GREEN='\033[0;32m'
BLUE='\033[0;34m'
RED='\033[0;31m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo -e "${BLUE}========================================${NC}"
echo -e "${BLUE} AI Agent - 个人生活助手启动脚本${NC}"
echo -e "${BLUE}========================================${NC}"
echo ""
# 检查 vLLM 容器是否运行
check_vllm() {
if ! docker ps --format '{{.Names}}' | grep -q "^gemma4-server$"; then
echo -e "${YELLOW}⚠️ vLLM 容器未运行!${NC}"
echo "正在启动 vLLM 容器Gemma-4 模型)..."
# 检查模型文件是否存在
if [ ! -d "/home/huang/Study/AIModel/gemma-4-E2B-it" ]; then
echo -e "${RED}✗ 错误:模型目录不存在: /home/huang/Study/AIModel/gemma-4-E2B-it${NC}"
echo "请先下载模型或修改模型路径"
exit 1
fi
docker run -d \
--name gemma4-server \
--group-add=video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--device=/dev/kfd \
--device=/dev/dri \
-v /home/huang/Study/AIModel/gemma-4-E2B-it:/models/gemma-4-E2B-it \
-e VLLM_ROCM_USE_AITER=0 \
-e HF_TOKEN="${HF_TOKEN}" \
-p 8000:8000 \
--ipc=host \
--entrypoint vllm \
my-vllm-gemma4:working \
serve /models/gemma-4-E2B-it \
--served-model-name gemma-4-E2B-it \
--dtype auto \
--api-key token-abc123 \
--trust-remote-code \
--port 8000 \
--gpu-memory-utilization 0.85 \
--max-model-len 8192
echo -e "${GREEN}✓ vLLM 容器已启动${NC}"
echo -e "${YELLOW}⏳ 等待模型加载(可能需要几分钟)...${NC}"
sleep 10
else
echo -e "${GREEN}✓ vLLM 容器正在运行${NC}"
fi
}
# 检查 PostgreSQL 容器是否运行
check_postgres() {
if ! docker ps | grep -q postgres-langgraph; then
echo -e "${YELLOW}⚠️ PostgreSQL 容器未运行!${NC}"
echo "正在启动 PostgreSQL 容器..."
docker run -d \
--name postgres-langgraph \
-e POSTGRES_PASSWORD=mysecretpassword \
-e POSTGRES_DB=langgraph_db \
-p 5432:5432 \
-v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \
postgres:16
echo -e "${GREEN}✓ PostgreSQL 容器已启动${NC}"
sleep 3
else
echo -e "${GREEN}✓ PostgreSQL 容器正在运行${NC}"
fi
}
# 启动后端
start_backend() {
echo -e "\n${BLUE}🚀 启动后端服务 (端口 8001)...${NC}"
python backend.py &
BACKEND_PID=$!
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
sleep 2
}
# 启动前端
start_frontend() {
echo -e "\n${BLUE}🎨 启动前端界面...${NC}"
streamlit run frontend.py &
FRONTEND_PID=$!
echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}"
echo -e "${GREEN}✓ 请在浏览器中打开: http://localhost:8501${NC}"
}
# 清理函数
cleanup() {
echo -e "\n${RED}🛑 正在停止所有服务...${NC}"
if [ ! -z "$BACKEND_PID" ]; then
kill $BACKEND_PID 2>/dev/null || true
echo -e "${GREEN}✓ 后端服务已停止${NC}"
fi
if [ ! -z "$FRONTEND_PID" ]; then
kill $FRONTEND_PID 2>/dev/null || true
echo -e "${GREEN}✓ 前端服务已停止${NC}"
fi
echo -e "${YELLOW}💡 提示Docker 容器需要手动停止${NC}"
echo -e " 停止 vLLM: docker stop gemma4-server"
echo -e " 停止 PostgreSQL: docker stop postgres-langgraph"
exit 0
}
# 捕获 Ctrl+C
trap cleanup SIGINT SIGTERM
# 主逻辑
case "${1:-both}" in
backend)
check_vllm
check_postgres
start_backend
echo -e "\n${GREEN}后端服务正在运行,按 Ctrl+C 停止${NC}"
wait $BACKEND_PID
;;
frontend)
start_frontend
echo -e "\n${GREEN}前端服务正在运行,按 Ctrl+C 停止${NC}"
wait $FRONTEND_PID
;;
both|*)
check_vllm
check_postgres
start_backend
start_frontend
echo -e "\n${GREEN}所有服务正在运行,按 Ctrl+C 停止 Python 服务${NC}"
echo -e "${YELLOW}注意Docker 容器会在后台继续运行${NC}"
wait
;;
esac

View File

@@ -1,20 +0,0 @@
from openai import OpenAI
# 连接本地 vLLM 服务
client = OpenAI(
base_url="http://localhost:8000/v1", # 容器映射的地址
api_key="token-abc123", # 与你启动命令中的 --api-key 一致
)
# 发起对话
response = client.chat.completions.create(
model="gemma-4-E2B-it", # --served-model-name 指定的名称
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "请用中文介绍一下你自己"}
],
temperature=0.7,
max_tokens=512,
)
print(response.choices[0].message.content)

134
test_multi_model.py Normal file
View File

@@ -0,0 +1,134 @@
"""
多模型切换功能测试脚本
用于验证后端是否正确支持多模型动态切换
"""
import requests
import json
API_URL = "http://localhost:8001/chat"
def test_model_switching():
"""测试模型切换功能"""
print("=" * 60)
print("测试多模型切换功能")
print("=" * 60)
# 测试消息
test_message = "你好,请简单介绍一下自己"
# 测试不同的模型
models_to_test = ["zhipu", "local"]
for model in models_to_test:
print(f"\n📤 测试模型: {model}")
print("-" * 60)
try:
response = requests.post(
API_URL,
json={
"message": test_message,
"model": model
},
timeout=30
)
if response.status_code == 200:
data = response.json()
print(f"✅ 成功!")
print(f" 使用的模型: {data['model_used']}")
print(f" 会话 ID: {data['thread_id'][:8]}...")
print(f" 回复预览: {data['reply'][:100]}...")
else:
print(f"❌ 失败! 状态码: {response.status_code}")
print(f" 错误信息: {response.text}")
except requests.exceptions.Timeout:
print(f"⏰ 超时! 模型 '{model}' 响应时间过长")
except requests.exceptions.ConnectionError:
print(f"🔌 连接失败! 请确认后端服务正在运行 (python backend.py)")
except Exception as e:
print(f"💥 异常: {str(e)}")
print("\n" + "=" * 60)
print("测试完成!")
print("=" * 60)
def test_conversation_memory():
"""测试跨模型的会话记忆"""
print("\n" + "=" * 60)
print("测试跨模型会话记忆")
print("=" * 60)
import uuid
thread_id = str(uuid.uuid4())
print(f"\n📝 使用固定会话 ID: {thread_id[:8]}...")
# 第一轮对话 - 使用 zhipu 模型
print("\n📤 第1轮 - 使用 zhipu 模型")
try:
response1 = requests.post(
API_URL,
json={
"message": "我叫小明,记住我的名字",
"thread_id": thread_id,
"model": "zhipu"
},
timeout=30
)
if response1.status_code == 200:
data1 = response1.json()
print(f" ✅ 回复: {data1['reply'][:100]}...")
print(f" 🤖 使用模型: {data1['model_used']}")
except Exception as e:
print(f" ❌ 失败: {e}")
return
# 第二轮对话 - 切换到 local 模型,测试是否记得名字
print("\n📤 第2轮 - 切换到 local 模型")
try:
response2 = requests.post(
API_URL,
json={
"message": "我叫什么名字?",
"thread_id": thread_id,
"model": "local"
},
timeout=30
)
if response2.status_code == 200:
data2 = response2.json()
print(f" ✅ 回复: {data2['reply'][:100]}...")
print(f" 🤖 使用模型: {data2['model_used']}")
# 检查是否记得名字
if "小明" in data2['reply']:
print(" 🎉 成功!跨模型记忆功能正常")
else:
print(" ⚠️ 注意:模型可能没有正确回忆上下文")
except Exception as e:
print(f" ❌ 失败: {e}")
print("\n" + "=" * 60)
print("会话记忆测试完成!")
print("=" * 60)
if __name__ == "__main__":
print("\n⚠️ 请确保后端服务正在运行 (python backend.py)\n")
# 运行基本测试
test_model_switching()
# 询问是否运行记忆测试
choice = input("\n是否运行会话记忆测试?(y/n): ").strip().lower()
if choice == 'y':
test_conversation_memory()
print("\n✨ 所有测试完成!")

103
tools.py Normal file
View File

@@ -0,0 +1,103 @@
"""
工具定义模块 - 纯函数工具,无依赖 AIAgent 类
"""
# 标准库
import os
from pathlib import Path
# 第三方库
import pandas as pd
import pypdf
import requests
from bs4 import BeautifulSoup
from langchain_core.tools import tool
def _file_allow_check(filename: str) -> Path:
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
allowed_dir = Path("./user_docs").resolve()
allowed_dir.mkdir(exist_ok=True)
file_path = (allowed_dir / filename).resolve()
if not str(file_path).startswith(str(allowed_dir)):
raise ValueError("错误:非法文件路径。")
if not file_path.exists():
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
return file_path
@tool
def get_current_temperature(location: str) -> str:
"""获取指定地点的当前温度。"""
return f'当前{location}的温度为25℃'
@tool
def read_local_file(filename: str) -> str:
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
try:
file_path = _file_allow_check(filename)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
except Exception as e:
return f"读取文件时出错:{str(e)}"
@tool
def read_pdf_summary(filename: str) -> str:
"""读取PDF文件并返回内容文本摘要。"""
try:
file_path = _file_allow_check(filename)
text = ""
with open(file_path, 'rb') as f:
reader = pypdf.PdfReader(f)
for page in reader.pages[:3]:
text += page.extract_text()
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
except Exception as e:
return f"读取PDF出错{e}"
@tool
def read_excel_as_markdown(filename: str) -> str:
"""读取Excel文件并将其主要数据转换为Markdown表格格式。"""
try:
file_path = _file_allow_check(filename)
df = pd.read_excel(file_path)
markdown_table = df.head(10).to_markdown(index=False)
return f"Excel文件 '{filename}' 的数据预览前10行\n{markdown_table}"
except Exception as e:
return f"读取Excel出错{e}"
@tool
def fetch_webpage_content(url: str) -> str:
"""抓取给定URL的网页正文内容并返回清晰的纯文本。"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
except Exception as e:
return f"抓取网页时出错:{str(e)}"
# 工具列表和映射(全局常量)
AVAILABLE_TOOLS = [
get_current_temperature,
read_local_file,
fetch_webpage_content,
read_pdf_summary,
read_excel_as_markdown
]
TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS}