Files
ailine/rag_indexer/splitters.py
2026-04-19 15:01:40 +08:00

83 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
文本切分器,用于将文档切分成块。
"""
from enum import Enum
from typing import List, Optional
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_experimental.text_splitter import SemanticChunker
class SplitterType(str, Enum):
RECURSIVE = "recursive"
SEMANTIC = "semantic"
PARENT_CHILD = "parent_child"
def get_splitter(splitter_type: SplitterType, **kwargs):
"""工厂函数,创建文本切分器。"""
if splitter_type == SplitterType.RECURSIVE:
chunk_size = kwargs.get("chunk_size", 500)
chunk_overlap = kwargs.get("chunk_overlap", 50)
return RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", "", "", "", " ", ""],
)
elif splitter_type == SplitterType.SEMANTIC:
embeddings = kwargs.pop("embeddings", None)
if embeddings is None:
raise ValueError("语义切分器需要提供 'embeddings' 参数")
return SemanticChunkerAdapter(embeddings=embeddings, **kwargs)
else:
raise ValueError(f"不支持的切分器类型: {splitter_type}")
class SemanticChunkerAdapter(TextSplitter):
"""将 SemanticChunker 适配为 TextSplitter 接口。"""
def __init__(self, embeddings, **kwargs):
super().__init__(**kwargs)
chunk_size = kwargs.pop("chunk_size", None)
chunk_overlap = kwargs.pop("chunk_overlap", None)
self._chunker = SemanticChunker(embeddings=embeddings, **kwargs)
def split_text(self, text: str) -> List[str]:
return self._chunker.split_text(text)
class ParentChildSplitter:
"""
将文档切分为父块(大块)和子块(小块)。
子块用于索引检索,父块用于存储上下文。
"""
def __init__(
self,
parent_chunk_size: int = 1000,
child_chunk_size: int = 200,
parent_chunk_overlap: int = 100,
child_chunk_overlap: int = 20,
):
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_chunk_size,
chunk_overlap=parent_chunk_overlap,
)
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
)
def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
"""
返回:
(父块列表, 子块列表)
"""
parent_chunks = self.parent_splitter.split_documents(documents)
child_chunks = self.child_splitter.split_documents(documents)
# 将子块与父块 ID 关联(可选元数据)
# 在实际实现中,需要将每个子块映射到对应的父块 ID。
return parent_chunks, child_chunks