RAG数据库生成

This commit is contained in:
2026-04-19 15:01:40 +08:00
parent c18e8a9860
commit cc8ef41ef9
17 changed files with 1089 additions and 577 deletions

View File

@@ -1,5 +1,5 @@
"""
Qdrant vector store wrapper.
Qdrant 向量数据库包装器。
"""
import logging
@@ -16,67 +16,85 @@ from .embedders import LlamaCppEmbedder
logger = logging.getLogger(__name__)
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
class QdrantVectorStore:
"""Wrapper for Qdrant vector database operations."""
"""Qdrant 向量数据库操作包装器。"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
qdrant_url: Optional[str] = None,
api_key: Optional[str] = None,
):
self.collection_name = collection_name
self.qdrant_url = qdrant_url or os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
self.api_key = api_key
self._client: Optional[QdrantClient] = None
# Embeddings
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
# Qdrant client
self.client = QdrantClient(url=self.qdrant_url, api_key=self.api_key)
# 先创建集合
self.create_collection()
# LangChain vector store
# LangChain 向量存储
self.vector_store = LangchainQdrantVS(
client=self.client,
client=self.get_client(),
collection_name=self.collection_name,
embeddings=self.embeddings,
embedding=self.embeddings,
)
def get_client(self) -> QdrantClient:
"""懒加载客户端,每次获取时确保连接可用。"""
if self._client is None:
self._client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
timeout=120,
http2=False,
)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
self._client.close()
self._client = None
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""Create collection with appropriate vector size."""
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
collections = self.client.get_collections().collections
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
self.client.delete_collection(self.collection_name)
client.delete_collection(self.collection_name)
exists = False
if not exists:
self.client.create_collection(
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("Collection '%s' created (dim=%d)", self.collection_name, vector_size)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("Collection '%s' already exists", self.collection_name)
logger.info("集合 '%s' 已存在", self.collection_name)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""Add documents to vector store."""
"""将文档添加到向量数据库。"""
if not documents:
return []
self.create_collection()
ids = self.vector_store.add_documents(documents, batch_size=batch_size)
logger.info("Added %d documents to '%s'", len(ids), self.collection_name)
logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids))
return ids
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
@@ -86,16 +104,21 @@ class QdrantVectorStore:
return self.vector_store.similarity_search_with_score(query, k=k)
def delete_collection(self):
self.client.delete_collection(self.collection_name)
logger.info("Collection '%s' deleted", self.collection_name)
self.get_client().delete_collection(self.collection_name)
logger.info("集合 '%s' 已删除", self.collection_name)
def get_collection_info(self) -> Dict[str, Any]:
info = self.client.get_collection(self.collection_name)
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
vector_size = next(iter(vectors_config.values())).size
else:
vector_size = vectors_config.size
return {
"name": info.name,
"vectors_count": info.vectors_count,
"name": self.collection_name,
"vectors_count": info.points_count or 0,
"status": info.status,
"vector_size": info.config.params.vectors.size,
"vector_size": vector_size,
}
def as_langchain_vectorstore(self):
@@ -107,4 +130,4 @@ class QdrantVectorStore:
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.client
return self.get_client()