Files
ailine/rag_core/vector_store.py
root 3c906e91d9
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 35m37s
重排,多路查询
2026-04-20 01:10:18 +08:00

124 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 向量数据库包装器。
"""
import logging
import os
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from .client import create_qdrant_client
logger = logging.getLogger(__name__)
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器。"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
):
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
if embeddings is None:
from .embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
self.create_collection()
self.vector_store = LangchainQdrantVS(
client=self.get_client(),
collection_name=self.collection_name,
embedding=self.embeddings,
)
def get_client(self) -> QdrantClient:
if self._client is None:
self._client = create_qdrant_client(timeout=120)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
self._client.close()
self._client = None
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
from .embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
client.delete_collection(self.collection_name)
exists = False
if not exists:
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库。"""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
return ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
return self.vector_store.similarity_search(query, k=k)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
self.get_client().delete_collection(self.collection_name)
logger.info("集合 '%s' 已删除", self.collection_name)
def get_collection_info(self) -> Dict[str, Any]:
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
vector_size = next(iter(vectors_config.values())).size
else:
vector_size = vectors_config.size
return {
"name": self.collection_name,
"vectors_count": info.points_count or 0,
"status": info.status,
"vector_size": vector_size,
}
def as_langchain_vectorstore(self):
return self.vector_store
def get_langchain_vectorstore(self):
"""返回 LangChain Qdrant 向量存储对象(别名)"""
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.get_client()