Files
ailine/backend/rag_core/retriever_factory.py
root 60afa86ded
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
feat: 实现 BM25 稀疏 + 稠密向量混合检索功能
2026-05-04 02:01:22 +08:00

111 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索器工厂模块
提供创建各种检索器的工厂函数,包括:
- 基础向量检索器
- ParentDocumentRetriever父子文档
- 混合检索器(稠密+稀疏)
"""
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_core.stores import BaseStore
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
search_k: int = 5,
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
embeddings: Optional[Embeddings] = None,
) -> ParentDocumentRetriever:
"""
创建 ParentDocumentRetriever 实例(基础稠密向量版本)。
Args:
collection_name: Qdrant 集合名称,默认 "rag_documents"
parent_splitter: 父文档切分器,默认 None使用默认参数创建
child_splitter: 子文档切分器,默认 None使用默认参数创建
docstore: 文档存储实例,默认 None使用默认参数创建
search_k: 检索时返回的结果数,默认 5
parent_chunk_size: 父文档块大小,默认 1000
parent_chunk_overlap: 父文档块重叠大小,默认 100
child_chunk_size: 子文档块大小,默认 200
child_chunk_overlap: 子文档块重叠大小,默认 20
embeddings: 嵌入模型实例,默认 None使用内部默认的 LocalLlamaCppEmbedder
Returns:
ParentDocumentRetriever 实例
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 向量存储(只读)
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 切分器(若未提供则创建默认)
if parent_splitter is None:
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_chunk_size,
chunk_overlap=parent_chunk_overlap,
)
if child_splitter is None:
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
)
# 文档存储
if docstore is None:
docstore, _ = create_docstore()
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
)
def create_hybrid_retriever_factory(
collection_name: str = "rag_documents",
search_k: int = 5,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
【不完整,仅占位】创建混合检索器的工厂函数占位符。
注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。
这里仅返回 QdrantVectorStore 作为基础。
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 嵌入模型实例
Returns:
基础的 QdrantVectorStore仅稠密检索
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 返回 LangChain 兼容的 retriever
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})