2026-04-18 16:56:23 +08:00
|
|
|
|
"""
|
2026-04-19 15:01:40 +08:00
|
|
|
|
Qdrant 向量数据库包装器。
|
2026-04-18 16:56:23 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import os
|
|
|
|
|
|
from typing import List, Optional, Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
|
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
|
|
|
|
|
from qdrant_client import QdrantClient
|
|
|
|
|
|
from qdrant_client.http.models import Distance, VectorParams
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
2026-04-19 15:01:40 +08:00
|
|
|
|
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
|
|
|
|
|
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
|
|
|
|
|
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
|
|
|
|
|
class QdrantVectorStore:
|
2026-04-19 15:01:40 +08:00
|
|
|
|
"""Qdrant 向量数据库操作包装器。"""
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
collection_name: str,
|
|
|
|
|
|
embeddings: Optional[Any] = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
self.collection_name = collection_name
|
2026-04-19 15:01:40 +08:00
|
|
|
|
self._client: Optional[QdrantClient] = None
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
|
|
|
|
|
if embeddings is None:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
from .embedders import LlamaCppEmbedder
|
2026-04-18 16:56:23 +08:00
|
|
|
|
embedder = LlamaCppEmbedder()
|
|
|
|
|
|
self.embeddings = embedder.as_langchain_embeddings()
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.embeddings = embeddings
|
|
|
|
|
|
|
2026-04-19 15:01:40 +08:00
|
|
|
|
self.create_collection()
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
|
|
|
|
|
self.vector_store = LangchainQdrantVS(
|
2026-04-19 15:01:40 +08:00
|
|
|
|
client=self.get_client(),
|
2026-04-18 16:56:23 +08:00
|
|
|
|
collection_name=self.collection_name,
|
2026-04-19 15:01:40 +08:00
|
|
|
|
embedding=self.embeddings,
|
2026-04-18 16:56:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-04-19 15:01:40 +08:00
|
|
|
|
def get_client(self) -> QdrantClient:
|
|
|
|
|
|
"""懒加载客户端,每次获取时确保连接可用。"""
|
|
|
|
|
|
if self._client is None:
|
|
|
|
|
|
self._client = QdrantClient(
|
|
|
|
|
|
url=QDRANT_URL,
|
|
|
|
|
|
api_key=QDRANT_API_KEY,
|
|
|
|
|
|
timeout=120,
|
|
|
|
|
|
http2=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
return self._client
|
|
|
|
|
|
|
|
|
|
|
|
def refresh_client(self):
|
|
|
|
|
|
"""关闭旧连接,创建新连接。"""
|
|
|
|
|
|
if self._client is not None:
|
|
|
|
|
|
self._client.close()
|
|
|
|
|
|
self._client = None
|
|
|
|
|
|
|
2026-04-18 16:56:23 +08:00
|
|
|
|
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
|
2026-04-19 15:01:40 +08:00
|
|
|
|
"""创建集合,设置合适的向量维度。"""
|
2026-04-18 16:56:23 +08:00
|
|
|
|
if vector_size is None:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
from .embedders import LlamaCppEmbedder
|
2026-04-18 16:56:23 +08:00
|
|
|
|
embedder = LlamaCppEmbedder()
|
|
|
|
|
|
vector_size = embedder.get_embedding_dimension()
|
|
|
|
|
|
|
2026-04-19 15:01:40 +08:00
|
|
|
|
client = self.get_client()
|
|
|
|
|
|
collections = client.get_collections().collections
|
2026-04-18 16:56:23 +08:00
|
|
|
|
exists = any(c.name == self.collection_name for c in collections)
|
|
|
|
|
|
|
|
|
|
|
|
if exists and force_recreate:
|
2026-04-19 15:01:40 +08:00
|
|
|
|
client.delete_collection(self.collection_name)
|
2026-04-18 16:56:23 +08:00
|
|
|
|
exists = False
|
|
|
|
|
|
|
|
|
|
|
|
if not exists:
|
2026-04-19 15:01:40 +08:00
|
|
|
|
client.create_collection(
|
2026-04-18 16:56:23 +08:00
|
|
|
|
collection_name=self.collection_name,
|
|
|
|
|
|
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
|
|
|
|
|
)
|
2026-04-19 15:01:40 +08:00
|
|
|
|
logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
|
2026-04-18 16:56:23 +08:00
|
|
|
|
else:
|
2026-04-19 15:01:40 +08:00
|
|
|
|
logger.info("集合 '%s' 已存在", self.collection_name)
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
|
|
|
|
|
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
2026-04-19 15:01:40 +08:00
|
|
|
|
"""将文档添加到向量数据库。"""
|
2026-04-18 16:56:23 +08:00
|
|
|
|
if not documents:
|
|
|
|
|
|
return []
|
|
|
|
|
|
self.create_collection()
|
|
|
|
|
|
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
|
2026-04-19 15:01:40 +08:00
|
|
|
|
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
|
2026-04-18 16:56:23 +08:00
|
|
|
|
return ids
|
|
|
|
|
|
|
|
|
|
|
|
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
|
|
|
|
|
return self.vector_store.similarity_search(query, k=k)
|
|
|
|
|
|
|
|
|
|
|
|
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
|
|
|
|
|
return self.vector_store.similarity_search_with_score(query, k=k)
|
|
|
|
|
|
|
|
|
|
|
|
def delete_collection(self):
|
2026-04-19 15:01:40 +08:00
|
|
|
|
self.get_client().delete_collection(self.collection_name)
|
|
|
|
|
|
logger.info("集合 '%s' 已删除", self.collection_name)
|
2026-04-18 16:56:23 +08:00
|
|
|
|
|
|
|
|
|
|
def get_collection_info(self) -> Dict[str, Any]:
|
2026-04-19 15:01:40 +08:00
|
|
|
|
info = self.get_client().get_collection(self.collection_name)
|
|
|
|
|
|
vectors_config = info.config.params.vectors
|
|
|
|
|
|
if isinstance(vectors_config, dict):
|
|
|
|
|
|
vector_size = next(iter(vectors_config.values())).size
|
|
|
|
|
|
else:
|
|
|
|
|
|
vector_size = vectors_config.size
|
2026-04-18 16:56:23 +08:00
|
|
|
|
return {
|
2026-04-19 15:01:40 +08:00
|
|
|
|
"name": self.collection_name,
|
|
|
|
|
|
"vectors_count": info.points_count or 0,
|
2026-04-18 16:56:23 +08:00
|
|
|
|
"status": info.status,
|
2026-04-19 15:01:40 +08:00
|
|
|
|
"vector_size": vector_size,
|
2026-04-18 16:56:23 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def as_langchain_vectorstore(self):
|
|
|
|
|
|
return self.vector_store
|
|
|
|
|
|
|
|
|
|
|
|
def get_langchain_vectorstore(self):
|
|
|
|
|
|
"""返回 LangChain Qdrant 向量存储对象(别名)"""
|
|
|
|
|
|
return self.vector_store
|
|
|
|
|
|
|
|
|
|
|
|
def get_qdrant_client(self):
|
|
|
|
|
|
"""返回原生 Qdrant 客户端(如需手动管理 collection)"""
|
2026-04-19 15:01:40 +08:00
|
|
|
|
return self.get_client()
|