""" 文本切分器,用于将文档切分成块。 """ from enum import Enum from typing import List, Optional, Tuple, Dict, Any from dataclasses import dataclass, field from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_experimental.text_splitter import SemanticChunker class SplitterType(str, Enum): RECURSIVE = "recursive" SEMANTIC = "semantic" PARENT_CHILD = "parent_child" @dataclass class RecursiveSplitterConfig: chunk_size: int = 500 chunk_overlap: int = 50 separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "。", "!", "?", " ", ""]) keep_separator: bool = True strip_whitespace: bool = True @dataclass class SemanticSplitterConfig: embeddings: Any buffer_size: int = 1 add_start_index: bool = False breakpoint_threshold_type: str = "percentile" breakpoint_threshold_amount: float = 0.6 # 非 None,切分更积极 number_of_chunks: Optional[int] = None sentence_split_regex: str = r"(?<=[。!?;.!?;])" # 中文友好 min_chunk_size: int = 100 @dataclass class ParentChildSplitterConfig: embeddings: Any # 语义切分(用于父块) semantic_threshold_type: str = "percentile" semantic_threshold_amount: float = 0.6 semantic_buffer_size: int = 1 semantic_min_chunk_size: int = 100 # 子块(递归字符切分) child_chunk_size: int = 400 child_chunk_overlap: int = 50 # ---------- 适配器 ---------- class SemanticChunkerAdapter(TextSplitter): def __init__(self, config: SemanticSplitterConfig, **kwargs): super().__init__(**kwargs) self._config = config self._chunker = SemanticChunker( embeddings=config.embeddings, buffer_size=config.buffer_size, add_start_index=config.add_start_index, breakpoint_threshold_type=config.breakpoint_threshold_type, breakpoint_threshold_amount=config.breakpoint_threshold_amount, number_of_chunks=config.number_of_chunks, sentence_split_regex=config.sentence_split_regex, min_chunk_size=config.min_chunk_size, ) def split_text(self, text: str) -> List[str]: return self._chunker.split_text(text) def split_documents(self, documents: List[Document]) -> List[Document]: result = [] for doc in documents: chunks = self.split_text(doc.page_content) for i, chunk in enumerate(chunks): result.append(Document( page_content=chunk, metadata={**doc.metadata, "chunk_index": i} )) return result # ---------- 工厂函数 ---------- def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter: if splitter_type == SplitterType.RECURSIVE: config = RecursiveSplitterConfig( chunk_size=kwargs.get("chunk_size", 500), chunk_overlap=kwargs.get("chunk_overlap", 50), separators=kwargs.get("separators", ["\n\n", "\n", "。", "!", "?", " ", ""]), ) return RecursiveCharacterTextSplitter( chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap, separators=config.separators, keep_separator=config.keep_separator, strip_whitespace=config.strip_whitespace, ) elif splitter_type == SplitterType.SEMANTIC: embeddings = kwargs.get("embeddings") if embeddings is None: raise ValueError("语义切分器需要提供 'embeddings' 参数") if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig): config = kwargs["config"] else: config = SemanticSplitterConfig( embeddings=embeddings, buffer_size=kwargs.get("buffer_size", 1), breakpoint_threshold_type=kwargs.get("breakpoint_threshold_type", "percentile"), breakpoint_threshold_amount=kwargs.get("breakpoint_threshold_amount", 0.6), number_of_chunks=kwargs.get("number_of_chunks"), min_chunk_size=kwargs.get("min_chunk_size", 100), ) return SemanticChunkerAdapter(config) elif splitter_type == SplitterType.PARENT_CHILD: raise ValueError("父子切分器应通过 ParentChildSplitter 直接创建") else: raise ValueError(f"不支持的切分器类型: {splitter_type}") # ---------- 父子切分器 ---------- class ParentChildSplitter: """ 切分流程: 1. 语义切分 → 父块 2. 递归字符切分 → 子块 """ def __init__(self, config: ParentChildSplitterConfig): self.config = config # 语义切分(父块) semantic_config = SemanticSplitterConfig( embeddings=config.embeddings, buffer_size=config.semantic_buffer_size, breakpoint_threshold_type=config.semantic_threshold_type, breakpoint_threshold_amount=config.semantic_threshold_amount, min_chunk_size=config.semantic_min_chunk_size, ) self.semantic_splitter = SemanticChunkerAdapter(semantic_config) # 递归字符切分(子块,大小由 child_chunk_size 控制) self.recursive_splitter = RecursiveCharacterTextSplitter( chunk_size=config.child_chunk_size, chunk_overlap=config.child_chunk_overlap, separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""] ) self.parent_to_children: Dict[str, List[str]] = {} self.child_to_parent: Dict[str, str] = {} def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]: parent_chunks = [] child_chunks = [] for doc in documents: # Step 1: 语义切分(父块) semantic_blocks = self.semantic_splitter.split_text(doc.page_content) for p_idx, semantic_block in enumerate(semantic_blocks): parent_id = f"parent_{len(parent_chunks)}" parent_doc = Document( page_content=semantic_block, metadata={**doc.metadata, "id": parent_id, "chunk_index": p_idx} ) parent_chunks.append(parent_doc) # Step 2: 递归字符切分(子块) sub_chunks = self.recursive_splitter.split_text(semantic_block) for c_idx, sub_chunk in enumerate(sub_chunks): child_id = f"child_{len(child_chunks)}" child_doc = Document( page_content=sub_chunk, metadata={**doc.metadata, "id": child_id, "parent_id": parent_id, "child_index": c_idx} ) child_chunks.append(child_doc) self.child_to_parent[child_id] = parent_id if parent_id not in self.parent_to_children: self.parent_to_children[parent_id] = [] self.parent_to_children[parent_id].append(child_id) return parent_chunks, child_chunks def get_parent_for_child(self, child_id: str) -> Optional[str]: return self.child_to_parent.get(child_id) def get_children_for_parent(self, parent_id: str) -> List[str]: return self.parent_to_children.get(parent_id, [])