Files
ailine/rag_indexer/vector_store.py
2026-04-19 15:01:40 +08:00

133 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 向量数据库包装器。
"""
import logging
import os
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams
from .embedders import LlamaCppEmbedder
logger = logging.getLogger(__name__)
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器。"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
):
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
# 先创建集合
self.create_collection()
# LangChain 向量存储
self.vector_store = LangchainQdrantVS(
client=self.get_client(),
collection_name=self.collection_name,
embedding=self.embeddings,
)
def get_client(self) -> QdrantClient:
"""懒加载客户端,每次获取时确保连接可用。"""
if self._client is None:
self._client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
timeout=120,
http2=False,
)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
self._client.close()
self._client = None
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
client.delete_collection(self.collection_name)
exists = False
if not exists:
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库。"""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
return ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
return self.vector_store.similarity_search(query, k=k)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
self.get_client().delete_collection(self.collection_name)
logger.info("集合 '%s' 已删除", self.collection_name)
def get_collection_info(self) -> Dict[str, Any]:
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
vector_size = next(iter(vectors_config.values())).size
else:
vector_size = vectors_config.size
return {
"name": self.collection_name,
"vectors_count": info.points_count or 0,
"status": info.status,
"vector_size": vector_size,
}
def as_langchain_vectorstore(self):
return self.vector_store
def get_langchain_vectorstore(self):
"""返回 LangChain Qdrant 向量存储对象(别名)"""
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.get_client()