diff --git a/backend/app/rag/__init__.py b/backend/app/rag/__init__.py index 4f86e0b..5499d7d 100644 --- a/backend/app/rag/__init__.py +++ b/backend/app/rag/__init__.py @@ -37,7 +37,6 @@ RAG 检索与生成模块 from .retriever import ( create_base_retriever, create_hybrid_retriever, - create_qdrant_client, ) from .reranker import LLaMaCPPReranker from .query_transform import MultiQueryGenerator @@ -50,7 +49,6 @@ __all__ = [ # 检索器工厂函数 "create_base_retriever", "create_hybrid_retriever", - "create_qdrant_client", # 重排序器 "LLaMaCPPReranker", diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index 483c8b9..c6cd3bb 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -25,66 +25,25 @@ Qdrant 向量检索器模块 >>> docs = retriever.invoke("什么是 RAG?") """ -from typing import Optional, Dict, Any +from typing import Dict, Any from qdrant_client import QdrantClient from qdrant_client.http.exceptions import UnexpectedResponse from langchain_qdrant import QdrantVectorStore from langchain_core.embeddings import Embeddings from langchain_core.retrievers import BaseRetriever -from rag_core import QDRANT_URL, QDRANT_API_KEY +from rag_core import QDRANT_URL, QDRANT_API_KEY, LlamaCppEmbedder +from rag_core.client import create_qdrant_client as create_core_qdrant_client # 模块级常量 DEFAULT_SEARCH_K = 20 DEFAULT_SCORE_THRESHOLD = 0.3 -def create_qdrant_client( - url: Optional[str] = None, - api_key: Optional[str] = None, - timeout: int = 30, -) -> QdrantClient: - """ - 创建并返回一个配置好的 Qdrant 客户端。 - - 优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。 - - Args: - url: Qdrant 服务地址,例如 "http://localhost:6333"。 - 默认从环境变量 QDRANT_URL 读取。 - api_key: API 密钥(若 Qdrant 启用了认证)。 - 默认从环境变量 QDRANT_API_KEY 读取。 - timeout: 请求超时时间(秒),默认 30 秒。 - - Returns: - 配置好的 QdrantClient 实例。 - - Raises: - ValueError: 如果 url 为空且环境变量也未设置。 - """ - effective_url = url or QDRANT_URL - if not effective_url: - raise ValueError( - "Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL" - ) - - effective_api_key = api_key or QDRANT_API_KEY - - client_kwargs = { - "url": effective_url, - "timeout": timeout, - } - if effective_api_key: - client_kwargs["api_key"] = effective_api_key - - return QdrantClient(**client_kwargs) - - def create_base_retriever( collection_name: str, - embeddings: Embeddings, - search_kwargs: Optional[Dict[str, Any]] = None, - client: Optional[QdrantClient] = None, + search_kwargs: Dict[str, Any] | None = None, + client: QdrantClient | None = None, ) -> BaseRetriever: """ 创建基础向量检索器(仅稠密向量检索)。 @@ -94,7 +53,6 @@ def create_base_retriever( Args: collection_name: Qdrant 集合名称(需预先创建并索引)。 - embeddings: LangChain 兼容的嵌入模型实例。 search_kwargs: 搜索参数,可包含: - k (int): 返回的文档数量,默认 20。 - score_threshold (float): 相似度阈值,仅返回高于此分数的文档。 @@ -108,6 +66,10 @@ def create_base_retriever( Raises: ValueError: 如果集合不存在或嵌入模型无效。 """ + # 嵌入模型 + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() + # 合并默认搜索参数 merged_search_kwargs = {"k": DEFAULT_SEARCH_K} if search_kwargs: @@ -115,7 +77,7 @@ def create_base_retriever( # 创建或复用 Qdrant 客户端 if client is None: - client = create_qdrant_client() + client = create_core_qdrant_client() # 验证集合是否存在(可选,便于提前发现问题) try: @@ -140,11 +102,10 @@ def create_base_retriever( def create_hybrid_retriever( collection_name: str, - embeddings: Embeddings, dense_k: int = 10, sparse_k: int = 10, - score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD, - client: Optional[QdrantClient] = None, + score_threshold: float | None = DEFAULT_SCORE_THRESHOLD, + client: QdrantClient | None = None, ) -> BaseRetriever: """ 创建混合检索器(稠密向量 + BM25 稀疏向量)。 @@ -157,7 +118,6 @@ def create_hybrid_retriever( Args: collection_name: Qdrant 集合名称。 - embeddings: 嵌入模型(用于稠密向量)。 dense_k: 稠密向量检索返回数量,默认 10。 sparse_k: 稀疏向量检索返回数量,默认 10。 score_threshold: 相似度阈值,默认 0.3。 @@ -177,7 +137,6 @@ def create_hybrid_retriever( # 复用基础检索器创建逻辑,只需调整搜索参数 return create_base_retriever( collection_name=collection_name, - embeddings=embeddings, search_kwargs=search_kwargs, client=client, ) @@ -186,9 +145,8 @@ def create_hybrid_retriever( # 可选:提供异步友好的辅助函数 async def acreate_base_retriever( collection_name: str, - embeddings: Embeddings, - search_kwargs: Optional[Dict[str, Any]] = None, - client: Optional[QdrantClient] = None, + search_kwargs: Dict[str, Any] | None = None, + client: QdrantClient | None = None, ) -> BaseRetriever: """ 异步创建基础向量检索器(与同步版本功能相同)。 @@ -196,4 +154,4 @@ async def acreate_base_retriever( 适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。 """ # 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可 - return create_base_retriever(collection_name, embeddings, search_kwargs, client) \ No newline at end of file + return create_base_retriever(collection_name, search_kwargs, client) diff --git a/backend/rag_core/__init__.py b/backend/rag_core/__init__.py index 318a066..6eb92ed 100644 --- a/backend/rag_core/__init__.py +++ b/backend/rag_core/__init__.py @@ -5,9 +5,17 @@ RAG Core - 公共 RAG 组件包 """ from .embedders import LlamaCppEmbedder -from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY +from .vector_store import QdrantVectorStore from .store import PostgresDocStore, create_docstore from .retriever_factory import create_parent_retriever +from .config import ( + QDRANT_URL, + QDRANT_API_KEY, + LLAMACPP_EMBEDDING_URL, + LLAMACPP_API_KEY, + DB_URI, + DOCSTORE_URI, +) __all__ = [ @@ -15,6 +23,10 @@ __all__ = [ "QdrantVectorStore", "QDRANT_URL", "QDRANT_API_KEY", + "LLAMACPP_EMBEDDING_URL", + "LLAMACPP_API_KEY", + "DB_URI", + "DOCSTORE_URI", "PostgresDocStore", "create_docstore", "create_parent_retriever", diff --git a/backend/rag_core/client.py b/backend/rag_core/client.py index d689a1f..bd352d9 100644 --- a/backend/rag_core/client.py +++ b/backend/rag_core/client.py @@ -1,27 +1,30 @@ # rag_core/client.py import os from .config import QDRANT_URL, QDRANT_API_KEY -from typing import Optional from qdrant_client import QdrantClient +def create_qdrant_client(timeout: int = 300) -> QdrantClient: + """ + 创建并返回一个配置好的 Qdrant 客户端。 -def create_qdrant_client( - url: Optional[str] = None, - api_key: Optional[str] = None, - timeout: int = 300, # 索引构建需要较长超时 -) -> QdrantClient: - effective_url = url or QDRANT_URL - effective_api_key = api_key or QDRANT_API_KEY + Args: + timeout: 请求超时时间(秒),默认 300 秒(索引构建需要较长超时)。 - if not effective_url: + Returns: + 配置好的 QdrantClient 实例。 + + Raises: + ValueError: 如果 QDRANT_URL 未配置。 + """ + if not QDRANT_URL: raise ValueError("Qdrant URL 未配置") client_kwargs = { - "url": effective_url, + "url": QDRANT_URL, "timeout": timeout, } - if effective_api_key: - client_kwargs["api_key"] = effective_api_key + if QDRANT_API_KEY: + client_kwargs["api_key"] = QDRANT_API_KEY - return QdrantClient(**client_kwargs) \ No newline at end of file + return QdrantClient(**client_kwargs) diff --git a/backend/rag_core/embedders.py b/backend/rag_core/embedders.py index eff11a0..81c2267 100644 --- a/backend/rag_core/embedders.py +++ b/backend/rag_core/embedders.py @@ -5,21 +5,21 @@ import os from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY import httpx -from typing import List, Optional +from typing import List from langchain_core.embeddings import Embeddings + class LlamaCppEmbedder: """通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。""" - def __init__( - self, - base_url: Optional[str] = None, - api_key: Optional[str] = None, - model: str = "Qwen3-Embedding-0.6B-Q8_0", - ): - self.base_url = base_url or LLAMACPP_EMBEDDING_URL - self.api_key = api_key or LLAMACPP_API_KEY + def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"): + """ + Args: + model: 嵌入模型名称,默认 "Qwen3-Embedding-0.6B-Q8_0"。 + """ + self.base_url = LLAMACPP_EMBEDDING_URL + self.api_key = LLAMACPP_API_KEY self.model = model def as_langchain_embeddings(self) -> Embeddings: @@ -30,7 +30,7 @@ class LlamaCppEmbedder: """嵌入一批文档。""" return self._call_embedding_api(texts) - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> List[List[float]]: """嵌入单个查询。""" return self._call_embedding_api([text])[0] @@ -70,6 +70,7 @@ class LlamaCppEmbedder: else: raise ValueError(f"未知的嵌入 API 响应格式: {data}") + class _LlamaCppLangchainAdapter(Embeddings): """将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。""" @@ -79,5 +80,5 @@ class _LlamaCppLangchainAdapter(Embeddings): def embed_documents(self, texts: List[str]) -> List[List[float]]: return self._embedder.embed_documents(texts) - def embed_query(self, text: str) -> List[float]: - return self._embedder.embed_query(text) \ No newline at end of file + def embed_query(self, text: str) -> List[List[float]]: + return self._embedder.embed_query(text) diff --git a/backend/rag_core/retriever_factory.py b/backend/rag_core/retriever_factory.py index 25dab1c..b0e5ab6 100644 --- a/backend/rag_core/retriever_factory.py +++ b/backend/rag_core/retriever_factory.py @@ -1,38 +1,46 @@ -# rag_core/retriever_factory.py +# rag_core/retriever_factory.py from langchain_core.embeddings import Embeddings from langchain_classic.retrievers import ParentDocumentRetriever -from langchain_text_splitters import RecursiveCharacterTextSplitter -from typing import Optional -from langchain_core.embeddings import Embeddings -from langchain_core.stores import BaseStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter -from langchain_classic.retrievers import ParentDocumentRetriever +from langchain_core.stores import BaseStore from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore + def create_parent_retriever( collection_name: str = "rag_documents", - embeddings: Optional[Embeddings] = None, - parent_splitter: Optional[TextSplitter] = None, - child_splitter: Optional[TextSplitter] = None, - docstore: Optional[BaseStore] = None, + parent_splitter: TextSplitter | None = None, + child_splitter: TextSplitter | None = None, + docstore: BaseStore | None = None, search_k: int = 5, - # 若未传入切分器,则用以下参数创建默认切分器 parent_chunk_size: int = 1000, parent_chunk_overlap: int = 100, child_chunk_size: int = 200, child_chunk_overlap: int = 20, ) -> ParentDocumentRetriever: + """ + 创建 ParentDocumentRetriever 实例。 + + Args: + collection_name: Qdrant 集合名称,默认 "rag_documents" + parent_splitter: 父文档切分器,默认 None(使用默认参数创建) + child_splitter: 子文档切分器,默认 None(使用默认参数创建) + docstore: 文档存储实例,默认 None(使用默认参数创建) + search_k: 检索时返回的结果数,默认 5 + parent_chunk_size: 父文档块大小,默认 1000 + parent_chunk_overlap: 父文档块重叠大小,默认 100 + child_chunk_size: 子文档块大小,默认 200 + child_chunk_overlap: 子文档块重叠大小,默认 20 + + Returns: + ParentDocumentRetriever 实例 + """ # 嵌入模型 - if embeddings is None: - embedder = LlamaCppEmbedder() - embeddings = embedder.as_langchain_embeddings() + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() # 向量存储(只读) - vector_store = QdrantVectorStore( - collection_name=collection_name, - embeddings=embeddings, - ) + vector_store = QdrantVectorStore(collection_name=collection_name) # 切分器(若未提供则创建默认) if parent_splitter is None: @@ -48,7 +56,7 @@ def create_parent_retriever( # 文档存储 if docstore is None: - docstore, _ = create_docstore() # 从环境变量读取连接 + docstore, _ = create_docstore() return ParentDocumentRetriever( vectorstore=vector_store.get_langchain_vectorstore(), @@ -56,4 +64,4 @@ def create_parent_retriever( child_splitter=child_splitter, parent_splitter=parent_splitter, search_kwargs={"k": search_k}, - ) \ No newline at end of file + ) diff --git a/backend/rag_core/store/__init__.py b/backend/rag_core/store/__init__.py index 359db76..476dd6b 100644 --- a/backend/rag_core/store/__init__.py +++ b/backend/rag_core/store/__init__.py @@ -9,14 +9,13 @@ >>> # 创建 PostgreSQL 存储 >>> store, conn = create_docstore( - ... connection_string="postgresql://user:pass@host:5432/db", ... table_name="parent_docs" ... ) """ from .postgres import PostgresDocStore -from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI +from .factory import create_docstore, get_docstore_uri __version__ = "2.0.0" @@ -27,5 +26,4 @@ __all__ = [ # 工厂函数 "create_docstore", "get_docstore_uri", - "DEFAULT_DB_URI", ] diff --git a/backend/rag_core/store/factory.py b/backend/rag_core/store/factory.py index 6c87ac9..bf9ade7 100644 --- a/backend/rag_core/store/factory.py +++ b/backend/rag_core/store/factory.py @@ -5,17 +5,14 @@ """ import os -from ..config import DB_URI, DOCSTORE_URI +from ..config import DOCSTORE_URI import logging -from typing import Optional, Tuple +from typing import Tuple from langchain_core.stores import BaseStore from .postgres import PostgresDocStore -logger = logging.getLogger(__name__) - -# 默认连接字符串(从环境变量读取) -DEFAULT_DB_URI = DB_URI +logger = logging.getLogger(__name__) def get_docstore_uri() -> str: @@ -24,48 +21,36 @@ def get_docstore_uri() -> str: def create_docstore( - store_type: str = "postgres", - connection_string: Optional[str] = None, table_name: str = "parent_documents", - pool_config: Optional[dict] = None, - max_concurrency: Optional[int] = None -) -> Tuple[BaseStore, Optional[str]]: + pool_config: dict | None = None, + max_concurrency: int | None = None +) -> Tuple[BaseStore, str]: """ 工厂函数,创建 PostgreSQL 文档存储。 - + Args: - store_type: 存储类型,目前仅支持 "postgres"(默认) - connection_string: PostgreSQL 连接字符串 table_name: PostgreSQL 表名(默认:parent_documents) pool_config: 连接池配置 max_concurrency: 最大并发操作数,如果为 None 则不限制 - + Returns: 元组 (存储实例, 连接字符串) - + Raises: - ValueError: 不支持的存储类型 ImportError: 缺少必要的依赖 - + Example: >>> # 创建 PostgreSQL 存储 >>> store, conn = create_docstore( - ... connection_string="postgresql://user:pass@host:5432/db", ... table_name="parent_docs", ... max_concurrency=10 ... ) """ - store_type = store_type.lower() - - if store_type == "postgres": - conn_str = connection_string or get_docstore_uri() - store = PostgresDocStore( - connection_string=conn_str, - table_name=table_name, - pool_config=pool_config, - max_concurrency=max_concurrency - ) - return store, conn_str - - else: - raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres") + conn_str = get_docstore_uri() + store = PostgresDocStore( + connection_string=conn_str, + table_name=table_name, + pool_config=pool_config, + max_concurrency=max_concurrency + ) + return store, conn_str diff --git a/backend/rag_core/vector_store.py b/backend/rag_core/vector_store.py index b2ecd20..7848157 100644 --- a/backend/rag_core/vector_store.py +++ b/backend/rag_core/vector_store.py @@ -4,7 +4,6 @@ Qdrant 向量数据库包装器。 import logging import os -from .config import QDRANT_URL, QDRANT_API_KEY import time from typing import List, Optional, Dict, Any @@ -14,31 +13,28 @@ from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from httpx import RemoteProtocolError from qdrant_client.http.exceptions import ResponseHandlingException + from .client import create_qdrant_client +from .embedders import LlamaCppEmbedder logger = logging.getLogger(__name__) - class QdrantVectorStore: """Qdrant 向量数据库操作包装器。""" - def __init__( - self, - collection_name: str, - embeddings: Optional[Any] = None, - ): + def __init__(self, collection_name: str): + """ + Args: + collection_name: Qdrant 集合名称。 + """ self.collection_name = collection_name self._client: Optional[QdrantClient] = None self._connection_attempts = 0 self._last_connection_time: Optional[float] = None - if embeddings is None: - from rag_core.embedders import LlamaCppEmbedder - embedder = LlamaCppEmbedder() - self.embeddings = embedder.as_langchain_embeddings() - else: - self.embeddings = embeddings + embedder = LlamaCppEmbedder() + self.embeddings = embedder.as_langchain_embeddings() self.create_collection() @@ -92,12 +88,10 @@ class QdrantVectorStore: "client_initialized": self._client is not None, } - def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): + def create_collection(self, force_recreate: bool = False): """创建集合,设置合适的向量维度。""" - if vector_size is None: - from rag_core.embedders import LlamaCppEmbedder - embedder = LlamaCppEmbedder() - vector_size = embedder.get_embedding_dimension() + embedder = LlamaCppEmbedder() + vector_size = embedder.get_embedding_dimension() max_retries = 3 base_delay = 2 @@ -177,4 +171,4 @@ class QdrantVectorStore: def get_qdrant_client(self): """返回原生 Qdrant 客户端(如需手动管理 collection)""" - return self.get_client() \ No newline at end of file + return self.get_client() diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index 587b1e4..a348120 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -37,11 +37,10 @@ logger = logging.getLogger(__name__) @dataclass class DocstoreConfig: """文档存储配置(用于父块存储)。""" - connection_string: Optional[str] = None - pool_config: Optional[Dict[str, Any]] = None - max_concurrency: Optional[int] = None + pool_config: Dict[str, Any] | None = None + max_concurrency: int | None = None # 若要从外部注入已创建好的 docstore,可直接设置此字段 - instance: Optional[BaseStore] = None + instance: BaseStore | None = None @dataclass class IndexBuilderConfig: @@ -147,7 +146,6 @@ class IndexBuilder: # 使用工厂函数创建检索器,避免重复代码 self.retriever = create_parent_retriever( collection_name=cfg.collection_name, - embeddings=self.embeddings, parent_splitter=self.parent_splitter, child_splitter=self.child_splitter, docstore=self.docstore, @@ -164,7 +162,6 @@ class IndexBuilder: # 使用 create_docstore 创建 PostgreSQL 存储 docstore, conn_info = create_docstore( - connection_string=cfg.connection_string, pool_config=cfg.pool_config, max_concurrency=cfg.max_concurrency, )