参数配置统一
This commit is contained in:
@@ -37,7 +37,6 @@ RAG 检索与生成模块
|
|||||||
from .retriever import (
|
from .retriever import (
|
||||||
create_base_retriever,
|
create_base_retriever,
|
||||||
create_hybrid_retriever,
|
create_hybrid_retriever,
|
||||||
create_qdrant_client,
|
|
||||||
)
|
)
|
||||||
from .reranker import LLaMaCPPReranker
|
from .reranker import LLaMaCPPReranker
|
||||||
from .query_transform import MultiQueryGenerator
|
from .query_transform import MultiQueryGenerator
|
||||||
@@ -50,7 +49,6 @@ __all__ = [
|
|||||||
# 检索器工厂函数
|
# 检索器工厂函数
|
||||||
"create_base_retriever",
|
"create_base_retriever",
|
||||||
"create_hybrid_retriever",
|
"create_hybrid_retriever",
|
||||||
"create_qdrant_client",
|
|
||||||
|
|
||||||
# 重排序器
|
# 重排序器
|
||||||
"LLaMaCPPReranker",
|
"LLaMaCPPReranker",
|
||||||
|
|||||||
@@ -25,66 +25,25 @@ Qdrant 向量检索器模块
|
|||||||
>>> docs = retriever.invoke("什么是 RAG?")
|
>>> docs = retriever.invoke("什么是 RAG?")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
from typing import Dict, Any
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from langchain_qdrant import QdrantVectorStore
|
from langchain_qdrant import QdrantVectorStore
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
from rag_core import QDRANT_URL, QDRANT_API_KEY, LlamaCppEmbedder
|
||||||
|
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
||||||
|
|
||||||
# 模块级常量
|
# 模块级常量
|
||||||
DEFAULT_SEARCH_K = 20
|
DEFAULT_SEARCH_K = 20
|
||||||
DEFAULT_SCORE_THRESHOLD = 0.3
|
DEFAULT_SCORE_THRESHOLD = 0.3
|
||||||
|
|
||||||
|
|
||||||
def create_qdrant_client(
|
|
||||||
url: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
timeout: int = 30,
|
|
||||||
) -> QdrantClient:
|
|
||||||
"""
|
|
||||||
创建并返回一个配置好的 Qdrant 客户端。
|
|
||||||
|
|
||||||
优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: Qdrant 服务地址,例如 "http://localhost:6333"。
|
|
||||||
默认从环境变量 QDRANT_URL 读取。
|
|
||||||
api_key: API 密钥(若 Qdrant 启用了认证)。
|
|
||||||
默认从环境变量 QDRANT_API_KEY 读取。
|
|
||||||
timeout: 请求超时时间(秒),默认 30 秒。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
配置好的 QdrantClient 实例。
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果 url 为空且环境变量也未设置。
|
|
||||||
"""
|
|
||||||
effective_url = url or QDRANT_URL
|
|
||||||
if not effective_url:
|
|
||||||
raise ValueError(
|
|
||||||
"Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL"
|
|
||||||
)
|
|
||||||
|
|
||||||
effective_api_key = api_key or QDRANT_API_KEY
|
|
||||||
|
|
||||||
client_kwargs = {
|
|
||||||
"url": effective_url,
|
|
||||||
"timeout": timeout,
|
|
||||||
}
|
|
||||||
if effective_api_key:
|
|
||||||
client_kwargs["api_key"] = effective_api_key
|
|
||||||
|
|
||||||
return QdrantClient(**client_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def create_base_retriever(
|
def create_base_retriever(
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
embeddings: Embeddings,
|
search_kwargs: Dict[str, Any] | None = None,
|
||||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
client: QdrantClient | None = None,
|
||||||
client: Optional[QdrantClient] = None,
|
|
||||||
) -> BaseRetriever:
|
) -> BaseRetriever:
|
||||||
"""
|
"""
|
||||||
创建基础向量检索器(仅稠密向量检索)。
|
创建基础向量检索器(仅稠密向量检索)。
|
||||||
@@ -94,7 +53,6 @@ def create_base_retriever(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: Qdrant 集合名称(需预先创建并索引)。
|
collection_name: Qdrant 集合名称(需预先创建并索引)。
|
||||||
embeddings: LangChain 兼容的嵌入模型实例。
|
|
||||||
search_kwargs: 搜索参数,可包含:
|
search_kwargs: 搜索参数,可包含:
|
||||||
- k (int): 返回的文档数量,默认 20。
|
- k (int): 返回的文档数量,默认 20。
|
||||||
- score_threshold (float): 相似度阈值,仅返回高于此分数的文档。
|
- score_threshold (float): 相似度阈值,仅返回高于此分数的文档。
|
||||||
@@ -108,6 +66,10 @@ def create_base_retriever(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: 如果集合不存在或嵌入模型无效。
|
ValueError: 如果集合不存在或嵌入模型无效。
|
||||||
"""
|
"""
|
||||||
|
# 嵌入模型
|
||||||
|
embedder = LlamaCppEmbedder()
|
||||||
|
embeddings = embedder.as_langchain_embeddings()
|
||||||
|
|
||||||
# 合并默认搜索参数
|
# 合并默认搜索参数
|
||||||
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
|
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
|
||||||
if search_kwargs:
|
if search_kwargs:
|
||||||
@@ -115,7 +77,7 @@ def create_base_retriever(
|
|||||||
|
|
||||||
# 创建或复用 Qdrant 客户端
|
# 创建或复用 Qdrant 客户端
|
||||||
if client is None:
|
if client is None:
|
||||||
client = create_qdrant_client()
|
client = create_core_qdrant_client()
|
||||||
|
|
||||||
# 验证集合是否存在(可选,便于提前发现问题)
|
# 验证集合是否存在(可选,便于提前发现问题)
|
||||||
try:
|
try:
|
||||||
@@ -140,11 +102,10 @@ def create_base_retriever(
|
|||||||
|
|
||||||
def create_hybrid_retriever(
|
def create_hybrid_retriever(
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
embeddings: Embeddings,
|
|
||||||
dense_k: int = 10,
|
dense_k: int = 10,
|
||||||
sparse_k: int = 10,
|
sparse_k: int = 10,
|
||||||
score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD,
|
score_threshold: float | None = DEFAULT_SCORE_THRESHOLD,
|
||||||
client: Optional[QdrantClient] = None,
|
client: QdrantClient | None = None,
|
||||||
) -> BaseRetriever:
|
) -> BaseRetriever:
|
||||||
"""
|
"""
|
||||||
创建混合检索器(稠密向量 + BM25 稀疏向量)。
|
创建混合检索器(稠密向量 + BM25 稀疏向量)。
|
||||||
@@ -157,7 +118,6 @@ def create_hybrid_retriever(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: Qdrant 集合名称。
|
collection_name: Qdrant 集合名称。
|
||||||
embeddings: 嵌入模型(用于稠密向量)。
|
|
||||||
dense_k: 稠密向量检索返回数量,默认 10。
|
dense_k: 稠密向量检索返回数量,默认 10。
|
||||||
sparse_k: 稀疏向量检索返回数量,默认 10。
|
sparse_k: 稀疏向量检索返回数量,默认 10。
|
||||||
score_threshold: 相似度阈值,默认 0.3。
|
score_threshold: 相似度阈值,默认 0.3。
|
||||||
@@ -177,7 +137,6 @@ def create_hybrid_retriever(
|
|||||||
# 复用基础检索器创建逻辑,只需调整搜索参数
|
# 复用基础检索器创建逻辑,只需调整搜索参数
|
||||||
return create_base_retriever(
|
return create_base_retriever(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
embeddings=embeddings,
|
|
||||||
search_kwargs=search_kwargs,
|
search_kwargs=search_kwargs,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
@@ -186,9 +145,8 @@ def create_hybrid_retriever(
|
|||||||
# 可选:提供异步友好的辅助函数
|
# 可选:提供异步友好的辅助函数
|
||||||
async def acreate_base_retriever(
|
async def acreate_base_retriever(
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
embeddings: Embeddings,
|
search_kwargs: Dict[str, Any] | None = None,
|
||||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
client: QdrantClient | None = None,
|
||||||
client: Optional[QdrantClient] = None,
|
|
||||||
) -> BaseRetriever:
|
) -> BaseRetriever:
|
||||||
"""
|
"""
|
||||||
异步创建基础向量检索器(与同步版本功能相同)。
|
异步创建基础向量检索器(与同步版本功能相同)。
|
||||||
@@ -196,4 +154,4 @@ async def acreate_base_retriever(
|
|||||||
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
|
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
|
||||||
"""
|
"""
|
||||||
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
|
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
|
||||||
return create_base_retriever(collection_name, embeddings, search_kwargs, client)
|
return create_base_retriever(collection_name, search_kwargs, client)
|
||||||
|
|||||||
@@ -5,9 +5,17 @@ RAG Core - 公共 RAG 组件包
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .embedders import LlamaCppEmbedder
|
from .embedders import LlamaCppEmbedder
|
||||||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
from .vector_store import QdrantVectorStore
|
||||||
from .store import PostgresDocStore, create_docstore
|
from .store import PostgresDocStore, create_docstore
|
||||||
from .retriever_factory import create_parent_retriever
|
from .retriever_factory import create_parent_retriever
|
||||||
|
from .config import (
|
||||||
|
QDRANT_URL,
|
||||||
|
QDRANT_API_KEY,
|
||||||
|
LLAMACPP_EMBEDDING_URL,
|
||||||
|
LLAMACPP_API_KEY,
|
||||||
|
DB_URI,
|
||||||
|
DOCSTORE_URI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -15,6 +23,10 @@ __all__ = [
|
|||||||
"QdrantVectorStore",
|
"QdrantVectorStore",
|
||||||
"QDRANT_URL",
|
"QDRANT_URL",
|
||||||
"QDRANT_API_KEY",
|
"QDRANT_API_KEY",
|
||||||
|
"LLAMACPP_EMBEDDING_URL",
|
||||||
|
"LLAMACPP_API_KEY",
|
||||||
|
"DB_URI",
|
||||||
|
"DOCSTORE_URI",
|
||||||
"PostgresDocStore",
|
"PostgresDocStore",
|
||||||
"create_docstore",
|
"create_docstore",
|
||||||
"create_parent_retriever",
|
"create_parent_retriever",
|
||||||
|
|||||||
@@ -1,27 +1,30 @@
|
|||||||
# rag_core/client.py
|
# rag_core/client.py
|
||||||
import os
|
import os
|
||||||
from .config import QDRANT_URL, QDRANT_API_KEY
|
from .config import QDRANT_URL, QDRANT_API_KEY
|
||||||
from typing import Optional
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
|
||||||
|
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||||||
|
"""
|
||||||
|
创建并返回一个配置好的 Qdrant 客户端。
|
||||||
|
|
||||||
def create_qdrant_client(
|
Args:
|
||||||
url: Optional[str] = None,
|
timeout: 请求超时时间(秒),默认 300 秒(索引构建需要较长超时)。
|
||||||
api_key: Optional[str] = None,
|
|
||||||
timeout: int = 300, # 索引构建需要较长超时
|
|
||||||
) -> QdrantClient:
|
|
||||||
effective_url = url or QDRANT_URL
|
|
||||||
effective_api_key = api_key or QDRANT_API_KEY
|
|
||||||
|
|
||||||
if not effective_url:
|
Returns:
|
||||||
|
配置好的 QdrantClient 实例。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果 QDRANT_URL 未配置。
|
||||||
|
"""
|
||||||
|
if not QDRANT_URL:
|
||||||
raise ValueError("Qdrant URL 未配置")
|
raise ValueError("Qdrant URL 未配置")
|
||||||
|
|
||||||
client_kwargs = {
|
client_kwargs = {
|
||||||
"url": effective_url,
|
"url": QDRANT_URL,
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
}
|
}
|
||||||
if effective_api_key:
|
if QDRANT_API_KEY:
|
||||||
client_kwargs["api_key"] = effective_api_key
|
client_kwargs["api_key"] = QDRANT_API_KEY
|
||||||
|
|
||||||
return QdrantClient(**client_kwargs)
|
return QdrantClient(**client_kwargs)
|
||||||
@@ -5,21 +5,21 @@
|
|||||||
import os
|
import os
|
||||||
from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||||
import httpx
|
import httpx
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
|
|
||||||
class LlamaCppEmbedder:
|
class LlamaCppEmbedder:
|
||||||
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
|
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"):
|
||||||
self,
|
"""
|
||||||
base_url: Optional[str] = None,
|
Args:
|
||||||
api_key: Optional[str] = None,
|
model: 嵌入模型名称,默认 "Qwen3-Embedding-0.6B-Q8_0"。
|
||||||
model: str = "Qwen3-Embedding-0.6B-Q8_0",
|
"""
|
||||||
):
|
self.base_url = LLAMACPP_EMBEDDING_URL
|
||||||
self.base_url = base_url or LLAMACPP_EMBEDDING_URL
|
self.api_key = LLAMACPP_API_KEY
|
||||||
self.api_key = api_key or LLAMACPP_API_KEY
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def as_langchain_embeddings(self) -> Embeddings:
|
def as_langchain_embeddings(self) -> Embeddings:
|
||||||
@@ -30,7 +30,7 @@ class LlamaCppEmbedder:
|
|||||||
"""嵌入一批文档。"""
|
"""嵌入一批文档。"""
|
||||||
return self._call_embedding_api(texts)
|
return self._call_embedding_api(texts)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[List[float]]:
|
||||||
"""嵌入单个查询。"""
|
"""嵌入单个查询。"""
|
||||||
return self._call_embedding_api([text])[0]
|
return self._call_embedding_api([text])[0]
|
||||||
|
|
||||||
@@ -70,6 +70,7 @@ class LlamaCppEmbedder:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
|
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
|
||||||
|
|
||||||
|
|
||||||
class _LlamaCppLangchainAdapter(Embeddings):
|
class _LlamaCppLangchainAdapter(Embeddings):
|
||||||
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
|
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
|
||||||
|
|
||||||
@@ -79,5 +80,5 @@ class _LlamaCppLangchainAdapter(Embeddings):
|
|||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
return self._embedder.embed_documents(texts)
|
return self._embedder.embed_documents(texts)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[List[float]]:
|
||||||
return self._embedder.embed_query(text)
|
return self._embedder.embed_query(text)
|
||||||
@@ -1,38 +1,46 @@
|
|||||||
# rag_core/retriever_factory.py
|
# rag_core/retriever_factory.py
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
||||||
from typing import Optional
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.stores import BaseStore
|
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
from langchain_core.stores import BaseStore
|
||||||
|
|
||||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
||||||
|
|
||||||
|
|
||||||
def create_parent_retriever(
|
def create_parent_retriever(
|
||||||
collection_name: str = "rag_documents",
|
collection_name: str = "rag_documents",
|
||||||
embeddings: Optional[Embeddings] = None,
|
parent_splitter: TextSplitter | None = None,
|
||||||
parent_splitter: Optional[TextSplitter] = None,
|
child_splitter: TextSplitter | None = None,
|
||||||
child_splitter: Optional[TextSplitter] = None,
|
docstore: BaseStore | None = None,
|
||||||
docstore: Optional[BaseStore] = None,
|
|
||||||
search_k: int = 5,
|
search_k: int = 5,
|
||||||
# 若未传入切分器,则用以下参数创建默认切分器
|
|
||||||
parent_chunk_size: int = 1000,
|
parent_chunk_size: int = 1000,
|
||||||
parent_chunk_overlap: int = 100,
|
parent_chunk_overlap: int = 100,
|
||||||
child_chunk_size: int = 200,
|
child_chunk_size: int = 200,
|
||||||
child_chunk_overlap: int = 20,
|
child_chunk_overlap: int = 20,
|
||||||
) -> ParentDocumentRetriever:
|
) -> ParentDocumentRetriever:
|
||||||
|
"""
|
||||||
|
创建 ParentDocumentRetriever 实例。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Qdrant 集合名称,默认 "rag_documents"
|
||||||
|
parent_splitter: 父文档切分器,默认 None(使用默认参数创建)
|
||||||
|
child_splitter: 子文档切分器,默认 None(使用默认参数创建)
|
||||||
|
docstore: 文档存储实例,默认 None(使用默认参数创建)
|
||||||
|
search_k: 检索时返回的结果数,默认 5
|
||||||
|
parent_chunk_size: 父文档块大小,默认 1000
|
||||||
|
parent_chunk_overlap: 父文档块重叠大小,默认 100
|
||||||
|
child_chunk_size: 子文档块大小,默认 200
|
||||||
|
child_chunk_overlap: 子文档块重叠大小,默认 20
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParentDocumentRetriever 实例
|
||||||
|
"""
|
||||||
# 嵌入模型
|
# 嵌入模型
|
||||||
if embeddings is None:
|
|
||||||
embedder = LlamaCppEmbedder()
|
embedder = LlamaCppEmbedder()
|
||||||
embeddings = embedder.as_langchain_embeddings()
|
embeddings = embedder.as_langchain_embeddings()
|
||||||
|
|
||||||
# 向量存储(只读)
|
# 向量存储(只读)
|
||||||
vector_store = QdrantVectorStore(
|
vector_store = QdrantVectorStore(collection_name=collection_name)
|
||||||
collection_name=collection_name,
|
|
||||||
embeddings=embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 切分器(若未提供则创建默认)
|
# 切分器(若未提供则创建默认)
|
||||||
if parent_splitter is None:
|
if parent_splitter is None:
|
||||||
@@ -48,7 +56,7 @@ def create_parent_retriever(
|
|||||||
|
|
||||||
# 文档存储
|
# 文档存储
|
||||||
if docstore is None:
|
if docstore is None:
|
||||||
docstore, _ = create_docstore() # 从环境变量读取连接
|
docstore, _ = create_docstore()
|
||||||
|
|
||||||
return ParentDocumentRetriever(
|
return ParentDocumentRetriever(
|
||||||
vectorstore=vector_store.get_langchain_vectorstore(),
|
vectorstore=vector_store.get_langchain_vectorstore(),
|
||||||
|
|||||||
@@ -9,14 +9,13 @@
|
|||||||
|
|
||||||
>>> # 创建 PostgreSQL 存储
|
>>> # 创建 PostgreSQL 存储
|
||||||
>>> store, conn = create_docstore(
|
>>> store, conn = create_docstore(
|
||||||
... connection_string="postgresql://user:pass@host:5432/db",
|
|
||||||
... table_name="parent_docs"
|
... table_name="parent_docs"
|
||||||
... )
|
... )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from .postgres import PostgresDocStore
|
from .postgres import PostgresDocStore
|
||||||
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
|
from .factory import create_docstore, get_docstore_uri
|
||||||
|
|
||||||
__version__ = "2.0.0"
|
__version__ = "2.0.0"
|
||||||
|
|
||||||
@@ -27,5 +26,4 @@ __all__ = [
|
|||||||
# 工厂函数
|
# 工厂函数
|
||||||
"create_docstore",
|
"create_docstore",
|
||||||
"get_docstore_uri",
|
"get_docstore_uri",
|
||||||
"DEFAULT_DB_URI",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,18 +5,15 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from ..config import DB_URI, DOCSTORE_URI
|
from ..config import DOCSTORE_URI
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from langchain_core.stores import BaseStore
|
from langchain_core.stores import BaseStore
|
||||||
from .postgres import PostgresDocStore
|
from .postgres import PostgresDocStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 默认连接字符串(从环境变量读取)
|
|
||||||
DEFAULT_DB_URI = DB_URI
|
|
||||||
|
|
||||||
|
|
||||||
def get_docstore_uri() -> str:
|
def get_docstore_uri() -> str:
|
||||||
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
||||||
@@ -24,18 +21,14 @@ def get_docstore_uri() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def create_docstore(
|
def create_docstore(
|
||||||
store_type: str = "postgres",
|
|
||||||
connection_string: Optional[str] = None,
|
|
||||||
table_name: str = "parent_documents",
|
table_name: str = "parent_documents",
|
||||||
pool_config: Optional[dict] = None,
|
pool_config: dict | None = None,
|
||||||
max_concurrency: Optional[int] = None
|
max_concurrency: int | None = None
|
||||||
) -> Tuple[BaseStore, Optional[str]]:
|
) -> Tuple[BaseStore, str]:
|
||||||
"""
|
"""
|
||||||
工厂函数,创建 PostgreSQL 文档存储。
|
工厂函数,创建 PostgreSQL 文档存储。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
store_type: 存储类型,目前仅支持 "postgres"(默认)
|
|
||||||
connection_string: PostgreSQL 连接字符串
|
|
||||||
table_name: PostgreSQL 表名(默认:parent_documents)
|
table_name: PostgreSQL 表名(默认:parent_documents)
|
||||||
pool_config: 连接池配置
|
pool_config: 连接池配置
|
||||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||||
@@ -44,21 +37,16 @@ def create_docstore(
|
|||||||
元组 (存储实例, 连接字符串)
|
元组 (存储实例, 连接字符串)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 不支持的存储类型
|
|
||||||
ImportError: 缺少必要的依赖
|
ImportError: 缺少必要的依赖
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> # 创建 PostgreSQL 存储
|
>>> # 创建 PostgreSQL 存储
|
||||||
>>> store, conn = create_docstore(
|
>>> store, conn = create_docstore(
|
||||||
... connection_string="postgresql://user:pass@host:5432/db",
|
|
||||||
... table_name="parent_docs",
|
... table_name="parent_docs",
|
||||||
... max_concurrency=10
|
... max_concurrency=10
|
||||||
... )
|
... )
|
||||||
"""
|
"""
|
||||||
store_type = store_type.lower()
|
conn_str = get_docstore_uri()
|
||||||
|
|
||||||
if store_type == "postgres":
|
|
||||||
conn_str = connection_string or get_docstore_uri()
|
|
||||||
store = PostgresDocStore(
|
store = PostgresDocStore(
|
||||||
connection_string=conn_str,
|
connection_string=conn_str,
|
||||||
table_name=table_name,
|
table_name=table_name,
|
||||||
@@ -66,6 +54,3 @@ def create_docstore(
|
|||||||
max_concurrency=max_concurrency
|
max_concurrency=max_concurrency
|
||||||
)
|
)
|
||||||
return store, conn_str
|
return store, conn_str
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Qdrant 向量数据库包装器。
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from .config import QDRANT_URL, QDRANT_API_KEY
|
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
@@ -14,31 +13,28 @@ from qdrant_client import QdrantClient
|
|||||||
from qdrant_client.http.models import Distance, VectorParams
|
from qdrant_client.http.models import Distance, VectorParams
|
||||||
from httpx import RemoteProtocolError
|
from httpx import RemoteProtocolError
|
||||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||||
|
|
||||||
from .client import create_qdrant_client
|
from .client import create_qdrant_client
|
||||||
|
from .embedders import LlamaCppEmbedder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorStore:
|
class QdrantVectorStore:
|
||||||
"""Qdrant 向量数据库操作包装器。"""
|
"""Qdrant 向量数据库操作包装器。"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, collection_name: str):
|
||||||
self,
|
"""
|
||||||
collection_name: str,
|
Args:
|
||||||
embeddings: Optional[Any] = None,
|
collection_name: Qdrant 集合名称。
|
||||||
):
|
"""
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self._client: Optional[QdrantClient] = None
|
self._client: Optional[QdrantClient] = None
|
||||||
self._connection_attempts = 0
|
self._connection_attempts = 0
|
||||||
self._last_connection_time: Optional[float] = None
|
self._last_connection_time: Optional[float] = None
|
||||||
|
|
||||||
if embeddings is None:
|
|
||||||
from rag_core.embedders import LlamaCppEmbedder
|
|
||||||
embedder = LlamaCppEmbedder()
|
embedder = LlamaCppEmbedder()
|
||||||
self.embeddings = embedder.as_langchain_embeddings()
|
self.embeddings = embedder.as_langchain_embeddings()
|
||||||
else:
|
|
||||||
self.embeddings = embeddings
|
|
||||||
|
|
||||||
self.create_collection()
|
self.create_collection()
|
||||||
|
|
||||||
@@ -92,10 +88,8 @@ class QdrantVectorStore:
|
|||||||
"client_initialized": self._client is not None,
|
"client_initialized": self._client is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
|
def create_collection(self, force_recreate: bool = False):
|
||||||
"""创建集合,设置合适的向量维度。"""
|
"""创建集合,设置合适的向量维度。"""
|
||||||
if vector_size is None:
|
|
||||||
from rag_core.embedders import LlamaCppEmbedder
|
|
||||||
embedder = LlamaCppEmbedder()
|
embedder = LlamaCppEmbedder()
|
||||||
vector_size = embedder.get_embedding_dimension()
|
vector_size = embedder.get_embedding_dimension()
|
||||||
|
|
||||||
|
|||||||
@@ -37,11 +37,10 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DocstoreConfig:
|
class DocstoreConfig:
|
||||||
"""文档存储配置(用于父块存储)。"""
|
"""文档存储配置(用于父块存储)。"""
|
||||||
connection_string: Optional[str] = None
|
pool_config: Dict[str, Any] | None = None
|
||||||
pool_config: Optional[Dict[str, Any]] = None
|
max_concurrency: int | None = None
|
||||||
max_concurrency: Optional[int] = None
|
|
||||||
# 若要从外部注入已创建好的 docstore,可直接设置此字段
|
# 若要从外部注入已创建好的 docstore,可直接设置此字段
|
||||||
instance: Optional[BaseStore] = None
|
instance: BaseStore | None = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IndexBuilderConfig:
|
class IndexBuilderConfig:
|
||||||
@@ -147,7 +146,6 @@ class IndexBuilder:
|
|||||||
# 使用工厂函数创建检索器,避免重复代码
|
# 使用工厂函数创建检索器,避免重复代码
|
||||||
self.retriever = create_parent_retriever(
|
self.retriever = create_parent_retriever(
|
||||||
collection_name=cfg.collection_name,
|
collection_name=cfg.collection_name,
|
||||||
embeddings=self.embeddings,
|
|
||||||
parent_splitter=self.parent_splitter,
|
parent_splitter=self.parent_splitter,
|
||||||
child_splitter=self.child_splitter,
|
child_splitter=self.child_splitter,
|
||||||
docstore=self.docstore,
|
docstore=self.docstore,
|
||||||
@@ -164,7 +162,6 @@ class IndexBuilder:
|
|||||||
|
|
||||||
# 使用 create_docstore 创建 PostgreSQL 存储
|
# 使用 create_docstore 创建 PostgreSQL 存储
|
||||||
docstore, conn_info = create_docstore(
|
docstore, conn_info = create_docstore(
|
||||||
connection_string=cfg.connection_string,
|
|
||||||
pool_config=cfg.pool_config,
|
pool_config=cfg.pool_config,
|
||||||
max_concurrency=cfg.max_concurrency,
|
max_concurrency=cfg.max_concurrency,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user