Files
ailine/rag_indexer/splitters.py
root 933d418d77
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s
检索器重构
2026-04-19 22:01:55 +08:00

211 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
文本切分器,用于将文档切分成块。
"""
from enum import Enum
from typing import List, Optional, Tuple, Dict, Any
from dataclasses import dataclass, field
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_experimental.text_splitter import SemanticChunker
class SplitterType(str, Enum):
RECURSIVE = "recursive"
SEMANTIC = "semantic"
PARENT_CHILD = "parent_child"
# ---------- 配置数据类,统一参数 ----------
@dataclass
class RecursiveSplitterConfig:
"""递归字符切分器配置"""
chunk_size: int = 500
chunk_overlap: int = 50
separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "", "", "", " ", ""])
keep_separator: bool = True
strip_whitespace: bool = True
@dataclass
class SemanticSplitterConfig:
"""语义切分器配置,仅包含 SemanticChunker 支持的参数。"""
embeddings: Any
buffer_size: int = 1
add_start_index: bool = False
breakpoint_threshold_type: str = "percentile"
breakpoint_threshold_amount: Optional[float] = None
number_of_chunks: Optional[int] = None
sentence_split_regex: str = r"(?<=[.?!。?!])\s+"
min_chunk_size: int = 100
@dataclass
class ParentChildSplitterConfig:
"""父子切分器配置"""
embeddings: Any # 子块语义切分所需
parent_chunk_size: int = 1000
parent_chunk_overlap: int = 100
child_buffer_size: int = 1
child_breakpoint_threshold_type: str = "percentile"
child_breakpoint_threshold_amount: Optional[float] = None
child_min_chunk_size: int = 100
child_max_chunk_size: Optional[int] = 200
# ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ----------
class SemanticChunkerAdapter(TextSplitter):
"""将 SemanticChunker 适配为 LangChain TextSplitter 接口。"""
def __init__(self, config: SemanticSplitterConfig, **kwargs):
super().__init__(**kwargs)
self._config = config
self._chunker = SemanticChunker(
embeddings=config.embeddings,
buffer_size=config.buffer_size,
add_start_index=config.add_start_index,
breakpoint_threshold_type=config.breakpoint_threshold_type,
breakpoint_threshold_amount=config.breakpoint_threshold_amount,
number_of_chunks=config.number_of_chunks,
sentence_split_regex=config.sentence_split_regex,
min_chunk_size=config.min_chunk_size,
)
def split_text(self, text: str) -> List[str]:
return self._chunker.split_text(text)
def split_documents(self, documents: List[Document]) -> List[Document]:
result = []
for doc in documents:
chunks = self.split_text(doc.page_content)
for i, chunk in enumerate(chunks):
result.append(Document(
page_content=chunk,
metadata={**doc.metadata, "chunk_index": i}
))
return result
# ---------- 工厂函数,统一创建切分器 ----------
def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
"""
根据类型创建切分器。
支持传入配置对象或直接参数。
"""
if splitter_type == SplitterType.RECURSIVE:
config = RecursiveSplitterConfig(
chunk_size=kwargs.get("chunk_size", 500),
chunk_overlap=kwargs.get("chunk_overlap", 50),
separators=kwargs.get("separators", ["\n\n", "\n", "", "", "", " ", ""]),
)
return RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
separators=config.separators,
keep_separator=config.keep_separator,
strip_whitespace=config.strip_whitespace,
)
elif splitter_type == SplitterType.SEMANTIC:
embeddings = kwargs.get("embeddings")
if embeddings is None:
raise ValueError("语义切分器需要提供 'embeddings' 参数")
if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig):
config = kwargs["config"]
else:
# 过滤出 SemanticSplitterConfig 支持的字段
config_kwargs = {
"embeddings": embeddings,
"buffer_size": kwargs.get("buffer_size", 1),
"breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"),
"breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"),
"number_of_chunks": kwargs.get("number_of_chunks"),
"min_chunk_size": kwargs.get("min_chunk_size", 100),
}
config = SemanticSplitterConfig(**config_kwargs)
return SemanticChunkerAdapter(config)
elif splitter_type == SplitterType.PARENT_CHILD:
# 父子切分器在 builder 中单独处理,不通过本函数创建
raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建")
else:
raise ValueError(f"不支持的切分器类型: {splitter_type}")
# ---------- 父子切分器实现 ----------
class ParentChildSplitter:
"""
将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。
内部维护父子块之间的映射关系。
"""
def __init__(self, config: ParentChildSplitterConfig):
self.config = config
# 父块使用递归字符切分
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.parent_chunk_size,
chunk_overlap=config.parent_chunk_overlap,
)
# 子块使用语义切分
semantic_config = SemanticSplitterConfig(
embeddings=config.embeddings,
buffer_size=config.child_buffer_size,
breakpoint_threshold_type=config.child_breakpoint_threshold_type,
breakpoint_threshold_amount=config.child_breakpoint_threshold_amount,
min_chunk_size=config.child_min_chunk_size,
)
self.child_splitter = SemanticChunkerAdapter(semantic_config)
# 存储父子块映射关系(可选)
self.parent_to_children: Dict[str, List[str]] = {}
self.child_to_parent: Dict[str, str] = {}
def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]:
"""
返回:
(父块列表, 子块列表)
同时填充内部映射字典。
"""
parent_chunks = self.parent_splitter.split_documents(documents)
child_chunks = self.child_splitter.split_documents(documents)
# 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法)
# 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位
self._build_mappings(parent_chunks, child_chunks)
return parent_chunks, child_chunks
def _build_mappings(self, parents: List[Document], children: List[Document]) -> None:
"""
根据文本内容建立父子映射。
本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。
"""
self.parent_to_children.clear()
self.child_to_parent.clear()
# 为每个父块生成唯一 ID若无则使用索引
for p_idx, parent in enumerate(parents):
parent_id = parent.metadata.get("id", f"parent_{p_idx}")
parent.metadata["id"] = parent_id
self.parent_to_children[parent_id] = []
# 将每个子块分配给包含其文本的第一个父块
for c_idx, child in enumerate(children):
child_id = child.metadata.get("id", f"child_{c_idx}")
child.metadata["id"] = child_id
for parent in parents:
if child.page_content in parent.page_content:
parent_id = parent.metadata["id"]
self.parent_to_children[parent_id].append(child_id)
self.child_to_parent[child_id] = parent_id
child.metadata["parent_id"] = parent_id
break
def get_parent_for_child(self, child_id: str) -> Optional[str]:
"""根据子块 ID 获取父块 ID"""
return self.child_to_parent.get(child_id)
def get_children_for_parent(self, parent_id: str) -> List[str]:
"""根据父块 ID 获取所有子块 ID"""
return self.parent_to_children.get(parent_id, [])