参数配置统一

This commit is contained in:
2026-04-21 19:06:34 +08:00
parent e2eaac9498
commit 37e86f3bb1
10 changed files with 120 additions and 166 deletions

View File

@@ -1,38 +1,46 @@
# rag_core/retriever_factory.py
# rag_core/retriever_factory.py
from langchain_core.embeddings import Embeddings
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_core.stores import BaseStore
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
embeddings: Optional[Embeddings] = None,
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
parent_splitter: TextSplitter | None = None,
child_splitter: TextSplitter | None = None,
docstore: BaseStore | None = None,
search_k: int = 5,
# 若未传入切分器,则用以下参数创建默认切分器
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
) -> ParentDocumentRetriever:
"""
创建 ParentDocumentRetriever 实例。
Args:
collection_name: Qdrant 集合名称,默认 "rag_documents"
parent_splitter: 父文档切分器,默认 None使用默认参数创建
child_splitter: 子文档切分器,默认 None使用默认参数创建
docstore: 文档存储实例,默认 None使用默认参数创建
search_k: 检索时返回的结果数,默认 5
parent_chunk_size: 父文档块大小,默认 1000
parent_chunk_overlap: 父文档块重叠大小,默认 100
child_chunk_size: 子文档块大小,默认 200
child_chunk_overlap: 子文档块重叠大小,默认 20
Returns:
ParentDocumentRetriever 实例
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 向量存储(只读)
vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=embeddings,
)
vector_store = QdrantVectorStore(collection_name=collection_name)
# 切分器(若未提供则创建默认)
if parent_splitter is None:
@@ -48,7 +56,7 @@ def create_parent_retriever(
# 文档存储
if docstore is None:
docstore, _ = create_docstore() # 从环境变量读取连接
docstore, _ = create_docstore()
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
@@ -56,4 +64,4 @@ def create_parent_retriever(
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
)
)