Files
ailine/rag_core/vector_store.py

181 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Qdrant 向量数据库包装器。
"""
import logging
import os
import time
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from rag_core.client import create_qdrant_client
logger = logging.getLogger(__name__)
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器。"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
):
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
self._connection_attempts = 0
self._last_connection_time: Optional[float] = None
if embeddings is None:
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
self.create_collection()
self.vector_store = LangchainQdrantVS(
client=self.get_client(),
collection_name=self.collection_name,
embedding=self.embeddings,
)
def get_client(self) -> QdrantClient:
if self._client is None:
self._client = create_qdrant_client(timeout=300)
self._connection_attempts += 1
self._last_connection_time = time.time()
logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
try:
self._client.close()
logger.debug("Qdrant 旧连接已关闭")
except Exception as e:
logger.warning("关闭 Qdrant 连接时出现异常: %s", e)
finally:
self._client = None
self._last_connection_time = None
def check_connection_health(self) -> bool:
"""检查连接健康状态,如果连接已失效则自动重建。"""
if self._client is None:
logger.info("Qdrant 客户端未初始化,将创建新连接")
return False
try:
client = self.get_client()
client.get_collections()
logger.debug("Qdrant 连接健康检查通过")
return True
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
logger.warning("Qdrant 连接健康检查失败: %s", e)
self.refresh_client()
return False
def get_connection_stats(self) -> Dict[str, Any]:
"""获取连接统计信息。"""
return {
"connection_attempts": self._connection_attempts,
"last_connection_time": self._last_connection_time,
"client_initialized": self._client is not None,
}
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
max_retries = 3
base_delay = 2
for attempt in range(max_retries):
try:
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
client.delete_collection(self.collection_name)
exists = False
if not exists:
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
return
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
if attempt == max_retries - 1:
logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e)
raise
wait_time = base_delay * (2 ** attempt)
error_type = type(e).__name__
logger.warning(
"创建集合 '%s' 遇到网络异常 [%s]%d秒后重试 (%d/%d): %s",
self.collection_name, error_type, wait_time, attempt + 1, max_retries, e
)
self.refresh_client()
logger.debug("已刷新 Qdrant 客户端连接")
time.sleep(wait_time)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库。"""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
return ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
return self.vector_store.similarity_search(query, k=k)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
self.get_client().delete_collection(self.collection_name)
logger.info("集合 '%s' 已删除", self.collection_name)
def get_collection_info(self) -> Dict[str, Any]:
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
first_config = next(iter(vectors_config.values()), None)
vector_size = first_config.size if first_config else 0
else:
vector_size = vectors_config.size if vectors_config else 0
return {
"name": self.collection_name,
"vectors_count": info.points_count or 0,
"status": info.status,
"vector_size": vector_size,
}
def as_langchain_vectorstore(self):
return self.vector_store
def get_langchain_vectorstore(self):
"""返回 LangChain Qdrant 向量存储对象(别名)"""
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.get_client()