211 lines
8.5 KiB
Python
211 lines
8.5 KiB
Python
"""
|
||
文本切分器,用于将文档切分成块。
|
||
"""
|
||
|
||
from enum import Enum
|
||
from typing import List, Optional, Tuple, Dict, Any
|
||
from dataclasses import dataclass, field
|
||
|
||
from langchain_core.documents import Document
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||
from langchain_experimental.text_splitter import SemanticChunker
|
||
|
||
|
||
class SplitterType(str, Enum):
|
||
RECURSIVE = "recursive"
|
||
SEMANTIC = "semantic"
|
||
PARENT_CHILD = "parent_child"
|
||
|
||
|
||
# ---------- 配置数据类,统一参数 ----------
|
||
@dataclass
|
||
class RecursiveSplitterConfig:
|
||
"""递归字符切分器配置"""
|
||
chunk_size: int = 500
|
||
chunk_overlap: int = 50
|
||
separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "。", "!", "?", " ", ""])
|
||
keep_separator: bool = True
|
||
strip_whitespace: bool = True
|
||
|
||
|
||
@dataclass
|
||
class SemanticSplitterConfig:
|
||
"""语义切分器配置,仅包含 SemanticChunker 支持的参数。"""
|
||
embeddings: Any
|
||
buffer_size: int = 1
|
||
add_start_index: bool = False
|
||
breakpoint_threshold_type: str = "percentile"
|
||
breakpoint_threshold_amount: Optional[float] = None
|
||
number_of_chunks: Optional[int] = None
|
||
sentence_split_regex: str = r"(?<=[.?!。?!])\s+"
|
||
min_chunk_size: int = 100
|
||
|
||
@dataclass
|
||
class ParentChildSplitterConfig:
|
||
"""父子切分器配置"""
|
||
embeddings: Any # 子块语义切分所需
|
||
parent_chunk_size: int = 1000
|
||
parent_chunk_overlap: int = 100
|
||
child_buffer_size: int = 1
|
||
child_breakpoint_threshold_type: str = "percentile"
|
||
child_breakpoint_threshold_amount: Optional[float] = None
|
||
child_min_chunk_size: int = 100
|
||
child_max_chunk_size: Optional[int] = 200
|
||
|
||
|
||
# ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ----------
|
||
class SemanticChunkerAdapter(TextSplitter):
|
||
"""将 SemanticChunker 适配为 LangChain TextSplitter 接口。"""
|
||
|
||
def __init__(self, config: SemanticSplitterConfig, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self._config = config
|
||
self._chunker = SemanticChunker(
|
||
embeddings=config.embeddings,
|
||
buffer_size=config.buffer_size,
|
||
add_start_index=config.add_start_index,
|
||
breakpoint_threshold_type=config.breakpoint_threshold_type,
|
||
breakpoint_threshold_amount=config.breakpoint_threshold_amount,
|
||
number_of_chunks=config.number_of_chunks,
|
||
sentence_split_regex=config.sentence_split_regex,
|
||
min_chunk_size=config.min_chunk_size,
|
||
)
|
||
|
||
def split_text(self, text: str) -> List[str]:
|
||
return self._chunker.split_text(text)
|
||
|
||
def split_documents(self, documents: List[Document]) -> List[Document]:
|
||
result = []
|
||
for doc in documents:
|
||
chunks = self.split_text(doc.page_content)
|
||
for i, chunk in enumerate(chunks):
|
||
result.append(Document(
|
||
page_content=chunk,
|
||
metadata={**doc.metadata, "chunk_index": i}
|
||
))
|
||
return result
|
||
|
||
|
||
# ---------- 工厂函数,统一创建切分器 ----------
|
||
def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
|
||
"""
|
||
根据类型创建切分器。
|
||
支持传入配置对象或直接参数。
|
||
"""
|
||
if splitter_type == SplitterType.RECURSIVE:
|
||
config = RecursiveSplitterConfig(
|
||
chunk_size=kwargs.get("chunk_size", 500),
|
||
chunk_overlap=kwargs.get("chunk_overlap", 50),
|
||
separators=kwargs.get("separators", ["\n\n", "\n", "。", "!", "?", " ", ""]),
|
||
)
|
||
return RecursiveCharacterTextSplitter(
|
||
chunk_size=config.chunk_size,
|
||
chunk_overlap=config.chunk_overlap,
|
||
separators=config.separators,
|
||
keep_separator=config.keep_separator,
|
||
strip_whitespace=config.strip_whitespace,
|
||
)
|
||
|
||
elif splitter_type == SplitterType.SEMANTIC:
|
||
embeddings = kwargs.get("embeddings")
|
||
if embeddings is None:
|
||
raise ValueError("语义切分器需要提供 'embeddings' 参数")
|
||
|
||
if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig):
|
||
config = kwargs["config"]
|
||
else:
|
||
# 过滤出 SemanticSplitterConfig 支持的字段
|
||
config_kwargs = {
|
||
"embeddings": embeddings,
|
||
"buffer_size": kwargs.get("buffer_size", 1),
|
||
"breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"),
|
||
"breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"),
|
||
"number_of_chunks": kwargs.get("number_of_chunks"),
|
||
"min_chunk_size": kwargs.get("min_chunk_size", 100),
|
||
}
|
||
config = SemanticSplitterConfig(**config_kwargs)
|
||
return SemanticChunkerAdapter(config)
|
||
|
||
elif splitter_type == SplitterType.PARENT_CHILD:
|
||
# 父子切分器在 builder 中单独处理,不通过本函数创建
|
||
raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建")
|
||
|
||
else:
|
||
raise ValueError(f"不支持的切分器类型: {splitter_type}")
|
||
|
||
# ---------- 父子切分器实现 ----------
|
||
class ParentChildSplitter:
|
||
"""
|
||
将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。
|
||
内部维护父子块之间的映射关系。
|
||
"""
|
||
|
||
def __init__(self, config: ParentChildSplitterConfig):
|
||
self.config = config
|
||
# 父块使用递归字符切分
|
||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=config.parent_chunk_size,
|
||
chunk_overlap=config.parent_chunk_overlap,
|
||
)
|
||
# 子块使用语义切分
|
||
semantic_config = SemanticSplitterConfig(
|
||
embeddings=config.embeddings,
|
||
buffer_size=config.child_buffer_size,
|
||
breakpoint_threshold_type=config.child_breakpoint_threshold_type,
|
||
breakpoint_threshold_amount=config.child_breakpoint_threshold_amount,
|
||
min_chunk_size=config.child_min_chunk_size,
|
||
)
|
||
self.child_splitter = SemanticChunkerAdapter(semantic_config)
|
||
|
||
# 存储父子块映射关系(可选)
|
||
self.parent_to_children: Dict[str, List[str]] = {}
|
||
self.child_to_parent: Dict[str, str] = {}
|
||
|
||
def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]:
|
||
"""
|
||
返回:
|
||
(父块列表, 子块列表)
|
||
同时填充内部映射字典。
|
||
"""
|
||
parent_chunks = self.parent_splitter.split_documents(documents)
|
||
child_chunks = self.child_splitter.split_documents(documents)
|
||
|
||
# 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法)
|
||
# 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位
|
||
self._build_mappings(parent_chunks, child_chunks)
|
||
|
||
return parent_chunks, child_chunks
|
||
|
||
def _build_mappings(self, parents: List[Document], children: List[Document]) -> None:
|
||
"""
|
||
根据文本内容建立父子映射。
|
||
本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。
|
||
"""
|
||
self.parent_to_children.clear()
|
||
self.child_to_parent.clear()
|
||
|
||
# 为每个父块生成唯一 ID(若无则使用索引)
|
||
for p_idx, parent in enumerate(parents):
|
||
parent_id = parent.metadata.get("id", f"parent_{p_idx}")
|
||
parent.metadata["id"] = parent_id
|
||
self.parent_to_children[parent_id] = []
|
||
|
||
# 将每个子块分配给包含其文本的第一个父块
|
||
for c_idx, child in enumerate(children):
|
||
child_id = child.metadata.get("id", f"child_{c_idx}")
|
||
child.metadata["id"] = child_id
|
||
for parent in parents:
|
||
if child.page_content in parent.page_content:
|
||
parent_id = parent.metadata["id"]
|
||
self.parent_to_children[parent_id].append(child_id)
|
||
self.child_to_parent[child_id] = parent_id
|
||
child.metadata["parent_id"] = parent_id
|
||
break
|
||
|
||
def get_parent_for_child(self, child_id: str) -> Optional[str]:
|
||
"""根据子块 ID 获取父块 ID"""
|
||
return self.child_to_parent.get(child_id)
|
||
|
||
def get_children_for_parent(self, parent_id: str) -> List[str]:
|
||
"""根据父块 ID 获取所有子块 ID"""
|
||
return self.parent_to_children.get(parent_id, []) |