This commit is contained in:
110
rag_indexer/vector_store.py
Normal file
110
rag_indexer/vector_store.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Qdrant vector store wrapper.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QdrantVectorStore:
|
||||
"""Wrapper for Qdrant vector database operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Any] = None,
|
||||
qdrant_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.qdrant_url = qdrant_url or os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
self.api_key = api_key
|
||||
|
||||
# Embeddings
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
|
||||
# Qdrant client
|
||||
self.client = QdrantClient(url=self.qdrant_url, api_key=self.api_key)
|
||||
|
||||
# LangChain vector store
|
||||
self.vector_store = LangchainQdrantVS(
|
||||
client=self.client,
|
||||
collection_name=self.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
|
||||
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
|
||||
"""Create collection with appropriate vector size."""
|
||||
if vector_size is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
vector_size = embedder.get_embedding_dimension()
|
||||
|
||||
collections = self.client.get_collections().collections
|
||||
exists = any(c.name == self.collection_name for c in collections)
|
||||
|
||||
if exists and force_recreate:
|
||||
self.client.delete_collection(self.collection_name)
|
||||
exists = False
|
||||
|
||||
if not exists:
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
||||
)
|
||||
logger.info("Collection '%s' created (dim=%d)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("Collection '%s' already exists", self.collection_name)
|
||||
|
||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||
"""Add documents to vector store."""
|
||||
if not documents:
|
||||
return []
|
||||
self.create_collection()
|
||||
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
|
||||
logger.info("Added %d documents to '%s'", len(ids), self.collection_name)
|
||||
return ids
|
||||
|
||||
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
|
||||
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
||||
return self.vector_store.similarity_search_with_score(query, k=k)
|
||||
|
||||
def delete_collection(self):
|
||||
self.client.delete_collection(self.collection_name)
|
||||
logger.info("Collection '%s' deleted", self.collection_name)
|
||||
|
||||
def get_collection_info(self) -> Dict[str, Any]:
|
||||
info = self.client.get_collection(self.collection_name)
|
||||
return {
|
||||
"name": info.name,
|
||||
"vectors_count": info.vectors_count,
|
||||
"status": info.status,
|
||||
"vector_size": info.config.params.vectors.size,
|
||||
}
|
||||
|
||||
def as_langchain_vectorstore(self):
|
||||
return self.vector_store
|
||||
|
||||
def get_langchain_vectorstore(self):
|
||||
"""返回 LangChain Qdrant 向量存储对象(别名)"""
|
||||
return self.vector_store
|
||||
|
||||
def get_qdrant_client(self):
|
||||
"""返回原生 Qdrant 客户端(如需手动管理 collection)"""
|
||||
return self.client
|
||||
Reference in New Issue
Block a user