This commit is contained in:
71
rag_indexer/splitters.py
Normal file
71
rag_indexer/splitters.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Text splitters for chunking documents.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
|
||||
|
||||
class SplitterType(str, Enum):
|
||||
RECURSIVE = "recursive"
|
||||
SEMANTIC = "semantic"
|
||||
PARENT_CHILD = "parent_child"
|
||||
|
||||
|
||||
def get_splitter(splitter_type: SplitterType, **kwargs):
|
||||
"""Factory function to create a text splitter."""
|
||||
if splitter_type == SplitterType.RECURSIVE:
|
||||
chunk_size = kwargs.get("chunk_size", 500)
|
||||
chunk_overlap = kwargs.get("chunk_overlap", 50)
|
||||
return RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""],
|
||||
)
|
||||
elif splitter_type == SplitterType.SEMANTIC:
|
||||
# Requires embeddings for semantic splitting
|
||||
embeddings = kwargs.get("embeddings")
|
||||
if embeddings is None:
|
||||
raise ValueError("Semantic splitter requires 'embeddings' parameter")
|
||||
return SemanticChunker(embeddings=embeddings)
|
||||
else:
|
||||
raise ValueError(f"Unsupported splitter type: {splitter_type}")
|
||||
|
||||
|
||||
class ParentChildSplitter:
|
||||
"""
|
||||
Splits documents into parent (large) and child (small) chunks.
|
||||
Child chunks are indexed for retrieval, parent chunks are stored for context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_chunk_size: int = 1000,
|
||||
child_chunk_size: int = 200,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_overlap: int = 20,
|
||||
):
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_chunk_size,
|
||||
chunk_overlap=parent_chunk_overlap,
|
||||
)
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_chunk_size,
|
||||
chunk_overlap=child_chunk_overlap,
|
||||
)
|
||||
|
||||
def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]:
|
||||
"""
|
||||
Returns:
|
||||
(parent_chunks, child_chunks)
|
||||
"""
|
||||
parent_chunks = self.parent_splitter.split_documents(documents)
|
||||
child_chunks = self.child_splitter.split_documents(documents)
|
||||
|
||||
# Link child chunks to parent IDs (optional metadata)
|
||||
# In a real implementation, you'd map each child to a parent chunk ID.
|
||||
return parent_chunks, child_chunks
|
||||
Reference in New Issue
Block a user