RAG数据库生成
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
Document loaders using unstructured library.
|
||||
文档加载器,使用 unstructured 库解析文档。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from unstructured.partition.auto import partition
|
||||
@@ -13,33 +13,74 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentLoader:
|
||||
"""Load documents from various file formats."""
|
||||
"""从各种文件格式加载文档。"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx"}
|
||||
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"}
|
||||
|
||||
def __init__(self, extract_images: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
extract_images: bool = False,
|
||||
strategy: str = "auto",
|
||||
ocr_languages: Optional[List[str]] = None,
|
||||
languages: Optional[List[str]] = None,
|
||||
include_page_breaks: bool = False,
|
||||
pdf_infer_table_structure: bool = True,
|
||||
partition_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
extract_images: Whether to extract images from documents (requires additional dependencies)
|
||||
extract_images: 是否提取 PDF 中的图片
|
||||
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
|
||||
ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng']
|
||||
languages: 文档主语言,如 ['zh']
|
||||
include_page_breaks: 是否包含分页符
|
||||
pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略)
|
||||
partition_kwargs: 额外的 partition 参数字典(高级定制)
|
||||
"""
|
||||
import os
|
||||
os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false"
|
||||
self.extract_images = extract_images
|
||||
self.strategy = strategy
|
||||
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
|
||||
self.languages = languages or ["zh"]
|
||||
self.include_page_breaks = include_page_breaks
|
||||
self.pdf_infer_table_structure = pdf_infer_table_structure
|
||||
self.partition_kwargs = partition_kwargs or {}
|
||||
|
||||
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
|
||||
"""Load a single file into LangChain Document objects."""
|
||||
"""将单个文件加载为 LangChain Document 对象。"""
|
||||
file_path = Path(file_path).resolve()
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix not in self.SUPPORTED_EXTENSIONS:
|
||||
raise ValueError(
|
||||
f"Unsupported file extension: {suffix}. Supported: {self.SUPPORTED_EXTENSIONS}"
|
||||
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
|
||||
)
|
||||
|
||||
# Parse with unstructured
|
||||
# 根据文件类型动态调整参数
|
||||
extra_kwargs = {}
|
||||
if suffix == ".pdf":
|
||||
extra_kwargs["strategy"] = self.strategy
|
||||
extra_kwargs["ocr_languages"] = self.ocr_languages
|
||||
extra_kwargs["extract_images_in_pdf"] = self.extract_images
|
||||
extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure
|
||||
|
||||
# languages 参数适用于所有文件类型
|
||||
if self.languages:
|
||||
extra_kwargs["languages"] = self.languages
|
||||
|
||||
extra_kwargs["include_page_breaks"] = self.include_page_breaks
|
||||
|
||||
# 合并用户自定义的额外参数(优先级最高)
|
||||
extra_kwargs.update(self.partition_kwargs)
|
||||
|
||||
# 使用 unstructured 解析
|
||||
elements = partition(
|
||||
filename=str(file_path),
|
||||
extract_images_in_pdf=self.extract_images,
|
||||
|
||||
**extra_kwargs
|
||||
)
|
||||
|
||||
documents = []
|
||||
@@ -48,23 +89,17 @@ class DocumentLoader:
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
# Base metadata
|
||||
# 基础元数据
|
||||
metadata = {
|
||||
"source": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"file_type": suffix,
|
||||
}
|
||||
|
||||
# Merge element-specific metadata without overwriting base fields
|
||||
elem_meta = getattr(elem, "metadata", {}) or {}
|
||||
for key, value in elem_meta.items():
|
||||
if value and key not in metadata:
|
||||
metadata[key] = value
|
||||
|
||||
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
if not documents:
|
||||
logger.warning("No text content extracted from %s", file_path)
|
||||
logger.warning("未从 %s 提取到文本内容", file_path)
|
||||
return []
|
||||
|
||||
return documents
|
||||
@@ -72,10 +107,10 @@ class DocumentLoader:
|
||||
def load_directory(
|
||||
self, directory_path: Union[str, Path], recursive: bool = True
|
||||
) -> List[Document]:
|
||||
"""Load all supported files from a directory."""
|
||||
"""从目录加载所有支持的文件。"""
|
||||
directory_path = Path(directory_path).resolve()
|
||||
if not directory_path.is_dir():
|
||||
raise NotADirectoryError(f"Not a directory: {directory_path}")
|
||||
raise NotADirectoryError(f"不是目录: {directory_path}")
|
||||
|
||||
all_documents = []
|
||||
pattern = "**/*" if recursive else "*"
|
||||
@@ -86,6 +121,6 @@ class DocumentLoader:
|
||||
docs = self.load_file(file_path)
|
||||
all_documents.extend(docs)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load %s: %s", file_path, e)
|
||||
logger.error("加载 %s 失败: %s", file_path, e)
|
||||
|
||||
return all_documents
|
||||
Reference in New Issue
Block a user