From 4385fabc22bab0f9e1afef5189a29af1e817efbb Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Mon, 13 Apr 2026 19:49:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=89=8D=E5=90=8E=E7=AB=AF?= =?UTF-8?q?=E5=88=86=E7=A6=BB=E7=9A=84agent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env | 2 +- .vscode/settings.json | 5 + QUICKSTART.md | 245 ++++++++++++++++++++++++++++++++++++++ README.md | 268 +++++++++++++++++++++++++++++++++++++++++- agent.py | 228 ++++++++++------------------------- backend.py | 115 ++++++++++++++++++ frontend.py | 94 +++++++++++++++ graph_builder.py | 127 ++++++++++++++++++++ requirement.txt | 19 ++- start.sh | 145 +++++++++++++++++++++++ test_gemma.py | 20 ---- test_multi_model.py | 134 +++++++++++++++++++++ tools.py | 103 ++++++++++++++++ 13 files changed, 1317 insertions(+), 188 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 QUICKSTART.md create mode 100644 backend.py create mode 100644 frontend.py create mode 100644 graph_builder.py create mode 100644 start.sh delete mode 100644 test_gemma.py create mode 100644 test_multi_model.py create mode 100644 tools.py diff --git a/.env b/.env index 58f29ef..7ac96ac 100644 --- a/.env +++ b/.env @@ -1,4 +1,4 @@ -LOCAL_MODEL_PATH=glm-4.7-flash +LOCAL_MODEL_PATH=gemma-4-E2B-it ZHIPUAI_API_KEY=4d568a4367f1442bbc226cc0daf84566.44SsKVWkVIM2Mkeg VLLM_LOCAL_KEY=token-abc123 EOF \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..739e0c6 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "editor.fontSize": 24, + "editor.formatOnSave": true, + "files.autoSave": "onWindowChange" +} \ No newline at end of file diff --git a/QUICKSTART.md b/QUICKSTART.md new file mode 100644 index 0000000..09d2717 --- /dev/null +++ b/QUICKSTART.md @@ -0,0 +1,245 @@ +# 快速开始指南 - 多模型切换功能 + +## 🚀 5分钟快速启动 + +### 步骤 1: 启动必要的容器 + +```bash +# 使用提供的启动脚本(推荐) +./start.sh + +# 或者手动启动容器 +# 1. 启动 vLLM (如果需要本地模型) +docker run -d --rm \ + --group-add=video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --device=/dev/kfd \ + --device=/dev/dri \ + -v /home/huang/Study/AIModel/gemma-4-E2B-it:/models/gemma-4-E2B-it \ + -e VLLM_ROCM_USE_AITER=0 \ + -e HF_TOKEN="$HF_TOKEN" \ + -p 8000:8000 \ + --ipc=host \ + --entrypoint vllm \ + my-vllm-gemma4:working \ + serve /models/gemma-4-E2B-it \ + --served-model-name gemma-4-E2B-it \ + --dtype auto \ + --api-key token-abc123 \ + --trust-remote-code \ + --port 8000 \ + --gpu-memory-utilization 0.85 \ + --max-model-len 8192 + +# 2. 启动 PostgreSQL +docker run -d \ + --name postgres-langgraph \ + -e POSTGRES_PASSWORD=mysecretpassword \ + -e POSTGRES_DB=langgraph_db \ + -p 5432:5432 \ + -v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \ + postgres:16 +``` + +### 步骤 2: 配置环境变量 + +编辑 `.env` 文件: + +```env +ZHIPUAI_API_KEY=your_actual_zhipuai_api_key +VLLM_LOCAL_KEY=token-abc123 +``` + +### 步骤 3: 启动服务 + +```bash +# 方式1: 使用启动脚本(推荐) +./start.sh + +# 方式2: 手动启动 +# 终端1: 启动后端 +python backend.py + +# 终端2: 启动前端 +streamlit run frontend.py +``` + +### 步骤 4: 访问应用 + +浏览器打开: `http://localhost:8501` + +--- + +## 🎯 使用多模型切换功能 + +### 在前端切换模型 + +1. **打开侧边栏**:点击左上角的菜单图标 +2. **选择模型**:在"选择大模型"下拉框中选择: + - 智谱 GLM-4.7-Flash(在线) + - 本地 vLLM(Gemma-4) +3. **开始对话**:输入您的问题,系统会使用选定的模型处理 + +### 特性说明 + +✅ **实时切换**:可以在对话过程中随时切换模型 +✅ **记忆共享**:同一会话 ID 下,不同模型共享对话历史 +✅ **自动降级**:如果选择的模型不可用,自动切换到可用模型 +✅ **状态显示**:每条回复下方会显示实际使用的模型 + +--- + +## 🧪 测试功能 + +### 运行自动化测试 + +```bash +# 确保后端正在运行 +python test_multi_model.py +``` + +测试内容包括: +- 各模型的可用性测试 +- 跨模型会话记忆测试 +- API 响应格式验证 + +### 手动测试 + +1. **测试智谱模型**: + - 选择"智谱 GLM-4.7-Flash" + - 询问:"你好,请介绍一下自己" + - 观察回复速度和内容质量 + +2. **测试本地模型**: + - 选择"本地 vLLM(Gemma-4)" + - 询问相同问题 + - 对比两个模型的回复差异 + +3. **测试记忆功能**: + - 第一轮(智谱模型):"我叫小明,记住我的名字" + - 第二轮(本地模型):"我叫什么名字?" + - 验证是否能正确回忆 + +--- + +## 🔧 常见问题 + +### Q1: 某个模型初始化失败怎么办? + +**A:** 系统会自动跳过失败的模型,使用其他可用模型。检查日志了解具体原因: +- 智谱模型:确认 `ZHIPUAI_API_KEY` 是否正确 +- 本地模型:确认 vLLM 容器是否运行 + +### Q2: 如何添加新模型? + +**A:** 在 `agent.py` 中添加: + +```python +def _create_new_model_llm(self): + """创建新模型的 LLM""" + return YourChatModel( + model="model-name", + api_key="your-key", + # ... 其他参数 + ) + +# 在 initialize() 方法的 model_configs 中添加 +model_configs = { + "zhipu": self._create_zhipu_llm, + "local": self._create_local_llm, + "new_model": self._create_new_model_llm, # 新增 +} +``` + +然后在前端 `frontend.py` 的 `MODEL_OPTIONS` 中添加对应选项。 + +### Q3: 会话记忆是如何工作的? + +**A:** +- 使用 PostgreSQL 存储对话历史 +- 通过 `thread_id` 关联同一会话的消息 +- 不同模型共享同一个 checkpointer,因此可以跨模型保持上下文 +- 点击"新会话"按钮会生成新的 `thread_id` + +### Q4: 性能优化建议 + +**A:** +- 智谱模型:适合快速响应场景,无需本地 GPU +- 本地模型:适合数据隐私要求高的场景,需要 GPU 支持 +- 长时间对话建议定期开启新会话,避免上下文过长 + +--- + +## 📊 架构优势 + +### 预编译 Graph + +每个模型在启动时都会预编译独立的 LangGraph: +- ✅ 避免每次请求都重新编译,提升性能 +- ✅ 各模型独立,互不影响 +- ✅ 支持热插拔,可动态添加/移除模型 + +### 智能降级 + +如果选择的模型不可用: +1. 后端自动切换到第一个可用模型 +2. 返回响应中包含 `model_used` 字段 +3. 前端显示实际使用的模型 +4. 用户无感知,体验流畅 + +### 统一接口 + +无论使用哪个模型: +- API 接口保持一致 +- 工具调用方式相同 +- 会话记忆机制统一 +- 前端操作体验一致 + +--- + +## 🎓 进阶使用 + +### 固定会话 ID + +如需在不同浏览器或设备间继续同一会话: + +```python +# 在 frontend.py 中修改 +st.session_state.thread_id = "my_fixed_session_id" +``` + +### 自定义超时时间 + +```python +# 在 frontend.py 中修改 timeout 参数 +response = requests.post( + API_URL, + json={...}, + timeout=120 # 增加到 120 秒 +) +``` + +### 批量测试 + +```python +# 创建测试脚本 +import requests + +messages = ["问题1", "问题2", "问题3"] +for msg in messages: + response = requests.post(API_URL, json={"message": msg, "model": "zhipu"}) + print(response.json()["reply"]) +``` + +--- + +## 📞 获取帮助 + +- 查看完整文档:[README.md](README.md) +- 查看项目结构:参考 [README.md](README.md) 中的项目结构部分 +- 报告问题:提交 Issue 并附上日志信息 + +--- + +**祝您使用愉快!** 🎉 diff --git a/README.md b/README.md index 0ad8e2c..2ed9a01 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,268 @@ -# ailine +# AI Agent - 个人生活助手和数据分析助手 +## 项目概述 + +这是一个基于 LangGraph、LangChain 和 FastAPI 构建的 AI 助手系统,能够处理天气查询、文件读取、网页抓取等任务。采用前后端分离架构,支持 PostgreSQL 持久化对话记忆。 + +## 项目结构 + +``` +Agent1/ +├── tools.py # 工具定义(纯函数、@tool) +├── graph_builder.py # LangGraph 状态图构建(状态定义、节点、边) +├── agent.py # AIAgentService 类(模型初始化、graph 管理、消息处理) +├── backend.py # FastAPI 应用(路由、WebSocket、lifespan) +├── frontend.py # Streamlit 前端(通过 HTTP 调用后端) +├── .env # 环境变量(ZHIPUAI_API_KEY 等) +├── requirement.txt # Python 依赖包列表 +└── user_docs/ # 允许读取的文档目录 + ├── a.txt + ├── b.pdf + └── c.xlsx +``` + +## 核心功能 + +- 🌤️ **天气查询**:获取指定地点的当前温度 +- 📄 **文本文件读取**:读取 `.txt`、`.md` 等文本文件 +- 📑 **PDF 文件读取**:解析 PDF 文件并提取文本内容 +- 📊 **Excel 数据处理**:读取 Excel 文件并转换为 Markdown 表格 +- 🌐 **网页抓取**:抓取网页正文内容 +- 💾 **持久化记忆**:使用 PostgreSQL 保存对话历史,支持多轮对话上下文 +- 🔄 **多模型动态切换**:前端可选择不同的大语言模型,后端自动切换处理 + +## 技术栈 + +- **后端框架**:FastAPI + Uvicorn +- **前端框架**:Streamlit +- **AI 框架**:LangGraph + LangChain +- **数据库**:PostgreSQL(用于持久化对话记忆) +- **LLM 支持**: + - 智谱 AI(glm-4.7-flash):在线服务,响应速度快 + - 本地 vLLM(gemma-4-E2B-it):本地部署,数据隐私性好 + +系统支持多种大语言模型,可在前端动态切换。每个模型在启动时都会预编译独立的 LangGraph,确保最佳性能。如果某个模型初始化失败(如 API Key 未配置),系统会自动降级到可用模型。 + +## 环境要求 + +- Python 3.10+ +- PostgreSQL 16+ +- Docker(可选,用于运行 PostgreSQL) + +## 安装步骤 + +### 1. 启动 PostgreSQL 容器 + +```bash +docker run -d \ + --name postgres-langgraph \ + -e POSTGRES_PASSWORD=mysecretpassword \ + -e POSTGRES_DB=langgraph_db \ + -p 5432:5432 \ + -v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \ + postgres:16 +``` + +### 2. 安装 Python 依赖 + +```bash +pip install fastapi uvicorn streamlit requests psycopg[binary,pool] \ + langgraph langgraph-checkpoint-postgres langchain langchain-community \ + langchain-openai python-dotenv pypdf pandas beautifulsoup4 +``` + +或者使用 requirements.txt: + +```bash +pip install -r requirement.txt +``` + +### 3. 配置环境变量 + +编辑 `.env` 文件,设置您的 API 密钥: + +```env +ZHIPUAI_API_KEY=your_zhipuai_api_key_here +VLLM_LOCAL_KEY=token-abc123 # 如果使用本地模型 +``` + +## 运行步骤 + +### 1. 启动后端服务 + +```bash +python backend.py +``` + +看到 `Uvicorn running on http://0.0.0.0:8001` 即表示启动成功。 + +### 2. 启动前端界面(新终端) + +```bash +streamlit run frontend.py +``` + +浏览器会自动打开 `http://localhost:8501`,即可开始使用。 + +## API 接口 + +### POST /chat + +同步对话接口,支持模型选择 + +**请求体:** +```json +{ + "message": "今天北京天气怎么样?", + "thread_id": "optional-thread-id", + "model": "zhipu" // 可选: "zhipu" 或 "local" +} +``` + +**响应:** +```json +{ + "reply": "当前北京的温度为25℃", + "thread_id": "generated-or-provided-thread-id", + "model_used": "zhipu" // 实际使用的模型 +} +``` + +**模型选项:** +- `zhipu`:智谱 GLM-4.7-Flash(在线) +- `local`:本地 vLLM Gemma-4(需要启动 vLLM 容器) + +### WebSocket /ws + +流式对话接口(可选扩展) + +## 使用说明 + +### 工具调用示例 + +1. **查询天气**: + ``` + 用户:今天上海天气怎么样? + ``` + +2. **读取文本文件**: + ``` + 用户:请读取 a.txt 文件的内容 + ``` + +3. **读取 PDF 文件**: + ``` + 用户:帮我总结一下 b.pdf 的内容 + ``` + +4. **读取 Excel 文件**: + ``` + 用户:显示 c.xlsx 的数据 + ``` + +5. **抓取网页**: + ``` + 用户:请抓取 https://example.com 的内容 + ``` + +### 会话记忆 + +- 系统会自动为每个会话生成唯一的 `thread_id` +- 相同 `thread_id` 的对话会共享历史记录 +- 即使重启后端服务,对话历史依然保存在 PostgreSQL 中 +- 如需固定会话 ID,可在前端代码中修改 `st.session_state.thread_id` 为固定字符串 + +### 多模型切换 + +**前端操作:** +1. 在左侧边栏的"选择大模型"下拉框中选择模型 +2. 可随时切换模型,甚至在同一会话中 +3. 点击"🔄 新会话"按钮可清空当前对话并开始新的会话 + +**后端行为:** +- 启动时会预编译所有可用模型的 LangGraph +- 如果某个模型初始化失败(如 API Key 未配置),会自动跳过 +- 请求时如果选择的模型不可用,会自动降级到第一个可用模型 +- 响应中会返回 `model_used` 字段,显示实际使用的模型 + +**添加新模型:** +在 `agent.py` 的 `initialize()` 方法中的 `model_configs` 字典添加新模型即可: +```python +model_configs = { + "zhipu": self._create_zhipu_llm, + "local": self._create_local_llm, + "new_model": self._create_new_model_llm, # 添加新模型 +} +``` + +## 架构说明 + +### 模块职责 + +- **tools.py**:独立工具模块,包含所有 `@tool` 装饰的纯函数,无外部依赖,可单独测试 +- **graph_builder.py**:LangGraph 状态图构建器,定义状态、节点函数和条件边 +- **agent.py**:AIAgentService 服务类,负责模型初始化和 graph 编译,使用 `AsyncPostgresSaver` +- **backend.py**:FastAPI 应用,提供 REST API 和 WebSocket 接口,端口 8001 +- **frontend.py**:Streamlit 前端,通过 HTTP 调用后端 API,实现友好的用户界面 + +### 数据流 + +``` +用户输入 → Streamlit 前端 → FastAPI 后端 → AIAgentService +→ LangGraph StateGraph → LLM + Tools → PostgreSQL (记忆) +→ 返回响应 → 前端展示 +``` + +## 注意事项 + +1. **文件安全**:所有文件读取操作仅限于 `./user_docs` 目录,防止路径遍历攻击 +2. **端口冲突**:后端使用 8001 端口,避免与本地 vLLM 服务的 8000 端口冲突 +3. **API 密钥**:请妥善保管 `.env` 文件中的 API 密钥,不要提交到版本控制系统 +4. **数据库持久化**:PostgreSQL 数据卷挂载到 `~/docker_volumes/postgres_data`,确保数据安全 + +## 故障排除 + +### 问题:无法连接 PostgreSQL + +**解决方案:** +```bash +# 检查容器是否运行 +docker ps | grep postgres-langgraph + +# 查看容器日志 +docker logs postgres-langgraph + +# 重新启动容器 +docker restart postgres-langgraph +``` + +### 问题:后端启动失败 + +**解决方案:** +- 确认端口 8001 未被占用 +- 检查 `.env` 文件中的 API 密钥是否正确配置 +- 确认所有依赖包已正确安装 +- 查看启动日志,确认至少有一个模型初始化成功 + +### 问题:模型切换后无响应 + +**解决方案:** +- 检查所选模型的配置是否正确(如智谱 API Key) +- 确认 vLLM 容器是否正在运行(如果使用本地模型) +- 查看后端日志,确认模型是否初始化成功 +- 尝试切换到另一个模型 + +### 问题:工具调用失败 + +**解决方案:** +- 确认文件位于 `./user_docs` 目录下 +- 检查文件格式是否正确 +- 查看后端日志获取详细错误信息 + +## 许可证 + +本项目采用 MIT 许可证。详见 [LICENSE](LICENSE) 文件。 + +## 贡献 + +欢迎提交 Issue 和 Pull Request! diff --git a/agent.py b/agent.py index c4b3d11..34e79f0 100644 --- a/agent.py +++ b/agent.py @@ -1,187 +1,85 @@ -from bs4 import BeautifulSoup -from langchain.agents import create_agent -import requests -import pypdf -import pandas as pd -from dotenv import load_dotenv +""" +AI Agent 服务类 - 支持多模型动态切换 +接收外部传入的 checkpointer,不负责管理连接生命周期 +""" + import os -import time -from pathlib import Path +from dotenv import load_dotenv from langchain_community.chat_models import ChatZhipuAI -from langchain_huggingface import HuggingFacePipeline,ChatHuggingFace -from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline -from langchain_core.tools import tool from langchain_core.messages import HumanMessage -from transformers import BitsAndBytesConfig from langchain_openai import ChatOpenAI from pydantic import SecretStr -##--基础定义 +# 本地模块 +from graph_builder import GraphBuilder +from tools import AVAILABLE_TOOLS, TOOLS_BY_NAME + load_dotenv() -LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH","glm-4.7-flash") -ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY") -VLLM_LOCAL_KEY = os.getenv("VLLM_LOCAL_KEY", "") -DEVICE = os.getenv("DEVICE") -##加载模型 -local_llm = None -online_llm = None +class AIAgentService: + """异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer""" -def get_local_llm(): - global local_llm - if local_llm is None: - local_llm = ChatOpenAI( - base_url="http://localhost:8000/v1", - api_key=SecretStr(VLLM_LOCAL_KEY), - model="gemma-4-E2B-it", - ) - return local_llm + def __init__(self, checkpointer): + """ + 初始化服务 + Args: + checkpointer: 已经初始化的 AsyncPostgresSaver 实例 + """ + self.checkpointer = checkpointer + self.graphs = {} # 存储不同模型对应的 graph 实例 -def get_online_llm(): - global online_llm - if online_llm is None: - online_llm = ChatZhipuAI( + def _create_zhipu_llm(self): + """创建智谱在线 LLM""" + api_key = os.getenv("ZHIPUAI_API_KEY") + if not api_key: + raise ValueError("ZHIPUAI_API_KEY not set in environment") + return ChatZhipuAI( model="glm-4.7-flash", - api_key=ZHIPUAI_API_KEY, + api_key=api_key, temperature=0.1, max_tokens=4096, ) - return online_llm -##工具调用 + def _create_local_llm(self): + """创建本地 vLLM 服务 LLM""" + return ChatOpenAI( + base_url="http://localhost:8000/v1", + api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")), + model="gemma-4-E2B-it", + ) -@tool -def get_currenttemperature(location: str) -> str: - """获取指定地点的当前温度,当用户询问天气或温度时使用此工具。""" - return f'当前{location}的温度为25℃' + async def initialize(self): + """预编译所有模型的 graph(使用传入的 checkpointer)""" + model_configs = { + "zhipu": self._create_zhipu_llm, + "local": self._create_local_llm, + } -# sym:file_allow_check -def file_allow_check(filename: str) -> Path: - """ - 检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。 - 返回合法的 Path 对象,若不合法则抛出异常。 - """ - allowed_dir = Path("./user_docs").resolve() - allowed_dir.mkdir(exist_ok=True) + for model_name, llm_creator in model_configs.items(): + try: + llm = llm_creator() + builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build() + graph = builder.compile(checkpointer=self.checkpointer) + self.graphs[model_name] = graph + print(f"✅ 模型 '{model_name}' 初始化成功") + except Exception as e: + print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}") - file_path = (allowed_dir / filename).resolve() - if not str(file_path).startswith(str(allowed_dir)): - raise ValueError("错误:非法文件路径。") + if not self.graphs: + raise RuntimeError("没有可用的模型,请检查配置") - if not file_path.exists(): - raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。") + return self - return file_path + async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str: + """处理用户消息,返回最终答案""" + if model not in self.graphs: + fallback_model = next(iter(self.graphs.keys())) + print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'") + model = fallback_model - -@tool -def read_local_file(filename: str) -> str: - """ - 读取用户指定名称的本地文本文件内容并返回摘要。 - 参数 filename: 文件名,例如 'project_plan.txt' 或 'notes.md'。 - """ - try: - file_path = file_allow_check(filename) - except (ValueError, FileNotFoundError) as e: - return str(e) - - try: - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - # 2. 内容过长时,可以在此处增加一个简单的摘要逻辑,或者直接返回前N个字符 - # 为了演示,这里返回前1000个字符 - return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..." - except Exception as e: - return f"读取文件时出错:{str(e)}" - - -@tool -def read_pdf_summary(filename: str) -> str: - """ - 读取PDF文件并返回内容文本。参数 filename: PDF文件名,例如 'report.pdf'。 - """ - try: - file_path = file_allow_check(filename) - except (ValueError, FileNotFoundError) as e: - return str(e) - try: - text = "" - with open(file_path, 'rb') as f: - reader = pypdf.PdfReader(f) - for page in reader.pages[:3]: - text += page.extract_text() - return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..." - except Exception as e: - return f"读取PDF出错:{e}" - -@tool -def read_excel_as_markdown(filename: str) -> str: - """ - 读取Excel文件,并将其主要数据转换为Markdown表格格式。参数 filename: Excel文件名,例如 'data.xlsx'。 - """ - try: - file_path = file_allow_check(filename) - except (ValueError, FileNotFoundError) as e: - return str(e) - - try: - df = pd.read_excel(file_path) - markdown_table = df.head(10).to_markdown(index=False) - return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}" - except Exception as e: - return f"读取Excel出错:{e}" - -@tool -def fetch_webpage_content(url: str) -> str: - """ - 抓取给定URL的网页正文内容,并返回清晰的纯文本。 - 参数 url: 完整的网页地址,例如 'https://example.com/article'。 - """ - try: - response = requests.get(url, timeout=10) - response.raise_for_status() - soup = BeautifulSoup(response.text, 'html.parser') - # 简单的正文提取,去除脚本和样式 - for script in soup(["script", "style"]): - script.decompose() - text = soup.get_text() - lines = (line.strip() for line in text.splitlines()) - chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) - text = '\n'.join(chunk for chunk in chunks if chunk) - return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..." - except Exception as e: - return f"抓取网页时出错:{str(e)}" - -#使用langgraph -agent=create_agent( - model=get_local_llm(), - tools=[get_currenttemperature,read_local_file,fetch_webpage_content,read_pdf_summary,read_excel_as_markdown], - system_prompt=( - "你是一个个人生活助手和数据分析助手。请说中文。" - "当用户询问天气或温度时,使用get_currenttemperature工具获取信息。" - "当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。" - "当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。" - "当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。" - "当用户要求抓取网页时,请使用 fetch_webpage_content 工具。" - "当用户要求分析文档时,请使用合适的工具读取内容,然后:1. 总结核心发现。2. 如果涉及数据,请以Markdown表格或列表的形式清晰地呈现。" - "重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述、标记或内部推理。直接给出最终答案或工具调用指令。" -) -) - -while True: - user_input = input("请输入: ") - if user_input.lower() == "exit": - break - # 记录开始时间 - start_time = time.time() - response=agent.invoke({"messages":[HumanMessage(content=user_input)]}) - # 计算思考时间 - thinking_time = time.time() - start_time - # 提取回答内容 - final_answer=response["messages"][-1].content - # 打印回答和统计信息 - print(f"\n{final_answer}") - print(f"思考时间: {thinking_time:.2f}秒") - print("-" * 50) - \ No newline at end of file + graph = self.graphs[model] + config = {"configurable": {"thread_id": thread_id}} + input_state = {"messages": [HumanMessage(content=message)]} + result = await graph.ainvoke(input_state, config=config) + return result["messages"][-1].content \ No newline at end of file diff --git a/backend.py b/backend.py new file mode 100644 index 0000000..6b4df52 --- /dev/null +++ b/backend.py @@ -0,0 +1,115 @@ +""" +FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆 +采用依赖注入模式,优雅管理资源生命周期 +""" + +import uuid +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver + +from agent import AIAgentService + +# PostgreSQL 连接字符串 +DB_URI = "postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable" + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理:创建并注入全局服务""" + # 1. 创建数据库连接池并初始化表 + async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: + await checkpointer.setup() + + # 2. 构建 AI Agent 服务 + agent_service = AIAgentService(checkpointer) + await agent_service.initialize() + + # 3. 将服务实例存入 app.state + app.state.agent_service = agent_service + + # 应用运行中... + yield + + # 4. 关闭时自动清理数据库连接(async with 负责) + print("🛑 应用关闭,数据库连接池已释放") + + +app = FastAPI(lifespan=lifespan) + +# CORS 中间件(允许前端跨域) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ========== Pydantic 模型 ========== +class ChatRequest(BaseModel): + message: str + thread_id: str | None = None + model: str = "zhipu" + + +class ChatResponse(BaseModel): + reply: str + thread_id: str + model_used: str + + +# ========== 依赖注入函数 ========== +def get_agent_service(request: Request) -> AIAgentService: + """从 app.state 中获取全局 AIAgentService 实例""" + return request.app.state.agent_service + + +# ========== HTTP 端点 ========== +@app.post("/chat", response_model=ChatResponse) +async def chat_endpoint( + request: ChatRequest, + agent_service: AIAgentService = Depends(get_agent_service) +): + """同步对话接口,支持模型选择""" + if not request.message: + raise HTTPException(status_code=400, detail="message required") + + thread_id = request.thread_id or str(uuid.uuid4()) + reply = await agent_service.process_message( + request.message, thread_id, request.model + ) + actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys())) + return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model) + + +# ========== WebSocket 端点(可选) ========== +@app.websocket("/ws") +async def websocket_endpoint( + websocket: WebSocket, + agent_service: AIAgentService = Depends(get_agent_service) +): + await websocket.accept() + try: + while True: + data = await websocket.receive_json() + message = data.get("message") + thread_id = data.get("thread_id", str(uuid.uuid4())) + model = data.get("model", "zhipu") + if not message: + await websocket.send_json({"error": "missing message"}) + continue + reply = await agent_service.process_message(message, thread_id, model) + actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys())) + await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model}) + except WebSocketDisconnect: + pass + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8001) \ No newline at end of file diff --git a/frontend.py b/frontend.py new file mode 100644 index 0000000..3ac8117 --- /dev/null +++ b/frontend.py @@ -0,0 +1,94 @@ +""" +Streamlit 前端 - 支持模型选择 +""" + +# 标准库 +import uuid + +# 第三方库 +import requests +import streamlit as st + +# 后端 API 地址(端口 8001) +API_URL = "http://localhost:8001/chat" + +st.set_page_config(page_title="AI 个人助手", page_icon="🤖") +st.title("🤖 个人生活与数据分析助手") + +# 模型选项(与后端支持的模型名称一致) +MODEL_OPTIONS = { + "zhipu": "智谱 GLM-4.7-Flash(在线)", + "local": "本地 vLLM(Gemma-4)" +} + +# 初始化会话状态 +if "messages" not in st.session_state: + st.session_state.messages = [] +if "thread_id" not in st.session_state: + st.session_state.thread_id = str(uuid.uuid4()) +if "selected_model" not in st.session_state: + st.session_state.selected_model = "zhipu" + +# 侧边栏:模型选择和会话管理 +with st.sidebar: + st.header("⚙️ 设置") + + # 模型选择 + selected_model_key = st.selectbox( + "选择大模型", + options=list(MODEL_OPTIONS.keys()), + format_func=lambda x: MODEL_OPTIONS[x], + index=0 + ) + st.session_state.selected_model = selected_model_key + + # 会话信息显示 + st.write(f"当前会话 ID: `{st.session_state.thread_id[:8]}...`") + + # 新会话按钮 + if st.button("🔄 新会话"): + st.session_state.thread_id = str(uuid.uuid4()) + st.session_state.messages = [] + st.rerun() + +# 显示历史消息 +for msg in st.session_state.messages: + with st.chat_message(msg["role"]): + st.markdown(msg["content"]) + +# 用户输入 +if prompt := st.chat_input("请输入您的问题..."): + # 显示用户消息 + with st.chat_message("user"): + st.markdown(prompt) + st.session_state.messages.append({"role": "user", "content": prompt}) + + # 调用后端 API(携带模型参数) + with st.chat_message("assistant"): + with st.spinner("思考中..."): + try: + response = requests.post( + API_URL, + json={ + "message": prompt, + "thread_id": st.session_state.thread_id, + "model": st.session_state.selected_model + }, + timeout=60 + ) + response.raise_for_status() + data = response.json() + reply = data["reply"] + model_used = data["model_used"] + + # 显示回复 + st.markdown(reply) + + # 显示使用的模型(小字提示) + st.caption(f"🤖 使用模型: {MODEL_OPTIONS.get(model_used, model_used)}") + + st.session_state.messages.append({"role": "assistant", "content": reply}) + except Exception as e: + error_msg = f"请求失败: {e}" + st.error(error_msg) + st.session_state.messages.append({"role": "assistant", "content": error_msg}) diff --git a/graph_builder.py b/graph_builder.py new file mode 100644 index 0000000..8598ef1 --- /dev/null +++ b/graph_builder.py @@ -0,0 +1,127 @@ +""" +LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数 +""" + +import operator +import asyncio +from typing import Literal, Annotated, Any +from langchain_core.language_models import BaseLLM +from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langgraph.graph import StateGraph, START, END +from typing_extensions import TypedDict + + +class MessageState(TypedDict): + """对话状态类型定义""" + messages: Annotated[list[AnyMessage], operator.add] + llm_calls: int + + +class GraphBuilder: + """LangGraph 状态图构建器 - 所有节点均为类方法""" + + def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]): + """ + 初始化构建器 + + Args: + llm: 大语言模型实例 + tools: 工具列表 + tools_by_name: 名称到工具函数的映射 + """ + self.llm = llm + self.tools = tools + self.tools_by_name = tools_by_name + self._llm_with_tools = llm.bind_tools(tools) + self._prompt = self._create_prompt() + self._chain = self._prompt | self._llm_with_tools + + @staticmethod + def _create_prompt() -> ChatPromptTemplate: + """创建系统提示模板(静态方法,无需访问实例)""" + return ChatPromptTemplate.from_messages([ + SystemMessage(content=( + "你是一个个人生活助手和数据分析助手。请说中文。" + "当用户询问天气或温度时,使用get_current_temperature工具获取信息。" + "当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。" + "当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。" + "当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。" + "当用户要求抓取网页时,请使用 fetch_webpage_content 工具。" + "重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。" + )), + MessagesPlaceholder(variable_name="message") + ]) + + async def call_llm(self, state: MessageState) -> dict: + """ + LLM 调用节点(异步方法) + 注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环 + """ + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: self._chain.invoke({"message": state["messages"]}) + ) + return { + "messages": [response], + "llm_calls": state.get('llm_calls', 0) + 1 + } + + async def call_tools(self, state: MessageState) -> dict: + """ + 工具执行节点(异步方法) + 对于每个工具调用,在线程池中执行同步工具函数 + """ + last_message = state['messages'][-1] + if not isinstance(last_message, AIMessage) or not last_message.tool_calls: + return {"messages": []} + + results = [] + loop = asyncio.get_event_loop() + + for tool_call in last_message.tool_calls: + tool_name = tool_call["name"] + tool_args = tool_call["args"] + tool_id = tool_call["id"] + tool_func = self.tools_by_name.get(tool_name) + + if tool_func is None: + results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id)) + continue + + try: + # 同步工具函数在线程池中执行 + observation = await loop.run_in_executor( + None, + lambda: tool_func.invoke(tool_args) + ) + results.append(ToolMessage(content=str(observation), tool_call_id=tool_id)) + except Exception as e: + results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id)) + + return {"messages": results} + + @staticmethod + def should_continue(state: MessageState) -> Literal['tool_node', END]: + """ + 条件边判断(静态方法) + 决定下一步是进入工具节点还是结束 + """ + last_message = state["messages"][-1] + if isinstance(last_message, AIMessage) and bool(last_message.tool_calls): + return 'tool_node' + return END + + def build(self) -> StateGraph: + """ + 构建未编译的状态图(返回 StateGraph 实例) + 图中节点直接使用实例方法 call_llm, call_tools + """ + builder = StateGraph(MessageState) + builder.add_node("llm_call", self.call_llm) + builder.add_node("tool_node", self.call_tools) + builder.add_edge(START, "llm_call") + builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END]) + builder.add_edge("tool_node", "llm_call") + return builder \ No newline at end of file diff --git a/requirement.txt b/requirement.txt index 83b2518..28a34ba 100644 --- a/requirement.txt +++ b/requirement.txt @@ -13,11 +13,28 @@ langchain-huggingface>=0.0.3 langchain-core>=0.1.0 langchain-openai>=0.0.5 +# LangGraph +langgraph>=0.0.30 +langgraph-checkpoint-postgres>=0.0.5 + # ZhipuAI (智谱AI) zhipuai>=1.0.0 +# Backend +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +websockets>=12.0 + +# Frontend +streamlit>=1.30.0 + +# Database +psycopg[binary,pool]>=3.1.0 + # Pydantic pydantic>=2.0.0 # Utilities -python-dotenv>=1.0.0 \ No newline at end of file +python-dotenv>=1.0.0 +typing-extensions>=4.9.0 +ipython>=8.0.0 diff --git a/start.sh b/start.sh new file mode 100644 index 0000000..f1e0574 --- /dev/null +++ b/start.sh @@ -0,0 +1,145 @@ +#!/bin/bash + +# AI Agent 启动脚本 +# 用法: ./start.sh [backend|frontend|both] + +set -e + +# 颜色定义 +GREEN='\033[0;32m' +BLUE='\033[0;34m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +echo -e "${BLUE}========================================${NC}" +echo -e "${BLUE} AI Agent - 个人生活助手启动脚本${NC}" +echo -e "${BLUE}========================================${NC}" +echo "" + +# 检查 vLLM 容器是否运行 +check_vllm() { + if ! docker ps --format '{{.Names}}' | grep -q "^gemma4-server$"; then + echo -e "${YELLOW}⚠️ vLLM 容器未运行!${NC}" + echo "正在启动 vLLM 容器(Gemma-4 模型)..." + + # 检查模型文件是否存在 + if [ ! -d "/home/huang/Study/AIModel/gemma-4-E2B-it" ]; then + echo -e "${RED}✗ 错误:模型目录不存在: /home/huang/Study/AIModel/gemma-4-E2B-it${NC}" + echo "请先下载模型或修改模型路径" + exit 1 + fi + + docker run -d \ + --name gemma4-server \ + --group-add=video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --device=/dev/kfd \ + --device=/dev/dri \ + -v /home/huang/Study/AIModel/gemma-4-E2B-it:/models/gemma-4-E2B-it \ + -e VLLM_ROCM_USE_AITER=0 \ + -e HF_TOKEN="${HF_TOKEN}" \ + -p 8000:8000 \ + --ipc=host \ + --entrypoint vllm \ + my-vllm-gemma4:working \ + serve /models/gemma-4-E2B-it \ + --served-model-name gemma-4-E2B-it \ + --dtype auto \ + --api-key token-abc123 \ + --trust-remote-code \ + --port 8000 \ + --gpu-memory-utilization 0.85 \ + --max-model-len 8192 + + echo -e "${GREEN}✓ vLLM 容器已启动${NC}" + echo -e "${YELLOW}⏳ 等待模型加载(可能需要几分钟)...${NC}" + sleep 10 + else + echo -e "${GREEN}✓ vLLM 容器正在运行${NC}" + fi +} + +# 检查 PostgreSQL 容器是否运行 +check_postgres() { + if ! docker ps | grep -q postgres-langgraph; then + echo -e "${YELLOW}⚠️ PostgreSQL 容器未运行!${NC}" + echo "正在启动 PostgreSQL 容器..." + docker run -d \ + --name postgres-langgraph \ + -e POSTGRES_PASSWORD=mysecretpassword \ + -e POSTGRES_DB=langgraph_db \ + -p 5432:5432 \ + -v ~/docker_volumes/postgres_data:/var/lib/postgresql/data \ + postgres:16 + + echo -e "${GREEN}✓ PostgreSQL 容器已启动${NC}" + sleep 3 + else + echo -e "${GREEN}✓ PostgreSQL 容器正在运行${NC}" + fi +} + +# 启动后端 +start_backend() { + echo -e "\n${BLUE}🚀 启动后端服务 (端口 8001)...${NC}" + python backend.py & + BACKEND_PID=$! + echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}" + sleep 2 +} + +# 启动前端 +start_frontend() { + echo -e "\n${BLUE}🎨 启动前端界面...${NC}" + streamlit run frontend.py & + FRONTEND_PID=$! + echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}" + echo -e "${GREEN}✓ 请在浏览器中打开: http://localhost:8501${NC}" +} + +# 清理函数 +cleanup() { + echo -e "\n${RED}🛑 正在停止所有服务...${NC}" + if [ ! -z "$BACKEND_PID" ]; then + kill $BACKEND_PID 2>/dev/null || true + echo -e "${GREEN}✓ 后端服务已停止${NC}" + fi + if [ ! -z "$FRONTEND_PID" ]; then + kill $FRONTEND_PID 2>/dev/null || true + echo -e "${GREEN}✓ 前端服务已停止${NC}" + fi + echo -e "${YELLOW}💡 提示:Docker 容器需要手动停止${NC}" + echo -e " 停止 vLLM: docker stop gemma4-server" + echo -e " 停止 PostgreSQL: docker stop postgres-langgraph" + exit 0 +} + +# 捕获 Ctrl+C +trap cleanup SIGINT SIGTERM + +# 主逻辑 +case "${1:-both}" in + backend) + check_vllm + check_postgres + start_backend + echo -e "\n${GREEN}后端服务正在运行,按 Ctrl+C 停止${NC}" + wait $BACKEND_PID + ;; + frontend) + start_frontend + echo -e "\n${GREEN}前端服务正在运行,按 Ctrl+C 停止${NC}" + wait $FRONTEND_PID + ;; + both|*) + check_vllm + check_postgres + start_backend + start_frontend + echo -e "\n${GREEN}所有服务正在运行,按 Ctrl+C 停止 Python 服务${NC}" + echo -e "${YELLOW}注意:Docker 容器会在后台继续运行${NC}" + wait + ;; +esac diff --git a/test_gemma.py b/test_gemma.py deleted file mode 100644 index a32ce7c..0000000 --- a/test_gemma.py +++ /dev/null @@ -1,20 +0,0 @@ -from openai import OpenAI - -# 连接本地 vLLM 服务 -client = OpenAI( - base_url="http://localhost:8000/v1", # 容器映射的地址 - api_key="token-abc123", # 与你启动命令中的 --api-key 一致 -) - -# 发起对话 -response = client.chat.completions.create( - model="gemma-4-E2B-it", # --served-model-name 指定的名称 - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "请用中文介绍一下你自己"} - ], - temperature=0.7, - max_tokens=512, -) - -print(response.choices[0].message.content) \ No newline at end of file diff --git a/test_multi_model.py b/test_multi_model.py new file mode 100644 index 0000000..a64d5bd --- /dev/null +++ b/test_multi_model.py @@ -0,0 +1,134 @@ +""" +多模型切换功能测试脚本 +用于验证后端是否正确支持多模型动态切换 +""" + +import requests +import json + +API_URL = "http://localhost:8001/chat" + + +def test_model_switching(): + """测试模型切换功能""" + + print("=" * 60) + print("测试多模型切换功能") + print("=" * 60) + + # 测试消息 + test_message = "你好,请简单介绍一下自己" + + # 测试不同的模型 + models_to_test = ["zhipu", "local"] + + for model in models_to_test: + print(f"\n📤 测试模型: {model}") + print("-" * 60) + + try: + response = requests.post( + API_URL, + json={ + "message": test_message, + "model": model + }, + timeout=30 + ) + + if response.status_code == 200: + data = response.json() + print(f"✅ 成功!") + print(f" 使用的模型: {data['model_used']}") + print(f" 会话 ID: {data['thread_id'][:8]}...") + print(f" 回复预览: {data['reply'][:100]}...") + else: + print(f"❌ 失败! 状态码: {response.status_code}") + print(f" 错误信息: {response.text}") + + except requests.exceptions.Timeout: + print(f"⏰ 超时! 模型 '{model}' 响应时间过长") + except requests.exceptions.ConnectionError: + print(f"🔌 连接失败! 请确认后端服务正在运行 (python backend.py)") + except Exception as e: + print(f"💥 异常: {str(e)}") + + print("\n" + "=" * 60) + print("测试完成!") + print("=" * 60) + + +def test_conversation_memory(): + """测试跨模型的会话记忆""" + + print("\n" + "=" * 60) + print("测试跨模型会话记忆") + print("=" * 60) + + import uuid + thread_id = str(uuid.uuid4()) + + print(f"\n📝 使用固定会话 ID: {thread_id[:8]}...") + + # 第一轮对话 - 使用 zhipu 模型 + print("\n📤 第1轮 - 使用 zhipu 模型") + try: + response1 = requests.post( + API_URL, + json={ + "message": "我叫小明,记住我的名字", + "thread_id": thread_id, + "model": "zhipu" + }, + timeout=30 + ) + if response1.status_code == 200: + data1 = response1.json() + print(f" ✅ 回复: {data1['reply'][:100]}...") + print(f" 🤖 使用模型: {data1['model_used']}") + except Exception as e: + print(f" ❌ 失败: {e}") + return + + # 第二轮对话 - 切换到 local 模型,测试是否记得名字 + print("\n📤 第2轮 - 切换到 local 模型") + try: + response2 = requests.post( + API_URL, + json={ + "message": "我叫什么名字?", + "thread_id": thread_id, + "model": "local" + }, + timeout=30 + ) + if response2.status_code == 200: + data2 = response2.json() + print(f" ✅ 回复: {data2['reply'][:100]}...") + print(f" 🤖 使用模型: {data2['model_used']}") + + # 检查是否记得名字 + if "小明" in data2['reply']: + print(" 🎉 成功!跨模型记忆功能正常") + else: + print(" ⚠️ 注意:模型可能没有正确回忆上下文") + except Exception as e: + print(f" ❌ 失败: {e}") + + print("\n" + "=" * 60) + print("会话记忆测试完成!") + print("=" * 60) + + +if __name__ == "__main__": + print("\n⚠️ 请确保后端服务正在运行 (python backend.py)\n") + + # 运行基本测试 + test_model_switching() + + # 询问是否运行记忆测试 + choice = input("\n是否运行会话记忆测试?(y/n): ").strip().lower() + if choice == 'y': + test_conversation_memory() + + print("\n✨ 所有测试完成!") diff --git a/tools.py b/tools.py new file mode 100644 index 0000000..6db7668 --- /dev/null +++ b/tools.py @@ -0,0 +1,103 @@ +""" +工具定义模块 - 纯函数工具,无依赖 AIAgent 类 +""" + +# 标准库 +import os +from pathlib import Path + +# 第三方库 +import pandas as pd +import pypdf +import requests +from bs4 import BeautifulSoup +from langchain_core.tools import tool + + +def _file_allow_check(filename: str) -> Path: + """检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。""" + allowed_dir = Path("./user_docs").resolve() + allowed_dir.mkdir(exist_ok=True) + + file_path = (allowed_dir / filename).resolve() + if not str(file_path).startswith(str(allowed_dir)): + raise ValueError("错误:非法文件路径。") + + if not file_path.exists(): + raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。") + + return file_path + + +@tool +def get_current_temperature(location: str) -> str: + """获取指定地点的当前温度。""" + return f'当前{location}的温度为25℃' + + +@tool +def read_local_file(filename: str) -> str: + """读取用户指定名称的本地文本文件内容并返回摘要。""" + try: + file_path = _file_allow_check(filename) + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..." + except Exception as e: + return f"读取文件时出错:{str(e)}" + + +@tool +def read_pdf_summary(filename: str) -> str: + """读取PDF文件并返回内容文本摘要。""" + try: + file_path = _file_allow_check(filename) + text = "" + with open(file_path, 'rb') as f: + reader = pypdf.PdfReader(f) + for page in reader.pages[:3]: + text += page.extract_text() + return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..." + except Exception as e: + return f"读取PDF出错:{e}" + + +@tool +def read_excel_as_markdown(filename: str) -> str: + """读取Excel文件,并将其主要数据转换为Markdown表格格式。""" + try: + file_path = _file_allow_check(filename) + df = pd.read_excel(file_path) + markdown_table = df.head(10).to_markdown(index=False) + return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}" + except Exception as e: + return f"读取Excel出错:{e}" + + +@tool +def fetch_webpage_content(url: str) -> str: + """抓取给定URL的网页正文内容,并返回清晰的纯文本。""" + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + soup = BeautifulSoup(response.text, 'html.parser') + for script in soup(["script", "style"]): + script.decompose() + text = soup.get_text() + lines = (line.strip() for line in text.splitlines()) + chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) + text = '\n'.join(chunk for chunk in chunks if chunk) + return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..." + except Exception as e: + return f"抓取网页时出错:{str(e)}" + + +# 工具列表和映射(全局常量) +AVAILABLE_TOOLS = [ + get_current_temperature, + read_local_file, + fetch_webpage_content, + read_pdf_summary, + read_excel_as_markdown +] +TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS}