Files
ailine/rag_core/retriever_factory.py
root 3c906e91d9
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 35m37s
重排,多路查询
2026-04-20 01:10:18 +08:00

67 lines
2.4 KiB
Python

# rag_core/retriever_factory.py
from langchain_core.embeddings import Embeddings
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter
from rag_indexer.splitters import SplitterType, get_splitter
import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
embeddings: Optional[Embeddings] = None,
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
search_k: int = 5,
# 若未传入切分器,则用以下参数创建默认切分器
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
) -> ParentDocumentRetriever:
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 向量存储(只读)
vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=embeddings,
)
# 切分器(若未提供则创建默认)
if parent_splitter is None:
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_chunk_size,
chunk_overlap=parent_chunk_overlap,
)
if child_splitter is None:
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
)
# 文档存储
if docstore is None:
docstore, _ = create_docstore() # 从环境变量读取连接
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
)