参数配置统一

This commit is contained in:
2026-04-21 19:06:34 +08:00
parent e2eaac9498
commit 37e86f3bb1
10 changed files with 120 additions and 166 deletions

View File

@@ -5,9 +5,17 @@ RAG Core - 公共 RAG 组件包
"""
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .vector_store import QdrantVectorStore
from .store import PostgresDocStore, create_docstore
from .retriever_factory import create_parent_retriever
from .config import (
QDRANT_URL,
QDRANT_API_KEY,
LLAMACPP_EMBEDDING_URL,
LLAMACPP_API_KEY,
DB_URI,
DOCSTORE_URI,
)
__all__ = [
@@ -15,6 +23,10 @@ __all__ = [
"QdrantVectorStore",
"QDRANT_URL",
"QDRANT_API_KEY",
"LLAMACPP_EMBEDDING_URL",
"LLAMACPP_API_KEY",
"DB_URI",
"DOCSTORE_URI",
"PostgresDocStore",
"create_docstore",
"create_parent_retriever",

View File

@@ -1,27 +1,30 @@
# rag_core/client.py
import os
from .config import QDRANT_URL, QDRANT_API_KEY
from typing import Optional
from qdrant_client import QdrantClient
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
"""
创建并返回一个配置好的 Qdrant 客户端。
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 300, # 索引构建需要较长超时
) -> QdrantClient:
effective_url = url or QDRANT_URL
effective_api_key = api_key or QDRANT_API_KEY
Args:
timeout: 请求超时时间(秒),默认 300 秒(索引构建需要较长超时)。
if not effective_url:
Returns:
配置好的 QdrantClient 实例。
Raises:
ValueError: 如果 QDRANT_URL 未配置。
"""
if not QDRANT_URL:
raise ValueError("Qdrant URL 未配置")
client_kwargs = {
"url": effective_url,
"url": QDRANT_URL,
"timeout": timeout,
}
if effective_api_key:
client_kwargs["api_key"] = effective_api_key
if QDRANT_API_KEY:
client_kwargs["api_key"] = QDRANT_API_KEY
return QdrantClient(**client_kwargs)
return QdrantClient(**client_kwargs)

View File

@@ -5,21 +5,21 @@
import os
from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
import httpx
from typing import List, Optional
from typing import List
from langchain_core.embeddings import Embeddings
class LlamaCppEmbedder:
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
def __init__(
self,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
model: str = "Qwen3-Embedding-0.6B-Q8_0",
):
self.base_url = base_url or LLAMACPP_EMBEDDING_URL
self.api_key = api_key or LLAMACPP_API_KEY
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"):
"""
Args:
model: 嵌入模型名称,默认 "Qwen3-Embedding-0.6B-Q8_0"
"""
self.base_url = LLAMACPP_EMBEDDING_URL
self.api_key = LLAMACPP_API_KEY
self.model = model
def as_langchain_embeddings(self) -> Embeddings:
@@ -30,7 +30,7 @@ class LlamaCppEmbedder:
"""嵌入一批文档。"""
return self._call_embedding_api(texts)
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> List[List[float]]:
"""嵌入单个查询。"""
return self._call_embedding_api([text])[0]
@@ -70,6 +70,7 @@ class LlamaCppEmbedder:
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
class _LlamaCppLangchainAdapter(Embeddings):
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
@@ -79,5 +80,5 @@ class _LlamaCppLangchainAdapter(Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embedder.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
return self._embedder.embed_query(text)
def embed_query(self, text: str) -> List[List[float]]:
return self._embedder.embed_query(text)

View File

@@ -1,38 +1,46 @@
# rag_core/retriever_factory.py
# rag_core/retriever_factory.py
from langchain_core.embeddings import Embeddings
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_core.stores import BaseStore
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
embeddings: Optional[Embeddings] = None,
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
parent_splitter: TextSplitter | None = None,
child_splitter: TextSplitter | None = None,
docstore: BaseStore | None = None,
search_k: int = 5,
# 若未传入切分器,则用以下参数创建默认切分器
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
) -> ParentDocumentRetriever:
"""
创建 ParentDocumentRetriever 实例。
Args:
collection_name: Qdrant 集合名称,默认 "rag_documents"
parent_splitter: 父文档切分器,默认 None使用默认参数创建
child_splitter: 子文档切分器,默认 None使用默认参数创建
docstore: 文档存储实例,默认 None使用默认参数创建
search_k: 检索时返回的结果数,默认 5
parent_chunk_size: 父文档块大小,默认 1000
parent_chunk_overlap: 父文档块重叠大小,默认 100
child_chunk_size: 子文档块大小,默认 200
child_chunk_overlap: 子文档块重叠大小,默认 20
Returns:
ParentDocumentRetriever 实例
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 向量存储(只读)
vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=embeddings,
)
vector_store = QdrantVectorStore(collection_name=collection_name)
# 切分器(若未提供则创建默认)
if parent_splitter is None:
@@ -48,7 +56,7 @@ def create_parent_retriever(
# 文档存储
if docstore is None:
docstore, _ = create_docstore() # 从环境变量读取连接
docstore, _ = create_docstore()
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
@@ -56,4 +64,4 @@ def create_parent_retriever(
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
)
)

