187 lines
6.7 KiB
Python
187 lines
6.7 KiB
Python
from bs4 import BeautifulSoup
|
||
from langchain.agents import create_agent
|
||
import requests
|
||
import pypdf
|
||
import pandas as pd
|
||
from dotenv import load_dotenv
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
from langchain_community.chat_models import ChatZhipuAI
|
||
from langchain_huggingface import HuggingFacePipeline,ChatHuggingFace
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
|
||
from langchain_core.tools import tool
|
||
from langchain_core.messages import HumanMessage
|
||
from transformers import BitsAndBytesConfig
|
||
from langchain_openai import ChatOpenAI
|
||
from pydantic import SecretStr
|
||
|
||
##--基础定义
|
||
load_dotenv()
|
||
|
||
LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH","glm-4.7-flash")
|
||
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
|
||
VLLM_LOCAL_KEY = os.getenv("VLLM_LOCAL_KEY", "")
|
||
DEVICE = os.getenv("DEVICE")
|
||
|
||
##加载模型
|
||
local_llm = None
|
||
online_llm = None
|
||
|
||
def get_local_llm():
|
||
global local_llm
|
||
if local_llm is None:
|
||
local_llm = ChatOpenAI(
|
||
base_url="http://localhost:8000/v1",
|
||
api_key=SecretStr(VLLM_LOCAL_KEY),
|
||
model="gemma-4-E2B-it",
|
||
)
|
||
return local_llm
|
||
|
||
def get_online_llm():
|
||
global online_llm
|
||
if online_llm is None:
|
||
online_llm = ChatZhipuAI(
|
||
model="glm-4.7-flash",
|
||
api_key=ZHIPUAI_API_KEY,
|
||
temperature=0.1,
|
||
max_tokens=4096,
|
||
)
|
||
return online_llm
|
||
|
||
##工具调用
|
||
|
||
@tool
|
||
def get_currenttemperature(location: str) -> str:
|
||
"""获取指定地点的当前温度,当用户询问天气或温度时使用此工具。"""
|
||
return f'当前{location}的温度为25℃'
|
||
|
||
# sym:file_allow_check
|
||
def file_allow_check(filename: str) -> Path:
|
||
"""
|
||
检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。
|
||
返回合法的 Path 对象,若不合法则抛出异常。
|
||
"""
|
||
allowed_dir = Path("./user_docs").resolve()
|
||
allowed_dir.mkdir(exist_ok=True)
|
||
|
||
file_path = (allowed_dir / filename).resolve()
|
||
if not str(file_path).startswith(str(allowed_dir)):
|
||
raise ValueError("错误:非法文件路径。")
|
||
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
|
||
|
||
return file_path
|
||
|
||
|
||
@tool
|
||
def read_local_file(filename: str) -> str:
|
||
"""
|
||
读取用户指定名称的本地文本文件内容并返回摘要。
|
||
参数 filename: 文件名,例如 'project_plan.txt' 或 'notes.md'。
|
||
"""
|
||
try:
|
||
file_path = file_allow_check(filename)
|
||
except (ValueError, FileNotFoundError) as e:
|
||
return str(e)
|
||
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
# 2. 内容过长时,可以在此处增加一个简单的摘要逻辑,或者直接返回前N个字符
|
||
# 为了演示,这里返回前1000个字符
|
||
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
|
||
except Exception as e:
|
||
return f"读取文件时出错:{str(e)}"
|
||
|
||
|
||
@tool
|
||
def read_pdf_summary(filename: str) -> str:
|
||
"""
|
||
读取PDF文件并返回内容文本。参数 filename: PDF文件名,例如 'report.pdf'。
|
||
"""
|
||
try:
|
||
file_path = file_allow_check(filename)
|
||
except (ValueError, FileNotFoundError) as e:
|
||
return str(e)
|
||
try:
|
||
text = ""
|
||
with open(file_path, 'rb') as f:
|
||
reader = pypdf.PdfReader(f)
|
||
for page in reader.pages[:3]:
|
||
text += page.extract_text()
|
||
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
|
||
except Exception as e:
|
||
return f"读取PDF出错:{e}"
|
||
|
||
@tool
|
||
def read_excel_as_markdown(filename: str) -> str:
|
||
"""
|
||
读取Excel文件,并将其主要数据转换为Markdown表格格式。参数 filename: Excel文件名,例如 'data.xlsx'。
|
||
"""
|
||
try:
|
||
file_path = file_allow_check(filename)
|
||
except (ValueError, FileNotFoundError) as e:
|
||
return str(e)
|
||
|
||
try:
|
||
df = pd.read_excel(file_path)
|
||
markdown_table = df.head(10).to_markdown(index=False)
|
||
return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}"
|
||
except Exception as e:
|
||
return f"读取Excel出错:{e}"
|
||
|
||
@tool
|
||
def fetch_webpage_content(url: str) -> str:
|
||
"""
|
||
抓取给定URL的网页正文内容,并返回清晰的纯文本。
|
||
参数 url: 完整的网页地址,例如 'https://example.com/article'。
|
||
"""
|
||
try:
|
||
response = requests.get(url, timeout=10)
|
||
response.raise_for_status()
|
||
soup = BeautifulSoup(response.text, 'html.parser')
|
||
# 简单的正文提取,去除脚本和样式
|
||
for script in soup(["script", "style"]):
|
||
script.decompose()
|
||
text = soup.get_text()
|
||
lines = (line.strip() for line in text.splitlines())
|
||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||
text = '\n'.join(chunk for chunk in chunks if chunk)
|
||
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
|
||
except Exception as e:
|
||
return f"抓取网页时出错:{str(e)}"
|
||
|
||
#使用langgraph
|
||
agent=create_agent(
|
||
model=get_local_llm(),
|
||
tools=[get_currenttemperature,read_local_file,fetch_webpage_content,read_pdf_summary,read_excel_as_markdown],
|
||
system_prompt=(
|
||
"你是一个个人生活助手和数据分析助手。请说中文。"
|
||
"当用户询问天气或温度时,使用get_currenttemperature工具获取信息。"
|
||
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
||
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
||
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
||
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
||
"当用户要求分析文档时,请使用合适的工具读取内容,然后:1. 总结核心发现。2. 如果涉及数据,请以Markdown表格或列表的形式清晰地呈现。"
|
||
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述、<think>标记或内部推理。直接给出最终答案或工具调用指令。"
|
||
)
|
||
)
|
||
|
||
while True:
|
||
user_input = input("请输入: ")
|
||
if user_input.lower() == "exit":
|
||
break
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
response=agent.invoke({"messages":[HumanMessage(content=user_input)]})
|
||
# 计算思考时间
|
||
thinking_time = time.time() - start_time
|
||
# 提取回答内容
|
||
final_answer=response["messages"][-1].content
|
||
# 打印回答和统计信息
|
||
print(f"\n{final_answer}")
|
||
print(f"思考时间: {thinking_time:.2f}秒")
|
||
print("-" * 50)
|
||
|