Files
ailine/rag_indexer/loaders.py

171 lines
6.0 KiB
Python
Raw Normal View History

2026-04-18 16:56:23 +08:00
"""
2026-04-19 15:01:40 +08:00
文档加载器使用 unstructured 库解析文档
2026-04-18 16:56:23 +08:00
"""
import logging
2026-04-19 22:01:55 +08:00
import os
2026-04-18 16:56:23 +08:00
from pathlib import Path
2026-04-19 22:01:55 +08:00
from typing import Any, Dict, List, Optional, Union
2026-04-18 16:56:23 +08:00
from langchain_core.documents import Document
2026-04-19 22:01:55 +08:00
from unstructured.documents.elements import Element
2026-04-18 16:56:23 +08:00
from unstructured.partition.auto import partition
2026-04-21 18:41:14 +08:00
# 相对导入配置
from .config import RAG_OCR_LANGUAGES, RAG_DOC_LANGUAGES
2026-04-18 16:56:23 +08:00
logger = logging.getLogger(__name__)
2026-04-19 22:01:55 +08:00
# 模块加载时设置一次环境变量,避免重复设置
os.environ.setdefault("UNSTRUCTURED_LANGUAGE_CHECKS", "false")
2026-04-18 16:56:23 +08:00
class DocumentLoader:
2026-04-19 15:01:40 +08:00
"""从各种文件格式加载文档。"""
2026-04-19 22:01:55 +08:00
SUPPORTED_EXTENSIONS = {
".pdf", ".docx", ".doc", ".txt", ".md",
".html", ".pptx", ".xlsx", ".json"
}
2026-04-19 15:01:40 +08:00
def __init__(
self,
extract_images: bool = False,
strategy: str = "auto",
ocr_languages: Optional[List[str]] = None,
languages: Optional[List[str]] = None,
include_page_breaks: bool = False,
pdf_infer_table_structure: bool = True,
partition_kwargs: Optional[Dict[str, Any]] = None,
):
2026-04-18 16:56:23 +08:00
"""
Args:
2026-04-19 15:01:40 +08:00
extract_images: 是否提取 PDF 中的图片
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
ocr_languages: OCR 语言列表 ['chi_sim', 'eng']
2026-04-19 22:01:55 +08:00
languages: 文档主语言 ['zh']主要用于非 OCR 场景
2026-04-19 15:01:40 +08:00
include_page_breaks: 是否包含分页符
2026-04-19 22:01:55 +08:00
pdf_infer_table_structure: 是否识别表格结构 hi_res 策略
2026-04-19 15:01:40 +08:00
partition_kwargs: 额外的 partition 参数字典高级定制
2026-04-18 16:56:23 +08:00
"""
self.extract_images = extract_images
2026-04-19 15:01:40 +08:00
self.strategy = strategy
2026-04-21 18:41:14 +08:00
self.ocr_languages = ocr_languages or RAG_OCR_LANGUAGES
self.languages = languages or RAG_DOC_LANGUAGES
2026-04-19 15:01:40 +08:00
self.include_page_breaks = include_page_breaks
self.pdf_infer_table_structure = pdf_infer_table_structure
self.partition_kwargs = partition_kwargs or {}
2026-04-18 16:56:23 +08:00
2026-04-19 22:01:55 +08:00
def _build_partition_kwargs(self, file_path: Path) -> Dict[str, Any]:
"""根据文件类型构建 partition 的参数。"""
kwargs: Dict[str, Any] = {
"include_page_breaks": self.include_page_breaks,
}
suffix = file_path.suffix.lower()
# PDF 专用参数
if suffix == ".pdf":
kwargs.update({
"strategy": self.strategy,
"ocr_languages": self.ocr_languages,
"extract_images_in_pdf": self.extract_images,
"pdf_infer_table_structure": self.pdf_infer_table_structure,
})
# 所有文件适用的语言参数
if self.languages:
kwargs["languages"] = self.languages
# 用户自定义参数覆盖默认值
kwargs.update(self.partition_kwargs)
return kwargs
def _element_to_document(self, element: Element, file_path: Path) -> Optional[Document]:
"""将单个 Element 转换为 Document同时保留关键元数据。"""
text = getattr(element, "text", "")
if not text or not text.strip():
return None
# 提取 unstructured 提供的元数据(根据实际需要选择)
metadata = {
"source": str(file_path),
"file_name": file_path.name,
"file_type": file_path.suffix.lower(),
# 以下元数据来自 Element 对象,可能为 None
"page_number": getattr(getattr(element, "metadata", None), "page_number", None),
"category": getattr(getattr(element, "metadata", None), "category", None),
}
# 过滤掉值为 None 的元数据
metadata = {k: v for k, v in metadata.items() if v is not None}
return Document(page_content=text, metadata=metadata)
2026-04-18 16:56:23 +08:00
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
2026-04-19 15:01:40 +08:00
"""将单个文件加载为 LangChain Document 对象。"""
2026-04-18 16:56:23 +08:00
file_path = Path(file_path).resolve()
if not file_path.exists():
2026-04-19 15:01:40 +08:00
raise FileNotFoundError(f"文件不存在: {file_path}")
2026-04-18 16:56:23 +08:00
suffix = file_path.suffix.lower()
if suffix not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
2026-04-19 15:01:40 +08:00
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
2026-04-18 16:56:23 +08:00
)
2026-04-19 22:01:55 +08:00
kwargs = self._build_partition_kwargs(file_path)
2026-04-19 15:01:40 +08:00
2026-04-19 22:01:55 +08:00
try:
elements = partition(filename=str(file_path), **kwargs)
except Exception as e:
logger.exception("解析文件 %s 失败", file_path)
raise RuntimeError(f"文件解析失败: {file_path}") from e
2026-04-18 16:56:23 +08:00
documents = []
for elem in elements:
2026-04-19 22:01:55 +08:00
doc = self._element_to_document(elem, file_path)
if doc:
documents.append(doc)
2026-04-18 16:56:23 +08:00
if not documents:
2026-04-19 15:01:40 +08:00
logger.warning("未从 %s 提取到文本内容", file_path)
2026-04-18 16:56:23 +08:00
return documents
def load_directory(
2026-04-19 22:01:55 +08:00
self,
directory_path: Union[str, Path],
recursive: bool = True,
fail_fast: bool = False
2026-04-18 16:56:23 +08:00
) -> List[Document]:
2026-04-19 22:01:55 +08:00
"""
从目录加载所有支持的文件
Args:
directory_path: 目录路径
recursive: 是否递归子目录
fail_fast: 遇到第一个失败时是否立即抛出异常
"""
2026-04-18 16:56:23 +08:00
directory_path = Path(directory_path).resolve()
if not directory_path.is_dir():
2026-04-19 15:01:40 +08:00
raise NotADirectoryError(f"不是目录: {directory_path}")
2026-04-18 16:56:23 +08:00
2026-04-19 22:01:55 +08:00
all_documents: List[Document] = []
2026-04-18 16:56:23 +08:00
pattern = "**/*" if recursive else "*"
for file_path in directory_path.glob(pattern):
2026-04-19 22:01:55 +08:00
if not file_path.is_file():
continue
if file_path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
continue
try:
docs = self.load_file(file_path)
all_documents.extend(docs)
except Exception as e:
logger.error("加载 %s 失败: %s", file_path, e)
if fail_fast:
raise
2026-04-18 16:56:23 +08:00
return all_documents