Files
ailine/rag_indexer/loaders.py
root 933d418d77
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s
检索器重构
2026-04-19 22:01:55 +08:00

168 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
文档加载器,使用 unstructured 库解析文档。
"""
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langchain_core.documents import Document
from unstructured.documents.elements import Element
from unstructured.partition.auto import partition
logger = logging.getLogger(__name__)
# 模块加载时设置一次环境变量,避免重复设置
os.environ.setdefault("UNSTRUCTURED_LANGUAGE_CHECKS", "false")
class DocumentLoader:
"""从各种文件格式加载文档。"""
SUPPORTED_EXTENSIONS = {
".pdf", ".docx", ".doc", ".txt", ".md",
".html", ".pptx", ".xlsx", ".json"
}
def __init__(
self,
extract_images: bool = False,
strategy: str = "auto",
ocr_languages: Optional[List[str]] = None,
languages: Optional[List[str]] = None,
include_page_breaks: bool = False,
pdf_infer_table_structure: bool = True,
partition_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Args:
extract_images: 是否提取 PDF 中的图片
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng']
languages: 文档主语言,如 ['zh'](主要用于非 OCR 场景)
include_page_breaks: 是否包含分页符
pdf_infer_table_structure: 是否识别表格结构(需 hi_res 策略)
partition_kwargs: 额外的 partition 参数字典(高级定制)
"""
self.extract_images = extract_images
self.strategy = strategy
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
self.languages = languages or ["zh"]
self.include_page_breaks = include_page_breaks
self.pdf_infer_table_structure = pdf_infer_table_structure
self.partition_kwargs = partition_kwargs or {}
def _build_partition_kwargs(self, file_path: Path) -> Dict[str, Any]:
"""根据文件类型构建 partition 的参数。"""
kwargs: Dict[str, Any] = {
"include_page_breaks": self.include_page_breaks,
}
suffix = file_path.suffix.lower()
# PDF 专用参数
if suffix == ".pdf":
kwargs.update({
"strategy": self.strategy,
"ocr_languages": self.ocr_languages,
"extract_images_in_pdf": self.extract_images,
"pdf_infer_table_structure": self.pdf_infer_table_structure,
})
# 所有文件适用的语言参数
if self.languages:
kwargs["languages"] = self.languages
# 用户自定义参数覆盖默认值
kwargs.update(self.partition_kwargs)
return kwargs
def _element_to_document(self, element: Element, file_path: Path) -> Optional[Document]:
"""将单个 Element 转换为 Document同时保留关键元数据。"""
text = getattr(element, "text", "")
if not text or not text.strip():
return None
# 提取 unstructured 提供的元数据(根据实际需要选择)
metadata = {
"source": str(file_path),
"file_name": file_path.name,
"file_type": file_path.suffix.lower(),
# 以下元数据来自 Element 对象,可能为 None
"page_number": getattr(getattr(element, "metadata", None), "page_number", None),
"category": getattr(getattr(element, "metadata", None), "category", None),
}
# 过滤掉值为 None 的元数据
metadata = {k: v for k, v in metadata.items() if v is not None}
return Document(page_content=text, metadata=metadata)
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
"""将单个文件加载为 LangChain Document 对象。"""
file_path = Path(file_path).resolve()
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
suffix = file_path.suffix.lower()
if suffix not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
)
kwargs = self._build_partition_kwargs(file_path)
try:
elements = partition(filename=str(file_path), **kwargs)
except Exception as e:
logger.exception("解析文件 %s 失败", file_path)
raise RuntimeError(f"文件解析失败: {file_path}") from e
documents = []
for elem in elements:
doc = self._element_to_document(elem, file_path)
if doc:
documents.append(doc)
if not documents:
logger.warning("未从 %s 提取到文本内容", file_path)
return documents
def load_directory(
self,
directory_path: Union[str, Path],
recursive: bool = True,
fail_fast: bool = False
) -> List[Document]:
"""
从目录加载所有支持的文件。
Args:
directory_path: 目录路径
recursive: 是否递归子目录
fail_fast: 遇到第一个失败时是否立即抛出异常
"""
directory_path = Path(directory_path).resolve()
if not directory_path.is_dir():
raise NotADirectoryError(f"不是目录: {directory_path}")
all_documents: List[Document] = []
pattern = "**/*" if recursive else "*"
for file_path in directory_path.glob(pattern):
if not file_path.is_file():
continue
if file_path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
continue
try:
docs = self.load_file(file_path)
all_documents.extend(docs)
except Exception as e:
logger.error("加载 %s 失败: %s", file_path, e)
if fail_fast:
raise
return all_documents