RAG数据库生成

This commit is contained in:
2026-04-19 15:01:40 +08:00
parent c18e8a9860
commit cc8ef41ef9
17 changed files with 1089 additions and 577 deletions

View File

@@ -1,10 +1,10 @@
"""
Document loaders using unstructured library.
文档加载器,使用 unstructured 库解析文档。
"""
import logging
from pathlib import Path
from typing import List, Union
from typing import Any, Dict, List, Mapping, Optional, Union
from langchain_core.documents import Document
from unstructured.partition.auto import partition
@@ -13,33 +13,74 @@ logger = logging.getLogger(__name__)
class DocumentLoader:
"""Load documents from various file formats."""
"""从各种文件格式加载文档。"""
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx"}
SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"}
def __init__(self, extract_images: bool = False):
def __init__(
self,
extract_images: bool = False,
strategy: str = "auto",
ocr_languages: Optional[List[str]] = None,
languages: Optional[List[str]] = None,
include_page_breaks: bool = False,
pdf_infer_table_structure: bool = True,
partition_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Args:
extract_images: Whether to extract images from documents (requires additional dependencies)
extract_images: 是否提取 PDF 中的图片
strategy: 解析策略 (auto, fast, hi_res, ocr_only)
ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng']
languages: 文档主语言,如 ['zh']
include_page_breaks: 是否包含分页符
pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略)
partition_kwargs: 额外的 partition 参数字典(高级定制)
"""
import os
os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false"
self.extract_images = extract_images
self.strategy = strategy
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
self.languages = languages or ["zh"]
self.include_page_breaks = include_page_breaks
self.pdf_infer_table_structure = pdf_infer_table_structure
self.partition_kwargs = partition_kwargs or {}
def load_file(self, file_path: Union[str, Path]) -> List[Document]:
"""Load a single file into LangChain Document objects."""
"""将单个文件加载为 LangChain Document 对象。"""
file_path = Path(file_path).resolve()
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
raise FileNotFoundError(f"文件不存在: {file_path}")
suffix = file_path.suffix.lower()
if suffix not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"Unsupported file extension: {suffix}. Supported: {self.SUPPORTED_EXTENSIONS}"
f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}"
)
# Parse with unstructured
# 根据文件类型动态调整参数
extra_kwargs = {}
if suffix == ".pdf":
extra_kwargs["strategy"] = self.strategy
extra_kwargs["ocr_languages"] = self.ocr_languages
extra_kwargs["extract_images_in_pdf"] = self.extract_images
extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure
# languages 参数适用于所有文件类型
if self.languages:
extra_kwargs["languages"] = self.languages
extra_kwargs["include_page_breaks"] = self.include_page_breaks
# 合并用户自定义的额外参数(优先级最高)
extra_kwargs.update(self.partition_kwargs)
# 使用 unstructured 解析
elements = partition(
filename=str(file_path),
extract_images_in_pdf=self.extract_images,
**extra_kwargs
)
documents = []
@@ -48,23 +89,17 @@ class DocumentLoader:
if not text or not text.strip():
continue
# Base metadata
# 基础元数据
metadata = {
"source": str(file_path),
"file_name": file_path.name,
"file_type": suffix,
}
# Merge element-specific metadata without overwriting base fields
elem_meta = getattr(elem, "metadata", {}) or {}
for key, value in elem_meta.items():
if value and key not in metadata:
metadata[key] = value
documents.append(Document(page_content=text, metadata=metadata))
if not documents:
logger.warning("No text content extracted from %s", file_path)
logger.warning("未从 %s 提取到文本内容", file_path)
return []
return documents
@@ -72,10 +107,10 @@ class DocumentLoader:
def load_directory(
self, directory_path: Union[str, Path], recursive: bool = True
) -> List[Document]:
"""Load all supported files from a directory."""
"""从目录加载所有支持的文件。"""
directory_path = Path(directory_path).resolve()
if not directory_path.is_dir():
raise NotADirectoryError(f"Not a directory: {directory_path}")
raise NotADirectoryError(f"不是目录: {directory_path}")
all_documents = []
pattern = "**/*" if recursive else "*"
@@ -86,6 +121,6 @@ class DocumentLoader:
docs = self.load_file(file_path)
all_documents.extend(docs)
except Exception as e:
logger.error("Failed to load %s: %s", file_path, e)
logger.error("加载 %s 失败: %s", file_path, e)
return all_documents