This commit is contained in:
@@ -7,6 +7,8 @@ RAG Core - 公共 RAG 组件包
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .retriever_factory import create_parent_retriever
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LlamaCppEmbedder",
|
||||
@@ -15,4 +17,5 @@ __all__ = [
|
||||
"QDRANT_API_KEY",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"create_parent_retriever",
|
||||
]
|
||||
|
||||
24
rag_core/client.py
Normal file
24
rag_core/client.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# rag_core/client.py
|
||||
import os
|
||||
from typing import Optional
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
|
||||
def create_qdrant_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 120, # 索引构建需要较长超时
|
||||
) -> QdrantClient:
|
||||
effective_url = url or QDRANT_URL
|
||||
effective_api_key = api_key or QDRANT_API_KEY
|
||||
|
||||
if not effective_url:
|
||||
raise ValueError("Qdrant URL 未配置")
|
||||
|
||||
client_kwargs = {"url": effective_url, "timeout": timeout}
|
||||
if effective_api_key:
|
||||
client_kwargs["api_key"] = effective_api_key
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
67
rag_core/retriever_factory.py
Normal file
67
rag_core/retriever_factory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# rag_core/retriever_factory.py
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from rag_indexer.splitters import SplitterType, get_splitter
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Any, Dict, Tuple
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
||||
|
||||
|
||||
def create_parent_retriever(
|
||||
collection_name: str = "rag_documents",
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
parent_splitter: Optional[TextSplitter] = None,
|
||||
child_splitter: Optional[TextSplitter] = None,
|
||||
docstore: Optional[BaseStore] = None,
|
||||
search_k: int = 5,
|
||||
# 若未传入切分器,则用以下参数创建默认切分器
|
||||
parent_chunk_size: int = 1000,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_size: int = 200,
|
||||
child_chunk_overlap: int = 20,
|
||||
) -> ParentDocumentRetriever:
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 向量存储(只读)
|
||||
vector_store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
# 切分器(若未提供则创建默认)
|
||||
if parent_splitter is None:
|
||||
parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_chunk_size,
|
||||
chunk_overlap=parent_chunk_overlap,
|
||||
)
|
||||
if child_splitter is None:
|
||||
child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_chunk_size,
|
||||
chunk_overlap=child_chunk_overlap,
|
||||
)
|
||||
|
||||
# 文档存储
|
||||
if docstore is None:
|
||||
docstore, _ = create_docstore() # 从环境变量读取连接
|
||||
|
||||
return ParentDocumentRetriever(
|
||||
vectorstore=vector_store.get_langchain_vectorstore(),
|
||||
docstore=docstore,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
search_kwargs={"k": search_k},
|
||||
)
|
||||
@@ -10,6 +10,7 @@ from langchain_core.documents import Document
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
from .client import create_qdrant_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,14 +45,8 @@ class QdrantVectorStore:
|
||||
)
|
||||
|
||||
def get_client(self) -> QdrantClient:
|
||||
"""懒加载客户端,每次获取时确保连接可用。"""
|
||||
if self._client is None:
|
||||
self._client = QdrantClient(
|
||||
url=QDRANT_URL,
|
||||
api_key=QDRANT_API_KEY,
|
||||
timeout=120,
|
||||
http2=False,
|
||||
)
|
||||
self._client = create_qdrant_client(timeout=120)
|
||||
return self._client
|
||||
|
||||
def refresh_client(self):
|
||||
|
||||
Reference in New Issue
Block a user