""" 文本切分器,用于将文档切分成块。 """ from enum import Enum from typing import List, Optional, Tuple, Dict, Any from dataclasses import dataclass, field from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_experimental.text_splitter import SemanticChunker class SplitterType(str, Enum): RECURSIVE = "recursive" SEMANTIC = "semantic" PARENT_CHILD = "parent_child" # ---------- 配置数据类,统一参数 ---------- @dataclass class RecursiveSplitterConfig: """递归字符切分器配置""" chunk_size: int = 500 chunk_overlap: int = 50 separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "。", "!", "?", " ", ""]) keep_separator: bool = True strip_whitespace: bool = True @dataclass class SemanticSplitterConfig: """语义切分器配置,仅包含 SemanticChunker 支持的参数。""" embeddings: Any buffer_size: int = 1 add_start_index: bool = False breakpoint_threshold_type: str = "percentile" breakpoint_threshold_amount: Optional[float] = None number_of_chunks: Optional[int] = None sentence_split_regex: str = r"(?<=[.?!。?!])\s+" min_chunk_size: int = 100 @dataclass class ParentChildSplitterConfig: """父子切分器配置""" embeddings: Any # 子块语义切分所需 parent_chunk_size: int = 1000 parent_chunk_overlap: int = 100 child_buffer_size: int = 1 child_breakpoint_threshold_type: str = "percentile" child_breakpoint_threshold_amount: Optional[float] = None child_min_chunk_size: int = 100 child_max_chunk_size: Optional[int] = 200 # ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ---------- class SemanticChunkerAdapter(TextSplitter): """将 SemanticChunker 适配为 LangChain TextSplitter 接口。""" def __init__(self, config: SemanticSplitterConfig, **kwargs): super().__init__(**kwargs) self._config = config self._chunker = SemanticChunker( embeddings=config.embeddings, buffer_size=config.buffer_size, add_start_index=config.add_start_index, breakpoint_threshold_type=config.breakpoint_threshold_type, breakpoint_threshold_amount=config.breakpoint_threshold_amount, number_of_chunks=config.number_of_chunks, sentence_split_regex=config.sentence_split_regex, min_chunk_size=config.min_chunk_size, ) def split_text(self, text: str) -> List[str]: return self._chunker.split_text(text) def split_documents(self, documents: List[Document]) -> List[Document]: result = [] for doc in documents: chunks = self.split_text(doc.page_content) for i, chunk in enumerate(chunks): result.append(Document( page_content=chunk, metadata={**doc.metadata, "chunk_index": i} )) return result # ---------- 工厂函数,统一创建切分器 ---------- def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter: """ 根据类型创建切分器。 支持传入配置对象或直接参数。 """ if splitter_type == SplitterType.RECURSIVE: config = RecursiveSplitterConfig( chunk_size=kwargs.get("chunk_size", 500), chunk_overlap=kwargs.get("chunk_overlap", 50), separators=kwargs.get("separators", ["\n\n", "\n", "。", "!", "?", " ", ""]), ) return RecursiveCharacterTextSplitter( chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap, separators=config.separators, keep_separator=config.keep_separator, strip_whitespace=config.strip_whitespace, ) elif splitter_type == SplitterType.SEMANTIC: embeddings = kwargs.get("embeddings") if embeddings is None: raise ValueError("语义切分器需要提供 'embeddings' 参数") if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig): config = kwargs["config"] else: # 过滤出 SemanticSplitterConfig 支持的字段 config_kwargs = { "embeddings": embeddings, "buffer_size": kwargs.get("buffer_size", 1), "breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"), "breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"), "number_of_chunks": kwargs.get("number_of_chunks"), "min_chunk_size": kwargs.get("min_chunk_size", 100), } config = SemanticSplitterConfig(**config_kwargs) return SemanticChunkerAdapter(config) elif splitter_type == SplitterType.PARENT_CHILD: # 父子切分器在 builder 中单独处理,不通过本函数创建 raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建") else: raise ValueError(f"不支持的切分器类型: {splitter_type}") # ---------- 父子切分器实现 ---------- class ParentChildSplitter: """ 将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。 内部维护父子块之间的映射关系。 """ def __init__(self, config: ParentChildSplitterConfig): self.config = config # 父块使用递归字符切分 self.parent_splitter = RecursiveCharacterTextSplitter( chunk_size=config.parent_chunk_size, chunk_overlap=config.parent_chunk_overlap, ) # 子块使用语义切分 semantic_config = SemanticSplitterConfig( embeddings=config.embeddings, buffer_size=config.child_buffer_size, breakpoint_threshold_type=config.child_breakpoint_threshold_type, breakpoint_threshold_amount=config.child_breakpoint_threshold_amount, min_chunk_size=config.child_min_chunk_size, ) self.child_splitter = SemanticChunkerAdapter(semantic_config) # 存储父子块映射关系(可选) self.parent_to_children: Dict[str, List[str]] = {} self.child_to_parent: Dict[str, str] = {} def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]: """ 返回: (父块列表, 子块列表) 同时填充内部映射字典。 """ parent_chunks = self.parent_splitter.split_documents(documents) child_chunks = self.child_splitter.split_documents(documents) # 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法) # 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位 self._build_mappings(parent_chunks, child_chunks) return parent_chunks, child_chunks def _build_mappings(self, parents: List[Document], children: List[Document]) -> None: """ 根据文本内容建立父子映射。 本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。 """ self.parent_to_children.clear() self.child_to_parent.clear() # 为每个父块生成唯一 ID(若无则使用索引) for p_idx, parent in enumerate(parents): parent_id = parent.metadata.get("id", f"parent_{p_idx}") parent.metadata["id"] = parent_id self.parent_to_children[parent_id] = [] # 将每个子块分配给包含其文本的第一个父块 for c_idx, child in enumerate(children): child_id = child.metadata.get("id", f"child_{c_idx}") child.metadata["id"] = child_id for parent in parents: if child.page_content in parent.page_content: parent_id = parent.metadata["id"] self.parent_to_children[parent_id].append(child_id) self.child_to_parent[child_id] = parent_id child.metadata["parent_id"] = parent_id break def get_parent_for_child(self, child_id: str) -> Optional[str]: """根据子块 ID 获取父块 ID""" return self.child_to_parent.get(child_id) def get_children_for_parent(self, parent_id: str) -> List[str]: """根据父块 ID 获取所有子块 ID""" return self.parent_to_children.get(parent_id, [])