""" Qdrant 向量数据库包装器。 """ import logging import os import time from typing import List, Optional, Dict, Any from langchain_core.documents import Document from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from httpx import RemoteProtocolError from qdrant_client.http.exceptions import ResponseHandlingException from rag_core.client import create_qdrant_client logger = logging.getLogger(__name__) QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") class QdrantVectorStore: """Qdrant 向量数据库操作包装器。""" def __init__( self, collection_name: str, embeddings: Optional[Any] = None, ): self.collection_name = collection_name self._client: Optional[QdrantClient] = None self._connection_attempts = 0 self._last_connection_time: Optional[float] = None if embeddings is None: from rag_core.embedders import LlamaCppEmbedder embedder = LlamaCppEmbedder() self.embeddings = embedder.as_langchain_embeddings() else: self.embeddings = embeddings self.create_collection() self.vector_store = LangchainQdrantVS( client=self.get_client(), collection_name=self.collection_name, embedding=self.embeddings, ) def get_client(self) -> QdrantClient: if self._client is None: self._client = create_qdrant_client(timeout=300) self._connection_attempts += 1 self._last_connection_time = time.time() logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts) return self._client def refresh_client(self): """关闭旧连接,创建新连接。""" if self._client is not None: try: self._client.close() logger.debug("Qdrant 旧连接已关闭") except Exception as e: logger.warning("关闭 Qdrant 连接时出现异常: %s", e) finally: self._client = None self._last_connection_time = None def check_connection_health(self) -> bool: """检查连接健康状态,如果连接已失效则自动重建。""" if self._client is None: logger.info("Qdrant 客户端未初始化,将创建新连接") return False try: client = self.get_client() client.get_collections() logger.debug("Qdrant 连接健康检查通过") return True except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: logger.warning("Qdrant 连接健康检查失败: %s", e) self.refresh_client() return False def get_connection_stats(self) -> Dict[str, Any]: """获取连接统计信息。""" return { "connection_attempts": self._connection_attempts, "last_connection_time": self._last_connection_time, "client_initialized": self._client is not None, } def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): """创建集合,设置合适的向量维度。""" if vector_size is None: from rag_core.embedders import LlamaCppEmbedder embedder = LlamaCppEmbedder() vector_size = embedder.get_embedding_dimension() max_retries = 3 base_delay = 2 for attempt in range(max_retries): try: client = self.get_client() collections = client.get_collections().collections exists = any(c.name == self.collection_name for c in collections) if exists and force_recreate: client.delete_collection(self.collection_name) exists = False if not exists: client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size) else: logger.info("集合 '%s' 已存在", self.collection_name) return except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: if attempt == max_retries - 1: logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e) raise wait_time = base_delay * (2 ** attempt) error_type = type(e).__name__ logger.warning( "创建集合 '%s' 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s", self.collection_name, error_type, wait_time, attempt + 1, max_retries, e ) self.refresh_client() logger.debug("已刷新 Qdrant 客户端连接") time.sleep(wait_time) def add_documents(self, documents: List[Document], batch_size: int = 100): """将文档添加到向量数据库。""" if not documents: return [] self.create_collection() ids = self.vector_store.add_documents(documents, batch_size=batch_size) logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids)) return ids def similarity_search(self, query: str, k: int = 5) -> List[Document]: return self.vector_store.similarity_search(query, k=k) def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]: return self.vector_store.similarity_search_with_score(query, k=k) def delete_collection(self): self.get_client().delete_collection(self.collection_name) logger.info("集合 '%s' 已删除", self.collection_name) def get_collection_info(self) -> Dict[str, Any]: info = self.get_client().get_collection(self.collection_name) vectors_config = info.config.params.vectors if isinstance(vectors_config, dict): first_config = next(iter(vectors_config.values()), None) vector_size = first_config.size if first_config else 0 else: vector_size = vectors_config.size if vectors_config else 0 return { "name": self.collection_name, "vectors_count": info.points_count or 0, "status": info.status, "vector_size": vector_size, } def as_langchain_vectorstore(self): return self.vector_store def get_langchain_vectorstore(self): """返回 LangChain Qdrant 向量存储对象(别名)""" return self.vector_store def get_qdrant_client(self): """返回原生 Qdrant 客户端(如需手动管理 collection)""" return self.get_client()