View File

@@ -9,14 +9,13 @@
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... connection_string="postgresql://user:pass@host:5432/db",
... table_name="parent_docs"
... )
"""
from .postgres import PostgresDocStore
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
from .factory import create_docstore, get_docstore_uri
__version__ = "2.0.0"
@@ -27,5 +26,4 @@ __all__ = [
# 工厂函数
"create_docstore",
"get_docstore_uri",
"DEFAULT_DB_URI",
]

View File

@@ -5,17 +5,14 @@
"""
import os
from ..config import DB_URI, DOCSTORE_URI
from ..config import DOCSTORE_URI
import logging
from typing import Optional, Tuple
from typing import Tuple
from langchain_core.stores import BaseStore
from .postgres import PostgresDocStore
logger = logging.getLogger(__name__)
# 默认连接字符串(从环境变量读取)
DEFAULT_DB_URI = DB_URI
logger = logging.getLogger(__name__)
def get_docstore_uri() -> str:
@@ -24,48 +21,36 @@ def get_docstore_uri() -> str:
def create_docstore(
store_type: str = "postgres",
connection_string: Optional[str] = None,
table_name: str = "parent_documents",
pool_config: Optional[dict] = None,
max_concurrency: Optional[int] = None
) -> Tuple[BaseStore, Optional[str]]:
pool_config: dict | None = None,
max_concurrency: int | None = None
) -> Tuple[BaseStore, str]:
"""
工厂函数,创建 PostgreSQL 文档存储。
Args:
store_type: 存储类型,目前仅支持 "postgres"(默认)
connection_string: PostgreSQL 连接字符串
table_name: PostgreSQL 表名默认parent_documents
pool_config: 连接池配置
max_concurrency: 最大并发操作数,如果为 None 则不限制
Returns:
元组 (存储实例, 连接字符串)
Raises:
ValueError: 不支持的存储类型
ImportError: 缺少必要的依赖
Example:
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... connection_string="postgresql://user:pass@host:5432/db",
... table_name="parent_docs",
... max_concurrency=10
... )
"""
store_type = store_type.lower()
if store_type == "postgres":
conn_str = connection_string or get_docstore_uri()
store = PostgresDocStore(
connection_string=conn_str,
table_name=table_name,
pool_config=pool_config,
max_concurrency=max_concurrency
)
return store, conn_str
else:
raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")
conn_str = get_docstore_uri()
store = PostgresDocStore(
connection_string=conn_str,
table_name=table_name,
pool_config=pool_config,
max_concurrency=max_concurrency
)
return store, conn_str

View File

@@ -4,7 +4,6 @@ Qdrant 向量数据库包装器。
import logging
import os
from .config import QDRANT_URL, QDRANT_API_KEY
import time
from typing import List, Optional, Dict, Any
@@ -14,31 +13,28 @@ from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from .client import create_qdrant_client
from .embedders import LlamaCppEmbedder
logger = logging.getLogger(__name__)
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器。"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
):
def __init__(self, collection_name: str):
"""
Args:
collection_name: Qdrant 集合名称。
"""
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
self._connection_attempts = 0
self._last_connection_time: Optional[float] = None
if embeddings is None:
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
self.create_collection()
@@ -92,12 +88,10 @@ class QdrantVectorStore:
"client_initialized": self._client is not None,
}
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
def create_collection(self, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
max_retries = 3
base_delay = 2
@@ -177,4 +171,4 @@ class QdrantVectorStore:
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.get_client()
return self.get_client()