实现前后端分离的agent

This commit is contained in:
2026-04-13 19:49:18 +08:00
parent 09a5440045
commit 4385fabc22
13 changed files with 1317 additions and 188 deletions

228
agent.py
View File

@@ -1,187 +1,85 @@
from bs4 import BeautifulSoup
from langchain.agents import create_agent
import requests
import pypdf
import pandas as pd
from dotenv import load_dotenv
"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import os
import time
from pathlib import Path
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_huggingface import HuggingFacePipeline,ChatHuggingFace
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from transformers import BitsAndBytesConfig
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
##--基础定义
# 本地模块
from graph_builder import GraphBuilder
from tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
load_dotenv()
LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH","glm-4.7-flash")
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
VLLM_LOCAL_KEY = os.getenv("VLLM_LOCAL_KEY", "")
DEVICE = os.getenv("DEVICE")
##加载模型
local_llm = None
online_llm = None
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def get_local_llm():
global local_llm
if local_llm is None:
local_llm = ChatOpenAI(
base_url="http://localhost:8000/v1",
api_key=SecretStr(VLLM_LOCAL_KEY),
model="gemma-4-E2B-it",
)
return local_llm
def __init__(self, checkpointer):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
def get_online_llm():
global online_llm
if online_llm is None:
online_llm = ChatZhipuAI(
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=ZHIPUAI_API_KEY,
api_key=api_key,
temperature=0.1,
max_tokens=4096,
)
return online_llm
##工具调用
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
return ChatOpenAI(
base_url="http://localhost:8000/v1",
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
model="gemma-4-E2B-it",
)
@tool
def get_currenttemperature(location: str) -> str:
"""获取指定地点的当前温度,当用户询问天气或温度时使用此工具。"""
return f'当前{location}的温度为25℃'
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
model_configs = {
"zhipu": self._create_zhipu_llm,
"local": self._create_local_llm,
}
# sym:file_allow_check
def file_allow_check(filename: str) -> Path:
"""
检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。
返回合法的 Path 对象,若不合法则抛出异常。
"""
allowed_dir = Path("./user_docs").resolve()
allowed_dir.mkdir(exist_ok=True)
for model_name, llm_creator in model_configs.items():
try:
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
print(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
file_path = (allowed_dir / filename).resolve()
if not str(file_path).startswith(str(allowed_dir)):
raise ValueError("错误:非法文件路径。")
if not self.graphs:
raise RuntimeError("没有可用的模型,请检查配置")
if not file_path.exists():
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
return self
return file_path
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
"""处理用户消息,返回最终答案"""
if model not in self.graphs:
fallback_model = next(iter(self.graphs.keys()))
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
model = fallback_model
@tool
def read_local_file(filename: str) -> str:
"""
读取用户指定名称的本地文本文件内容并返回摘要。
参数 filename: 文件名,例如 'project_plan.txt''notes.md'
"""
try:
file_path = file_allow_check(filename)
except (ValueError, FileNotFoundError) as e:
return str(e)
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 2. 内容过长时可以在此处增加一个简单的摘要逻辑或者直接返回前N个字符
# 为了演示这里返回前1000个字符
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
except Exception as e:
return f"读取文件时出错:{str(e)}"
@tool
def read_pdf_summary(filename: str) -> str:
"""
读取PDF文件并返回内容文本。参数 filename: PDF文件名例如 'report.pdf'
"""
try:
file_path = file_allow_check(filename)
except (ValueError, FileNotFoundError) as e:
return str(e)
try:
text = ""
with open(file_path, 'rb') as f:
reader = pypdf.PdfReader(f)
for page in reader.pages[:3]:
text += page.extract_text()
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
except Exception as e:
return f"读取PDF出错{e}"
@tool
def read_excel_as_markdown(filename: str) -> str:
"""
读取Excel文件并将其主要数据转换为Markdown表格格式。参数 filename: Excel文件名例如 'data.xlsx'
"""
try:
file_path = file_allow_check(filename)
except (ValueError, FileNotFoundError) as e:
return str(e)
try:
df = pd.read_excel(file_path)
markdown_table = df.head(10).to_markdown(index=False)
return f"Excel文件 '{filename}' 的数据预览前10行\n{markdown_table}"
except Exception as e:
return f"读取Excel出错{e}"
@tool
def fetch_webpage_content(url: str) -> str:
"""
抓取给定URL的网页正文内容并返回清晰的纯文本。
参数 url: 完整的网页地址,例如 'https://example.com/article'
"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
# 简单的正文提取,去除脚本和样式
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
except Exception as e:
return f"抓取网页时出错:{str(e)}"
#使用langgraph
agent=create_agent(
model=get_local_llm(),
tools=[get_currenttemperature,read_local_file,fetch_webpage_content,read_pdf_summary,read_excel_as_markdown],
system_prompt=(
"你是一个个人生活助手和数据分析助手。请说中文。"
"当用户询问天气或温度时使用get_currenttemperature工具获取信息。"
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读PDF文件时请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求读Excel文件时请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
"当用户要求分析文档时请使用合适的工具读取内容然后1. 总结核心发现。2. 如果涉及数据请以Markdown表格或列表的形式清晰地呈现。"
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述、<think>标记或内部推理。直接给出最终答案或工具调用指令。"
)
)
while True:
user_input = input("请输入: ")
if user_input.lower() == "exit":
break
# 记录开始时间
start_time = time.time()
response=agent.invoke({"messages":[HumanMessage(content=user_input)]})
# 计算思考时间
thinking_time = time.time() - start_time
# 提取回答内容
final_answer=response["messages"][-1].content
# 打印回答和统计信息
print(f"\n{final_answer}")
print(f"思考时间: {thinking_time:.2f}")
print("-" * 50)
graph = self.graphs[model]
config = {"configurable": {"thread_id": thread_id}}
input_state = {"messages": [HumanMessage(content=message)]}
result = await graph.ainvoke(input_state, config=config)
return result["messages"][-1].content