Files
ailine/rag_core/vector_store.py

181 lines
7.1 KiB
Python
Raw Normal View History

2026-04-18 16:56:23 +08:00
"""
2026-04-19 15:01:40 +08:00
Qdrant 向量数据库包装器
2026-04-18 16:56:23 +08:00
"""
import logging
import os
2026-04-20 14:05:57 +08:00
import time
2026-04-18 16:56:23 +08:00
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
2026-04-20 14:05:57 +08:00
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from rag_core.client import create_qdrant_client
2026-04-18 16:56:23 +08:00
logger = logging.getLogger(__name__)
2026-04-19 15:01:40 +08:00
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
2026-04-18 16:56:23 +08:00
class QdrantVectorStore:
2026-04-19 15:01:40 +08:00
"""Qdrant 向量数据库操作包装器。"""
2026-04-18 16:56:23 +08:00
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
):
self.collection_name = collection_name
2026-04-19 15:01:40 +08:00
self._client: Optional[QdrantClient] = None
2026-04-20 14:05:57 +08:00
self._connection_attempts = 0
self._last_connection_time: Optional[float] = None
2026-04-18 16:56:23 +08:00
if embeddings is None:
from rag_core.embedders import LlamaCppEmbedder
2026-04-18 16:56:23 +08:00
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
2026-04-19 15:01:40 +08:00
self.create_collection()
2026-04-18 16:56:23 +08:00
self.vector_store = LangchainQdrantVS(
2026-04-19 15:01:40 +08:00
client=self.get_client(),
2026-04-18 16:56:23 +08:00
collection_name=self.collection_name,
2026-04-19 15:01:40 +08:00
embedding=self.embeddings,
2026-04-18 16:56:23 +08:00
)
2026-04-19 15:01:40 +08:00
def get_client(self) -> QdrantClient:
if self._client is None:
2026-04-20 14:05:57 +08:00
self._client = create_qdrant_client(timeout=300)
self._connection_attempts += 1
self._last_connection_time = time.time()
logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts)
2026-04-19 15:01:40 +08:00
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
2026-04-20 14:05:57 +08:00
try:
self._client.close()
logger.debug("Qdrant 旧连接已关闭")
except Exception as e:
logger.warning("关闭 Qdrant 连接时出现异常: %s", e)
finally:
self._client = None
self._last_connection_time = None
def check_connection_health(self) -> bool:
"""检查连接健康状态,如果连接已失效则自动重建。"""
if self._client is None:
logger.info("Qdrant 客户端未初始化,将创建新连接")
return False
try:
client = self.get_client()
client.get_collections()
logger.debug("Qdrant 连接健康检查通过")
return True
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
logger.warning("Qdrant 连接健康检查失败: %s", e)
self.refresh_client()
return False
def get_connection_stats(self) -> Dict[str, Any]:
"""获取连接统计信息。"""
return {
"connection_attempts": self._connection_attempts,
"last_connection_time": self._last_connection_time,
"client_initialized": self._client is not None,
}
2026-04-19 15:01:40 +08:00
2026-04-18 16:56:23 +08:00
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
2026-04-19 15:01:40 +08:00
"""创建集合,设置合适的向量维度。"""
2026-04-18 16:56:23 +08:00
if vector_size is None:
from rag_core.embedders import LlamaCppEmbedder
2026-04-18 16:56:23 +08:00
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
2026-04-20 14:05:57 +08:00
max_retries = 3
base_delay = 2
for attempt in range(max_retries):
try:
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
client.delete_collection(self.collection_name)
exists = False
if not exists:
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
return
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
if attempt == max_retries - 1:
logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e)
raise
wait_time = base_delay * (2 ** attempt)
error_type = type(e).__name__
logger.warning(
"创建集合 '%s' 遇到网络异常 [%s]%d秒后重试 (%d/%d): %s",
self.collection_name, error_type, wait_time, attempt + 1, max_retries, e
)
self.refresh_client()
logger.debug("已刷新 Qdrant 客户端连接")
time.sleep(wait_time)
2026-04-18 16:56:23 +08:00
def add_documents(self, documents: List[Document], batch_size: int = 100):
2026-04-19 15:01:40 +08:00
"""将文档添加到向量数据库。"""
2026-04-18 16:56:23 +08:00
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
2026-04-19 15:01:40 +08:00
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
2026-04-18 16:56:23 +08:00
return ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
return self.vector_store.similarity_search(query, k=k)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
2026-04-19 15:01:40 +08:00
self.get_client().delete_collection(self.collection_name)
logger.info("集合 '%s' 已删除", self.collection_name)
2026-04-18 16:56:23 +08:00
def get_collection_info(self) -> Dict[str, Any]:
2026-04-19 15:01:40 +08:00
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
2026-04-20 14:05:57 +08:00
first_config = next(iter(vectors_config.values()), None)
vector_size = first_config.size if first_config else 0
2026-04-19 15:01:40 +08:00
else:
2026-04-20 14:05:57 +08:00
vector_size = vectors_config.size if vectors_config else 0
2026-04-18 16:56:23 +08:00
return {
2026-04-19 15:01:40 +08:00
"name": self.collection_name,
"vectors_count": info.points_count or 0,
2026-04-18 16:56:23 +08:00
"status": info.status,
2026-04-19 15:01:40 +08:00
"vector_size": vector_size,
2026-04-18 16:56:23 +08:00
}
def as_langchain_vectorstore(self):
return self.vector_store
def get_langchain_vectorstore(self):
"""返回 LangChain Qdrant 向量存储对象(别名)"""
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
2026-04-19 15:01:40 +08:00
return self.get_client()