187 lines
6.7 KiB
Python
187 lines
6.7 KiB
Python
|
|
from bs4 import BeautifulSoup
|
|||
|
|
from langchain.agents import create_agent
|
|||
|
|
import requests
|
|||
|
|
import pypdf
|
|||
|
|
import pandas as pd
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
from pathlib import Path
|
|||
|
|
from langchain_community.chat_models import ChatZhipuAI
|
|||
|
|
from langchain_huggingface import HuggingFacePipeline,ChatHuggingFace
|
|||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
|
|||
|
|
from langchain_core.tools import tool
|
|||
|
|
from langchain_core.messages import HumanMessage
|
|||
|
|
from transformers import BitsAndBytesConfig
|
|||
|
|
from langchain_openai import ChatOpenAI
|
|||
|
|
from pydantic import SecretStr
|
|||
|
|
|
|||
|
|
##--基础定义
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH","glm-4.7-flash")
|
|||
|
|
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
|
|||
|
|
VLLM_LOCAL_KEY = os.getenv("VLLM_LOCAL_KEY", "")
|
|||
|
|
DEVICE = os.getenv("DEVICE")
|
|||
|
|
|
|||
|
|
##加载模型
|
|||
|
|
local_llm = None
|
|||
|
|
online_llm = None
|
|||
|
|
|
|||
|
|
def get_local_llm():
|
|||
|
|
global local_llm
|
|||
|
|
if local_llm is None:
|
|||
|
|
local_llm = ChatOpenAI(
|
|||
|
|
base_url="http://localhost:8000/v1",
|
|||
|
|
api_key=SecretStr(VLLM_LOCAL_KEY),
|
|||
|
|
model="gemma-4-E2B-it",
|
|||
|
|
)
|
|||
|
|
return local_llm
|
|||
|
|
|
|||
|
|
def get_online_llm():
|
|||
|
|
global online_llm
|
|||
|
|
if online_llm is None:
|
|||
|
|
online_llm = ChatZhipuAI(
|
|||
|
|
model="glm-4.7-flash",
|
|||
|
|
api_key=ZHIPUAI_API_KEY,
|
|||
|
|
temperature=0.1,
|
|||
|
|
max_tokens=4096,
|
|||
|
|
)
|
|||
|
|
return online_llm
|
|||
|
|
|
|||
|
|
##工具调用
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def get_currenttemperature(location: str) -> str:
|
|||
|
|
"""获取指定地点的当前温度,当用户询问天气或温度时使用此工具。"""
|
|||
|
|
return f'当前{location}的温度为25℃'
|
|||
|
|
|
|||
|
|
# sym:file_allow_check
|
|||
|
|
def file_allow_check(filename: str) -> Path:
|
|||
|
|
"""
|
|||
|
|
检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。
|
|||
|
|
返回合法的 Path 对象,若不合法则抛出异常。
|
|||
|
|
"""
|
|||
|
|
allowed_dir = Path("./user_docs").resolve()
|
|||
|
|
allowed_dir.mkdir(exist_ok=True)
|
|||
|
|
|
|||
|
|
file_path = (allowed_dir / filename).resolve()
|
|||
|
|
if not str(file_path).startswith(str(allowed_dir)):
|
|||
|
|
raise ValueError("错误:非法文件路径。")
|
|||
|
|
|
|||
|
|
if not file_path.exists():
|
|||
|
|
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
|
|||
|
|
|
|||
|
|
return file_path
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def read_local_file(filename: str) -> str:
|
|||
|
|
"""
|
|||
|
|
读取用户指定名称的本地文本文件内容并返回摘要。
|
|||
|
|
参数 filename: 文件名,例如 'project_plan.txt' 或 'notes.md'。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
file_path = file_allow_check(filename)
|
|||
|
|
except (ValueError, FileNotFoundError) as e:
|
|||
|
|
return str(e)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
content = f.read()
|
|||
|
|
# 2. 内容过长时,可以在此处增加一个简单的摘要逻辑,或者直接返回前N个字符
|
|||
|
|
# 为了演示,这里返回前1000个字符
|
|||
|
|
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"读取文件时出错:{str(e)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def read_pdf_summary(filename: str) -> str:
|
|||
|
|
"""
|
|||
|
|
读取PDF文件并返回内容文本。参数 filename: PDF文件名,例如 'report.pdf'。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
file_path = file_allow_check(filename)
|
|||
|
|
except (ValueError, FileNotFoundError) as e:
|
|||
|
|
return str(e)
|
|||
|
|
try:
|
|||
|
|
text = ""
|
|||
|
|
with open(file_path, 'rb') as f:
|
|||
|
|
reader = pypdf.PdfReader(f)
|
|||
|
|
for page in reader.pages[:3]:
|
|||
|
|
text += page.extract_text()
|
|||
|
|
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"读取PDF出错:{e}"
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def read_excel_as_markdown(filename: str) -> str:
|
|||
|
|
"""
|
|||
|
|
读取Excel文件,并将其主要数据转换为Markdown表格格式。参数 filename: Excel文件名,例如 'data.xlsx'。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
file_path = file_allow_check(filename)
|
|||
|
|
except (ValueError, FileNotFoundError) as e:
|
|||
|
|
return str(e)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
df = pd.read_excel(file_path)
|
|||
|
|
markdown_table = df.head(10).to_markdown(index=False)
|
|||
|
|
return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}"
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"读取Excel出错:{e}"
|
|||
|
|
|
|||
|
|
@tool
|
|||
|
|
def fetch_webpage_content(url: str) -> str:
|
|||
|
|
"""
|
|||
|
|
抓取给定URL的网页正文内容,并返回清晰的纯文本。
|
|||
|
|
参数 url: 完整的网页地址,例如 'https://example.com/article'。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
response = requests.get(url, timeout=10)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
soup = BeautifulSoup(response.text, 'html.parser')
|
|||
|
|
# 简单的正文提取,去除脚本和样式
|
|||
|
|
for script in soup(["script", "style"]):
|
|||
|
|
script.decompose()
|
|||
|
|
text = soup.get_text()
|
|||
|
|
lines = (line.strip() for line in text.splitlines())
|
|||
|
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
|||
|
|
text = '\n'.join(chunk for chunk in chunks if chunk)
|
|||
|
|
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
|
|||
|
|
except Exception as e:
|
|||
|
|
return f"抓取网页时出错:{str(e)}"
|
|||
|
|
|
|||
|
|
#使用langgraph
|
|||
|
|
agent=create_agent(
|
|||
|
|
model=get_local_llm(),
|
|||
|
|
tools=[get_currenttemperature,read_local_file,fetch_webpage_content,read_pdf_summary,read_excel_as_markdown],
|
|||
|
|
system_prompt=(
|
|||
|
|
"你是一个个人生活助手和数据分析助手。请说中文。"
|
|||
|
|
"当用户询问天气或温度时,使用get_currenttemperature工具获取信息。"
|
|||
|
|
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
|||
|
|
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
|||
|
|
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
|||
|
|
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
|||
|
|
"当用户要求分析文档时,请使用合适的工具读取内容,然后:1. 总结核心发现。2. 如果涉及数据,请以Markdown表格或列表的形式清晰地呈现。"
|
|||
|
|
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述、<think>标记或内部推理。直接给出最终答案或工具调用指令。"
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
user_input = input("请输入: ")
|
|||
|
|
if user_input.lower() == "exit":
|
|||
|
|
break
|
|||
|
|
# 记录开始时间
|
|||
|
|
start_time = time.time()
|
|||
|
|
response=agent.invoke({"messages":[HumanMessage(content=user_input)]})
|
|||
|
|
# 计算思考时间
|
|||
|
|
thinking_time = time.time() - start_time
|
|||
|
|
# 提取回答内容
|
|||
|
|
final_answer=response["messages"][-1].content
|
|||
|
|
# 打印回答和统计信息
|
|||
|
|
print(f"\n{final_answer}")
|
|||
|
|
print(f"思考时间: {thinking_time:.2f}秒")
|
|||
|
|
print("-" * 50)
|
|||
|
|
|