""" 文档加载器,使用 unstructured 库解析文档。 """ import logging import os from pathlib import Path from typing import Any, Dict, List, Optional, Union from langchain_core.documents import Document from unstructured.documents.elements import Element from unstructured.partition.auto import partition # 相对导入配置 from .config import RAG_OCR_LANGUAGES, RAG_DOC_LANGUAGES logger = logging.getLogger(__name__) # 模块加载时设置一次环境变量,避免重复设置 os.environ.setdefault("UNSTRUCTURED_LANGUAGE_CHECKS", "false") class DocumentLoader: """从各种文件格式加载文档。""" SUPPORTED_EXTENSIONS = { ".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json" } def __init__( self, extract_images: bool = False, strategy: str = "auto", ocr_languages: Optional[List[str]] = None, languages: Optional[List[str]] = None, include_page_breaks: bool = False, pdf_infer_table_structure: bool = True, partition_kwargs: Optional[Dict[str, Any]] = None, ): """ Args: extract_images: 是否提取 PDF 中的图片 strategy: 解析策略 (auto, fast, hi_res, ocr_only) ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng'] languages: 文档主语言,如 ['zh'](主要用于非 OCR 场景) include_page_breaks: 是否包含分页符 pdf_infer_table_structure: 是否识别表格结构(需 hi_res 策略) partition_kwargs: 额外的 partition 参数字典(高级定制) """ self.extract_images = extract_images self.strategy = strategy self.ocr_languages = ocr_languages or RAG_OCR_LANGUAGES self.languages = languages or RAG_DOC_LANGUAGES self.include_page_breaks = include_page_breaks self.pdf_infer_table_structure = pdf_infer_table_structure self.partition_kwargs = partition_kwargs or {} def _build_partition_kwargs(self, file_path: Path) -> Dict[str, Any]: """根据文件类型构建 partition 的参数。""" kwargs: Dict[str, Any] = { "include_page_breaks": self.include_page_breaks, } suffix = file_path.suffix.lower() # PDF 专用参数 if suffix == ".pdf": kwargs.update({ "strategy": self.strategy, "ocr_languages": self.ocr_languages, "extract_images_in_pdf": self.extract_images, "pdf_infer_table_structure": self.pdf_infer_table_structure, }) # 所有文件适用的语言参数 if self.languages: kwargs["languages"] = self.languages # 用户自定义参数覆盖默认值 kwargs.update(self.partition_kwargs) return kwargs def _element_to_document(self, element: Element, file_path: Path) -> Optional[Document]: """将单个 Element 转换为 Document,同时保留关键元数据。""" text = getattr(element, "text", "") if not text or not text.strip(): return None # 提取 unstructured 提供的元数据(根据实际需要选择) metadata = { "source": str(file_path), "file_name": file_path.name, "file_type": file_path.suffix.lower(), # 以下元数据来自 Element 对象,可能为 None "page_number": getattr(getattr(element, "metadata", None), "page_number", None), "category": getattr(getattr(element, "metadata", None), "category", None), } # 过滤掉值为 None 的元数据 metadata = {k: v for k, v in metadata.items() if v is not None} return Document(page_content=text, metadata=metadata) def load_file(self, file_path: Union[str, Path]) -> List[Document]: """将单个文件加载为 LangChain Document 对象。""" file_path = Path(file_path).resolve() if not file_path.exists(): raise FileNotFoundError(f"文件不存在: {file_path}") suffix = file_path.suffix.lower() if suffix not in self.SUPPORTED_EXTENSIONS: raise ValueError( f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}" ) kwargs = self._build_partition_kwargs(file_path) try: elements = partition(filename=str(file_path), **kwargs) except Exception as e: logger.exception("解析文件 %s 失败", file_path) raise RuntimeError(f"文件解析失败: {file_path}") from e documents = [] for elem in elements: doc = self._element_to_document(elem, file_path) if doc: documents.append(doc) if not documents: logger.warning("未从 %s 提取到文本内容", file_path) return documents def load_directory( self, directory_path: Union[str, Path], recursive: bool = True, fail_fast: bool = False ) -> List[Document]: """ 从目录加载所有支持的文件。 Args: directory_path: 目录路径 recursive: 是否递归子目录 fail_fast: 遇到第一个失败时是否立即抛出异常 """ directory_path = Path(directory_path).resolve() if not directory_path.is_dir(): raise NotADirectoryError(f"不是目录: {directory_path}") all_documents: List[Document] = [] pattern = "**/*" if recursive else "*" for file_path in directory_path.glob(pattern): if not file_path.is_file(): continue if file_path.suffix.lower() not in self.SUPPORTED_EXTENSIONS: continue try: docs = self.load_file(file_path) all_documents.extend(docs) except Exception as e: logger.error("加载 %s 失败: %s", file_path, e) if fail_fast: raise return all_documents