Files
ailine/rag_indexer/splitters.py
root 3ae9daa01a
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m44s
导入方式修改
2026-05-05 23:17:00 +08:00

196 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
文本切分器,用于将文档切分成块。
"""
from enum import Enum
from typing import List, Optional, Tuple, Dict, Any
from dataclasses import dataclass, field
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_experimental.text_splitter import SemanticChunker
class SplitterType(str, Enum):
RECURSIVE = "recursive"
SEMANTIC = "semantic"
PARENT_CHILD = "parent_child"
@dataclass
class RecursiveSplitterConfig:
chunk_size: int = 500
chunk_overlap: int = 50
separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "", "", "", " ", ""])
keep_separator: bool = True
strip_whitespace: bool = True
@dataclass
class SemanticSplitterConfig:
embeddings: Any
buffer_size: int = 1
add_start_index: bool = False
breakpoint_threshold_type: str = "percentile"
breakpoint_threshold_amount: float = 0.6 # 非 None切分更积极
number_of_chunks: Optional[int] = None
sentence_split_regex: str = r"(?<=[。!?;.!?;])" # 中文友好
min_chunk_size: int = 100
@dataclass
class ParentChildSplitterConfig:
embeddings: Any
# 语义切分(用于父块)
semantic_threshold_type: str = "percentile"
semantic_threshold_amount: float = 0.6
semantic_buffer_size: int = 1
semantic_min_chunk_size: int = 100
# 子块(递归字符切分)
child_chunk_size: int = 400
child_chunk_overlap: int = 50
# ---------- 适配器 ----------
class SemanticChunkerAdapter(TextSplitter):
def __init__(self, config: SemanticSplitterConfig, **kwargs):
super().__init__(**kwargs)
self._config = config
self._chunker = SemanticChunker(
embeddings=config.embeddings,
buffer_size=config.buffer_size,
add_start_index=config.add_start_index,
breakpoint_threshold_type=config.breakpoint_threshold_type,
breakpoint_threshold_amount=config.breakpoint_threshold_amount,
number_of_chunks=config.number_of_chunks,
sentence_split_regex=config.sentence_split_regex,
min_chunk_size=config.min_chunk_size,
)
def split_text(self, text: str) -> List[str]:
return self._chunker.split_text(text)
def split_documents(self, documents: List[Document]) -> List[Document]:
result = []
for doc in documents:
chunks = self.split_text(doc.page_content)
for i, chunk in enumerate(chunks):
result.append(Document(
page_content=chunk,
metadata={**doc.metadata, "chunk_index": i}
))
return result
# ---------- 工厂函数 ----------
def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
if splitter_type == SplitterType.RECURSIVE:
config = RecursiveSplitterConfig(
chunk_size=kwargs.get("chunk_size", 500),
chunk_overlap=kwargs.get("chunk_overlap", 50),
separators=kwargs.get("separators", ["\n\n", "\n", "", "", "", " ", ""]),
)
return RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
separators=config.separators,
keep_separator=config.keep_separator,
strip_whitespace=config.strip_whitespace,
)
elif splitter_type == SplitterType.SEMANTIC:
embeddings = kwargs.get("embeddings")
if embeddings is None:
raise ValueError("语义切分器需要提供 'embeddings' 参数")
if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig):
config = kwargs["config"]
else:
config = SemanticSplitterConfig(
embeddings=embeddings,
buffer_size=kwargs.get("buffer_size", 1),
breakpoint_threshold_type=kwargs.get("breakpoint_threshold_type", "percentile"),
breakpoint_threshold_amount=kwargs.get("breakpoint_threshold_amount", 0.6),
number_of_chunks=kwargs.get("number_of_chunks"),
min_chunk_size=kwargs.get("min_chunk_size", 100),
)
return SemanticChunkerAdapter(config)
elif splitter_type == SplitterType.PARENT_CHILD:
raise ValueError("父子切分器应通过 ParentChildSplitter 直接创建")
else:
raise ValueError(f"不支持的切分器类型: {splitter_type}")
# ---------- 父子切分器 ----------
class ParentChildSplitter:
"""
切分流程:
1. 语义切分 → 父块
2. 递归字符切分 → 子块
"""
def __init__(self, config: ParentChildSplitterConfig):
self.config = config
# 语义切分(父块)
semantic_config = SemanticSplitterConfig(
embeddings=config.embeddings,
buffer_size=config.semantic_buffer_size,
breakpoint_threshold_type=config.semantic_threshold_type,
breakpoint_threshold_amount=config.semantic_threshold_amount,
min_chunk_size=config.semantic_min_chunk_size,
)
self.semantic_splitter = SemanticChunkerAdapter(semantic_config)
# 递归字符切分(子块,大小由 child_chunk_size 控制)
self.recursive_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.child_chunk_size,
chunk_overlap=config.child_chunk_overlap,
separators=["\n\n", "\n", "", "", "", "", "", " ", ""]
)
self.parent_to_children: Dict[str, List[str]] = {}
self.child_to_parent: Dict[str, str] = {}
def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]:
parent_chunks = []
child_chunks = []
for doc in documents:
# Step 1: 语义切分(父块)
semantic_blocks = self.semantic_splitter.split_text(doc.page_content)
for p_idx, semantic_block in enumerate(semantic_blocks):
parent_id = f"parent_{len(parent_chunks)}"
parent_doc = Document(
page_content=semantic_block,
metadata={**doc.metadata, "id": parent_id, "chunk_index": p_idx}
)
parent_chunks.append(parent_doc)
# Step 2: 递归字符切分(子块)
sub_chunks = self.recursive_splitter.split_text(semantic_block)
for c_idx, sub_chunk in enumerate(sub_chunks):
child_id = f"child_{len(child_chunks)}"
child_doc = Document(
page_content=sub_chunk,
metadata={**doc.metadata, "id": child_id, "parent_id": parent_id, "child_index": c_idx}
)
child_chunks.append(child_doc)
self.child_to_parent[child_id] = parent_id
if parent_id not in self.parent_to_children:
self.parent_to_children[parent_id] = []
self.parent_to_children[parent_id].append(child_id)
return parent_chunks, child_chunks
def get_parent_for_child(self, child_id: str) -> Optional[str]:
return self.child_to_parent.get(child_id)
def get_children_for_parent(self, parent_id: str) -> List[str]:
return self.parent_to_children.get(parent_id, [])