Files
ailine/backend/rag_core/retriever_factory.py

111 lines
3.9 KiB
Python
Raw Normal View History

"""
RAG 检索器工厂模块
提供创建各种检索器的工厂函数包括
- 基础向量检索器
- ParentDocumentRetriever父子文档
- 混合检索器稠密+稀疏
"""
from typing import Optional
2026-04-21 11:02:16 +08:00
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
2026-04-21 11:02:16 +08:00
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
2026-04-21 19:06:34 +08:00
from langchain_core.stores import BaseStore
2026-04-21 11:02:16 +08:00
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
2026-04-21 19:06:34 +08:00
2026-04-21 11:02:16 +08:00
def create_parent_retriever(
collection_name: str = "rag_documents",
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
2026-04-21 11:02:16 +08:00
search_k: int = 5,
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
embeddings: Optional[Embeddings] = None,
2026-04-21 11:02:16 +08:00
) -> ParentDocumentRetriever:
2026-04-21 19:06:34 +08:00
"""
创建 ParentDocumentRetriever 实例基础稠密向量版本
2026-04-21 19:06:34 +08:00
Args:
collection_name: Qdrant 集合名称默认 "rag_documents"
parent_splitter: 父文档切分器默认 None使用默认参数创建
child_splitter: 子文档切分器默认 None使用默认参数创建
docstore: 文档存储实例默认 None使用默认参数创建
search_k: 检索时返回的结果数默认 5
parent_chunk_size: 父文档块大小默认 1000
parent_chunk_overlap: 父文档块重叠大小默认 100
child_chunk_size: 子文档块大小默认 200
child_chunk_overlap: 子文档块重叠大小默认 20
embeddings: 嵌入模型实例默认 None使用内部默认的 LocalLlamaCppEmbedder
2026-04-21 19:06:34 +08:00
Returns:
ParentDocumentRetriever 实例
"""
2026-04-21 11:02:16 +08:00
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
2026-04-21 11:02:16 +08:00
# 向量存储(只读)
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
2026-04-21 11:02:16 +08:00
# 切分器(若未提供则创建默认)
if parent_splitter is None:
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_chunk_size,
chunk_overlap=parent_chunk_overlap,
)
if child_splitter is None:
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
)
2026-04-21 11:02:16 +08:00
# 文档存储
if docstore is None:
2026-04-21 19:06:34 +08:00
docstore, _ = create_docstore()
2026-04-21 11:02:16 +08:00
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
2026-04-21 19:06:34 +08:00
)
def create_hybrid_retriever_factory(
collection_name: str = "rag_documents",
search_k: int = 5,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
不完整仅占位创建混合检索器的工厂函数占位符
注意完整的混合检索逻辑在 app/rag/retriever.py 中实现
这里仅返回 QdrantVectorStore 作为基础
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 嵌入模型实例
Returns:
基础的 QdrantVectorStore仅稠密检索
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 返回 LangChain 兼容的 retriever
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})