2026-04-18 16:56:23 +08:00
|
|
|
|
"""
|
2026-04-19 15:01:40 +08:00
|
|
|
|
文本切分器,用于将文档切分成块。
|
2026-04-18 16:56:23 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum
|
2026-04-19 22:01:55 +08:00
|
|
|
|
from typing import List, Optional, Tuple, Dict, Any
|
|
|
|
|
|
from dataclasses import dataclass, field
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
|
|
|
|
|
from langchain_core.documents import Document
|
2026-04-19 15:01:40 +08:00
|
|
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
2026-04-18 16:56:23 +08:00
|
|
|
|
from langchain_experimental.text_splitter import SemanticChunker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SplitterType(str, Enum):
|
|
|
|
|
|
RECURSIVE = "recursive"
|
|
|
|
|
|
SEMANTIC = "semantic"
|
|
|
|
|
|
PARENT_CHILD = "parent_child"
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
@dataclass
|
|
|
|
|
|
class RecursiveSplitterConfig:
|
|
|
|
|
|
chunk_size: int = 500
|
|
|
|
|
|
chunk_overlap: int = 50
|
|
|
|
|
|
separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "。", "!", "?", " ", ""])
|
|
|
|
|
|
keep_separator: bool = True
|
|
|
|
|
|
strip_whitespace: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class SemanticSplitterConfig:
|
|
|
|
|
|
embeddings: Any
|
|
|
|
|
|
buffer_size: int = 1
|
|
|
|
|
|
add_start_index: bool = False
|
|
|
|
|
|
breakpoint_threshold_type: str = "percentile"
|
2026-05-05 23:17:00 +08:00
|
|
|
|
breakpoint_threshold_amount: float = 0.6 # 非 None,切分更积极
|
2026-04-19 22:01:55 +08:00
|
|
|
|
number_of_chunks: Optional[int] = None
|
2026-05-05 23:17:00 +08:00
|
|
|
|
sentence_split_regex: str = r"(?<=[。!?;.!?;])" # 中文友好
|
2026-04-19 22:01:55 +08:00
|
|
|
|
min_chunk_size: int = 100
|
2026-04-19 15:01:40 +08:00
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
@dataclass
|
|
|
|
|
|
class ParentChildSplitterConfig:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
embeddings: Any
|
|
|
|
|
|
# 语义切分(用于父块)
|
|
|
|
|
|
semantic_threshold_type: str = "percentile"
|
|
|
|
|
|
semantic_threshold_amount: float = 0.6
|
|
|
|
|
|
semantic_buffer_size: int = 1
|
|
|
|
|
|
semantic_min_chunk_size: int = 100
|
|
|
|
|
|
# 子块(递归字符切分)
|
|
|
|
|
|
child_chunk_size: int = 400
|
|
|
|
|
|
child_chunk_overlap: int = 50
|
2026-04-19 15:01:40 +08:00
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
|
|
|
|
|
# ---------- 适配器 ----------
|
|
|
|
|
|
class SemanticChunkerAdapter(TextSplitter):
|
2026-04-19 22:01:55 +08:00
|
|
|
|
def __init__(self, config: SemanticSplitterConfig, **kwargs):
|
2026-04-19 15:01:40 +08:00
|
|
|
|
super().__init__(**kwargs)
|
2026-04-19 22:01:55 +08:00
|
|
|
|
self._config = config
|
|
|
|
|
|
self._chunker = SemanticChunker(
|
|
|
|
|
|
embeddings=config.embeddings,
|
|
|
|
|
|
buffer_size=config.buffer_size,
|
|
|
|
|
|
add_start_index=config.add_start_index,
|
|
|
|
|
|
breakpoint_threshold_type=config.breakpoint_threshold_type,
|
|
|
|
|
|
breakpoint_threshold_amount=config.breakpoint_threshold_amount,
|
|
|
|
|
|
number_of_chunks=config.number_of_chunks,
|
|
|
|
|
|
sentence_split_regex=config.sentence_split_regex,
|
|
|
|
|
|
min_chunk_size=config.min_chunk_size,
|
|
|
|
|
|
)
|
2026-04-19 15:01:40 +08:00
|
|
|
|
|
|
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
|
|
|
|
return self._chunker.split_text(text)
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
def split_documents(self, documents: List[Document]) -> List[Document]:
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for doc in documents:
|
|
|
|
|
|
chunks = self.split_text(doc.page_content)
|
|
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
|
|
|
|
result.append(Document(
|
|
|
|
|
|
page_content=chunk,
|
|
|
|
|
|
metadata={**doc.metadata, "chunk_index": i}
|
|
|
|
|
|
))
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
# ---------- 工厂函数 ----------
|
2026-04-19 22:01:55 +08:00
|
|
|
|
def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
|
|
|
|
|
|
if splitter_type == SplitterType.RECURSIVE:
|
|
|
|
|
|
config = RecursiveSplitterConfig(
|
|
|
|
|
|
chunk_size=kwargs.get("chunk_size", 500),
|
|
|
|
|
|
chunk_overlap=kwargs.get("chunk_overlap", 50),
|
|
|
|
|
|
separators=kwargs.get("separators", ["\n\n", "\n", "。", "!", "?", " ", ""]),
|
|
|
|
|
|
)
|
|
|
|
|
|
return RecursiveCharacterTextSplitter(
|
|
|
|
|
|
chunk_size=config.chunk_size,
|
|
|
|
|
|
chunk_overlap=config.chunk_overlap,
|
|
|
|
|
|
separators=config.separators,
|
|
|
|
|
|
keep_separator=config.keep_separator,
|
|
|
|
|
|
strip_whitespace=config.strip_whitespace,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
elif splitter_type == SplitterType.SEMANTIC:
|
|
|
|
|
|
embeddings = kwargs.get("embeddings")
|
|
|
|
|
|
if embeddings is None:
|
|
|
|
|
|
raise ValueError("语义切分器需要提供 'embeddings' 参数")
|
|
|
|
|
|
|
|
|
|
|
|
if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig):
|
|
|
|
|
|
config = kwargs["config"]
|
|
|
|
|
|
else:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
config = SemanticSplitterConfig(
|
|
|
|
|
|
embeddings=embeddings,
|
|
|
|
|
|
buffer_size=kwargs.get("buffer_size", 1),
|
|
|
|
|
|
breakpoint_threshold_type=kwargs.get("breakpoint_threshold_type", "percentile"),
|
|
|
|
|
|
breakpoint_threshold_amount=kwargs.get("breakpoint_threshold_amount", 0.6),
|
|
|
|
|
|
number_of_chunks=kwargs.get("number_of_chunks"),
|
|
|
|
|
|
min_chunk_size=kwargs.get("min_chunk_size", 100),
|
|
|
|
|
|
)
|
2026-04-19 22:01:55 +08:00
|
|
|
|
return SemanticChunkerAdapter(config)
|
|
|
|
|
|
|
|
|
|
|
|
elif splitter_type == SplitterType.PARENT_CHILD:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
raise ValueError("父子切分器应通过 ParentChildSplitter 直接创建")
|
2026-04-19 22:01:55 +08:00
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"不支持的切分器类型: {splitter_type}")
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------- 父子切分器 ----------
|
2026-04-18 16:56:23 +08:00
|
|
|
|
class ParentChildSplitter:
|
|
|
|
|
|
"""
|
2026-05-05 23:17:00 +08:00
|
|
|
|
切分流程:
|
|
|
|
|
|
1. 语义切分 → 父块
|
|
|
|
|
|
2. 递归字符切分 → 子块
|
2026-04-18 16:56:23 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
def __init__(self, config: ParentChildSplitterConfig):
|
|
|
|
|
|
self.config = config
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
|
|
|
|
|
# 语义切分(父块)
|
2026-04-19 22:01:55 +08:00
|
|
|
|
semantic_config = SemanticSplitterConfig(
|
|
|
|
|
|
embeddings=config.embeddings,
|
2026-05-05 23:17:00 +08:00
|
|
|
|
buffer_size=config.semantic_buffer_size,
|
|
|
|
|
|
breakpoint_threshold_type=config.semantic_threshold_type,
|
|
|
|
|
|
breakpoint_threshold_amount=config.semantic_threshold_amount,
|
|
|
|
|
|
min_chunk_size=config.semantic_min_chunk_size,
|
|
|
|
|
|
)
|
|
|
|
|
|
self.semantic_splitter = SemanticChunkerAdapter(semantic_config)
|
|
|
|
|
|
|
|
|
|
|
|
# 递归字符切分(子块,大小由 child_chunk_size 控制)
|
|
|
|
|
|
self.recursive_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
|
|
chunk_size=config.child_chunk_size,
|
|
|
|
|
|
chunk_overlap=config.child_chunk_overlap,
|
|
|
|
|
|
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
|
2026-04-18 16:56:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
self.parent_to_children: Dict[str, List[str]] = {}
|
|
|
|
|
|
self.child_to_parent: Dict[str, str] = {}
|
|
|
|
|
|
|
|
|
|
|
|
def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
parent_chunks = []
|
|
|
|
|
|
child_chunks = []
|
2026-04-19 22:01:55 +08:00
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
for doc in documents:
|
|
|
|
|
|
# Step 1: 语义切分(父块)
|
|
|
|
|
|
semantic_blocks = self.semantic_splitter.split_text(doc.page_content)
|
|
|
|
|
|
|
|
|
|
|
|
for p_idx, semantic_block in enumerate(semantic_blocks):
|
|
|
|
|
|
parent_id = f"parent_{len(parent_chunks)}"
|
|
|
|
|
|
parent_doc = Document(
|
|
|
|
|
|
page_content=semantic_block,
|
|
|
|
|
|
metadata={**doc.metadata, "id": parent_id, "chunk_index": p_idx}
|
|
|
|
|
|
)
|
|
|
|
|
|
parent_chunks.append(parent_doc)
|
|
|
|
|
|
|
|
|
|
|
|
# Step 2: 递归字符切分(子块)
|
|
|
|
|
|
sub_chunks = self.recursive_splitter.split_text(semantic_block)
|
|
|
|
|
|
|
|
|
|
|
|
for c_idx, sub_chunk in enumerate(sub_chunks):
|
|
|
|
|
|
child_id = f"child_{len(child_chunks)}"
|
|
|
|
|
|
child_doc = Document(
|
|
|
|
|
|
page_content=sub_chunk,
|
|
|
|
|
|
metadata={**doc.metadata, "id": child_id, "parent_id": parent_id, "child_index": c_idx}
|
|
|
|
|
|
)
|
|
|
|
|
|
child_chunks.append(child_doc)
|
2026-04-19 22:01:55 +08:00
|
|
|
|
|
|
|
|
|
|
self.child_to_parent[child_id] = parent_id
|
2026-05-05 23:17:00 +08:00
|
|
|
|
if parent_id not in self.parent_to_children:
|
|
|
|
|
|
self.parent_to_children[parent_id] = []
|
|
|
|
|
|
self.parent_to_children[parent_id].append(child_id)
|
|
|
|
|
|
|
|
|
|
|
|
return parent_chunks, child_chunks
|
2026-04-19 22:01:55 +08:00
|
|
|
|
|
|
|
|
|
|
def get_parent_for_child(self, child_id: str) -> Optional[str]:
|
|
|
|
|
|
return self.child_to_parent.get(child_id)
|
|
|
|
|
|
|
|
|
|
|
|
def get_children_for_parent(self, parent_id: str) -> List[str]:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
return self.parent_to_children.get(parent_id, [])
|