""" Qdrant 向量数据库包装器。 """ import logging import os from typing import List, Optional, Dict, Any from langchain_core.documents import Document from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS from qdrant_client import QdrantClient from qdrant_client.http import models from qdrant_client.http.models import Distance, VectorParams from .embedders import LlamaCppEmbedder logger = logging.getLogger(__name__) QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") class QdrantVectorStore: """Qdrant 向量数据库操作包装器。""" def __init__( self, collection_name: str, embeddings: Optional[Any] = None, ): self.collection_name = collection_name self._client: Optional[QdrantClient] = None # 嵌入模型 if embeddings is None: embedder = LlamaCppEmbedder() self.embeddings = embedder.as_langchain_embeddings() else: self.embeddings = embeddings # 先创建集合 self.create_collection() # LangChain 向量存储 self.vector_store = LangchainQdrantVS( client=self.get_client(), collection_name=self.collection_name, embedding=self.embeddings, ) def get_client(self) -> QdrantClient: """懒加载客户端,每次获取时确保连接可用。""" if self._client is None: self._client = QdrantClient( url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=120, http2=False, ) return self._client def refresh_client(self): """关闭旧连接,创建新连接。""" if self._client is not None: self._client.close() self._client = None def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): """创建集合,设置合适的向量维度。""" if vector_size is None: embedder = LlamaCppEmbedder() vector_size = embedder.get_embedding_dimension() client = self.get_client() collections = client.get_collections().collections exists = any(c.name == self.collection_name for c in collections) if exists and force_recreate: client.delete_collection(self.collection_name) exists = False if not exists: client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size) else: logger.info("集合 '%s' 已存在", self.collection_name) def add_documents(self, documents: List[Document], batch_size: int = 100): """将文档添加到向量数据库。""" if not documents: return [] self.create_collection() ids = self.vector_store.add_documents(documents, batch_size=batch_size) logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids)) return ids def similarity_search(self, query: str, k: int = 5) -> List[Document]: return self.vector_store.similarity_search(query, k=k) def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]: return self.vector_store.similarity_search_with_score(query, k=k) def delete_collection(self): self.get_client().delete_collection(self.collection_name) logger.info("集合 '%s' 已删除", self.collection_name) def get_collection_info(self) -> Dict[str, Any]: info = self.get_client().get_collection(self.collection_name) vectors_config = info.config.params.vectors if isinstance(vectors_config, dict): vector_size = next(iter(vectors_config.values())).size else: vector_size = vectors_config.size return { "name": self.collection_name, "vectors_count": info.points_count or 0, "status": info.status, "vector_size": vector_size, } def as_langchain_vectorstore(self): return self.vector_store def get_langchain_vectorstore(self): """返回 LangChain Qdrant 向量存储对象(别名)""" return self.vector_store def get_qdrant_client(self): """返回原生 Qdrant 客户端(如需手动管理 collection)""" return self.get_client()