59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
# rag_core/retriever_factory.py
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_classic.retrievers import ParentDocumentRetriever
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from typing import Optional
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.stores import BaseStore
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
|
from langchain_classic.retrievers import ParentDocumentRetriever
|
|
|
|
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
|
|
|
def create_parent_retriever(
|
|
collection_name: str = "rag_documents",
|
|
embeddings: Optional[Embeddings] = None,
|
|
parent_splitter: Optional[TextSplitter] = None,
|
|
child_splitter: Optional[TextSplitter] = None,
|
|
docstore: Optional[BaseStore] = None,
|
|
search_k: int = 5,
|
|
# 若未传入切分器,则用以下参数创建默认切分器
|
|
parent_chunk_size: int = 1000,
|
|
parent_chunk_overlap: int = 100,
|
|
child_chunk_size: int = 200,
|
|
child_chunk_overlap: int = 20,
|
|
) -> ParentDocumentRetriever:
|
|
# 嵌入模型
|
|
if embeddings is None:
|
|
embedder = LlamaCppEmbedder()
|
|
embeddings = embedder.as_langchain_embeddings()
|
|
|
|
# 向量存储(只读)
|
|
vector_store = QdrantVectorStore(
|
|
collection_name=collection_name,
|
|
embeddings=embeddings,
|
|
)
|
|
|
|
# 切分器(若未提供则创建默认)
|
|
if parent_splitter is None:
|
|
parent_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=parent_chunk_size,
|
|
chunk_overlap=parent_chunk_overlap,
|
|
)
|
|
if child_splitter is None:
|
|
child_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=child_chunk_size,
|
|
chunk_overlap=child_chunk_overlap,
|
|
)
|
|
|
|
# 文档存储
|
|
if docstore is None:
|
|
docstore, _ = create_docstore() # 从环境变量读取连接
|
|
|
|
return ParentDocumentRetriever(
|
|
vectorstore=vector_store.get_langchain_vectorstore(),
|
|
docstore=docstore,
|
|
child_splitter=child_splitter,
|
|
parent_splitter=parent_splitter,
|
|
search_kwargs={"k": search_k},
|
|
) |