refactor!: 完全异步化 RAG 系统,移除 LangChain ParentDocumentRetriever 依赖
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m34s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m34s
- 重写 rag_core/vector_store.py:完全异步实现 aadd_documents、asimilarity_search - 重写 app/rag/retriever.py:异步混合检索,移除同步兼容代码 - 修改 rag_indexer/index_builder.py:全链路异步调用 - 删除 rag_core/retriever_factory.py:不再使用 LangChain ParentDocumentRetriever - 清理冗余导入和代码:移除 model_services 兼容、不需要的异常导入 - 更新 rag_indexer/README.md:反映新架构 核心改进: - 完全异步化:索引构建和检索全链路 async/await - 自定义实现:不再依赖 LangChain 的 ParentDocumentRetriever - 双向量支持:子文档同时存储 dense + sparse 向量到 Qdrant - 架构清晰:rag_core 公共组件、rag_indexer 索引、app/rag 检索
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
"""
|
||||
RAG Core - 公共 RAG 组件包
|
||||
|
||||
提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。
|
||||
"""
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .embedders import get_embeddings, get_embedding_dimension
|
||||
from .vector_store import QdrantHybridStore
|
||||
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .retriever_factory import create_parent_retriever
|
||||
from .client import create_qdrant_client, create_async_qdrant_client
|
||||
from .config import (
|
||||
QDRANT_URL,
|
||||
QDRANT_API_KEY,
|
||||
@@ -20,8 +19,9 @@ from .config import (
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LlamaCppEmbedder",
|
||||
"QdrantVectorStore",
|
||||
"get_embeddings",
|
||||
"get_embedding_dimension",
|
||||
"QdrantHybridStore",
|
||||
"BM25SparseEmbedder",
|
||||
"get_sparse_embedder",
|
||||
"QDRANT_URL",
|
||||
@@ -32,5 +32,6 @@ __all__ = [
|
||||
"DOCSTORE_URI",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"create_parent_retriever",
|
||||
"create_qdrant_client",
|
||||
"create_async_qdrant_client",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# rag_core/client.py
|
||||
import os
|
||||
from .config import QDRANT_URL, QDRANT_API_KEY
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client import QdrantClient, AsyncQdrantClient
|
||||
|
||||
|
||||
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||||
@@ -28,3 +28,29 @@ def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
|
||||
|
||||
def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient:
|
||||
"""
|
||||
创建并返回一个配置好的 Qdrant 异步客户端。
|
||||
|
||||
Args:
|
||||
timeout: 请求超时时间(秒),默认 300 秒。
|
||||
|
||||
Returns:
|
||||
配置好的 AsyncQdrantClient 实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 QDRANT_URL 未配置。
|
||||
"""
|
||||
if not QDRANT_URL:
|
||||
raise ValueError("Qdrant URL 未配置")
|
||||
|
||||
client_kwargs = {
|
||||
"url": QDRANT_URL,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if QDRANT_API_KEY:
|
||||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||||
|
||||
return AsyncQdrantClient(**client_kwargs)
|
||||
|
||||
@@ -1,121 +1,37 @@
|
||||
"""
|
||||
嵌入模型包装器 - 直接使用统一嵌入服务
|
||||
支持自动降级(本地 llama.cpp → 智谱),由 get_embedding_service() 内部处理
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCppEmbedder:
|
||||
def get_embeddings() -> Embeddings:
|
||||
"""
|
||||
嵌入器包装类 - 直接使用统一的 get_embedding_service()
|
||||
降级逻辑完全由 app.model_services 处理
|
||||
获取统一的嵌入服务实例。
|
||||
|
||||
Returns:
|
||||
LangChain 兼容的 Embeddings 实例
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0", use_fallback: bool = True):
|
||||
"""
|
||||
Args:
|
||||
model: 嵌入模型名称(向后兼容,现在实际使用统一服务)
|
||||
use_fallback: 是否使用降级机制(保留参数,现在始终为 True)
|
||||
"""
|
||||
self.model = model
|
||||
self._fallback_embeddings = None
|
||||
|
||||
# 直接获取统一嵌入服务
|
||||
try:
|
||||
from backend.app.model_services import get_embedding_service
|
||||
self._fallback_embeddings = get_embedding_service()
|
||||
logger.info("✅ 统一嵌入服务加载成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 无法加载统一嵌入服务: {e}")
|
||||
# 保留向后兼容的初始化
|
||||
self.base_url = LLAMACPP_EMBEDDING_URL
|
||||
self.api_key = LLAMACPP_API_KEY
|
||||
|
||||
def as_langchain_embeddings(self) -> Embeddings:
|
||||
"""创建 LangChain 兼容的嵌入实例"""
|
||||
if self._fallback_embeddings:
|
||||
logger.info("✅ 使用统一嵌入服务(已内置降级机制)")
|
||||
return self._fallback_embeddings
|
||||
|
||||
# 向后兼容,仅在统一服务不可用时使用传统方式
|
||||
logger.warning("⚠️ 统一服务不可用,使用传统模式(不推荐)")
|
||||
return _LlamaCppLangchainAdapter(self)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""嵌入一批文档"""
|
||||
if self._fallback_embeddings:
|
||||
return self._fallback_embeddings.embed_documents(texts)
|
||||
|
||||
# 向后兼容
|
||||
return self._call_embedding_api(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""嵌入单个查询"""
|
||||
if self._fallback_embeddings:
|
||||
return self._fallback_embeddings.embed_query(text)
|
||||
|
||||
# 向后兼容
|
||||
return self._call_embedding_api([text])[0]
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""通过嵌入测试字符串获取嵌入维度"""
|
||||
test_embedding = self.embed_query("test")
|
||||
return len(test_embedding)
|
||||
|
||||
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
|
||||
"""仅作为向后兼容的备用方法"""
|
||||
import httpx
|
||||
|
||||
if not hasattr(self, 'base_url') or not self.base_url:
|
||||
raise ValueError("LLAMACPP_EMBEDDING_URL 未配置且统一服务不可用")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
base = self.base_url.rstrip("/")
|
||||
if not base.endswith("/v1"):
|
||||
base = base + "/v1"
|
||||
|
||||
payload = {
|
||||
"input": texts,
|
||||
"model": self.model,
|
||||
}
|
||||
|
||||
with httpx.Client(timeout=120) as client:
|
||||
response = client.post(
|
||||
f"{base}/embeddings",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if isinstance(data, list):
|
||||
return [item["embedding"] for item in data]
|
||||
elif isinstance(data, dict) and "data" in data:
|
||||
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
|
||||
else:
|
||||
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
|
||||
from backend.app.model_services import get_embedding_service
|
||||
return get_embedding_service()
|
||||
|
||||
|
||||
class _LlamaCppLangchainAdapter(Embeddings):
|
||||
"""仅作为向后兼容的适配器"""
|
||||
|
||||
def __init__(self, embedder: LlamaCppEmbedder):
|
||||
self._embedder = embedder
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embedder.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embedder.embed_query(text)
|
||||
def get_embedding_dimension(embeddings: Optional[Embeddings] = None) -> int:
|
||||
"""
|
||||
获取嵌入维度。
|
||||
|
||||
Args:
|
||||
embeddings: 可选的嵌入实例,如果不提供则自动获取
|
||||
|
||||
Returns:
|
||||
嵌入维度大小
|
||||
"""
|
||||
if embeddings is None:
|
||||
embeddings = get_embeddings()
|
||||
test_embedding = embeddings.embed_query("test")
|
||||
return len(test_embedding)
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
"""
|
||||
RAG 检索器工厂模块
|
||||
|
||||
提供创建各种检索器的工厂函数,包括:
|
||||
- 基础向量检索器
|
||||
- ParentDocumentRetriever(父子文档)
|
||||
- 混合检索器(稠密+稀疏)
|
||||
"""
|
||||
from typing import Optional
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .store import create_docstore
|
||||
|
||||
|
||||
def create_parent_retriever(
|
||||
collection_name: str = "rag_documents",
|
||||
parent_splitter: Optional[TextSplitter] = None,
|
||||
child_splitter: Optional[TextSplitter] = None,
|
||||
docstore: Optional[BaseStore] = None,
|
||||
search_k: int = 5,
|
||||
parent_chunk_size: int = 1000,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_size: int = 200,
|
||||
child_chunk_overlap: int = 20,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> ParentDocumentRetriever:
|
||||
"""
|
||||
创建 ParentDocumentRetriever 实例(基础稠密向量版本)。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称,默认 "rag_documents"
|
||||
parent_splitter: 父文档切分器,默认 None(使用默认参数创建)
|
||||
child_splitter: 子文档切分器,默认 None(使用默认参数创建)
|
||||
docstore: 文档存储实例,默认 None(使用默认参数创建)
|
||||
search_k: 检索时返回的结果数,默认 5
|
||||
parent_chunk_size: 父文档块大小,默认 1000
|
||||
parent_chunk_overlap: 父文档块重叠大小,默认 100
|
||||
child_chunk_size: 子文档块大小,默认 200
|
||||
child_chunk_overlap: 子文档块重叠大小,默认 20
|
||||
embeddings: 嵌入模型实例,默认 None(使用内部默认的 LocalLlamaCppEmbedder)
|
||||
|
||||
Returns:
|
||||
ParentDocumentRetriever 实例
|
||||
"""
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 向量存储(只读)
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
# 切分器(若未提供则创建默认)
|
||||
if parent_splitter is None:
|
||||
parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_chunk_size,
|
||||
chunk_overlap=parent_chunk_overlap,
|
||||
)
|
||||
if child_splitter is None:
|
||||
child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_chunk_size,
|
||||
chunk_overlap=child_chunk_overlap,
|
||||
)
|
||||
|
||||
# 文档存储
|
||||
if docstore is None:
|
||||
docstore, _ = create_docstore()
|
||||
|
||||
return ParentDocumentRetriever(
|
||||
vectorstore=vector_store.get_langchain_vectorstore(),
|
||||
docstore=docstore,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
search_kwargs={"k": search_k},
|
||||
)
|
||||
|
||||
|
||||
def create_hybrid_retriever_factory(
|
||||
collection_name: str = "rag_documents",
|
||||
search_k: int = 5,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
【不完整,仅占位】创建混合检索器的工厂函数占位符。
|
||||
|
||||
注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。
|
||||
这里仅返回 QdrantVectorStore 作为基础。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
search_k: 检索返回结果数
|
||||
embeddings: 嵌入模型实例
|
||||
|
||||
Returns:
|
||||
基础的 QdrantVectorStore(仅稠密检索)
|
||||
"""
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建向量存储
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||
|
||||
# 返回 LangChain 兼容的 retriever
|
||||
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})
|
||||
@@ -1,6 +1,5 @@
|
||||
"""
|
||||
Qdrant 向量数据库包装器。
|
||||
支持稠密+稀疏双向量存储。
|
||||
Qdrant 向量数据库包装器(完全异步实现)。
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -11,111 +10,91 @@ from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient
|
||||
from qdrant_client.http.models import (
|
||||
Distance, VectorParams, SparseVectorParams, PointStruct
|
||||
Distance, VectorParams, SparseVectorParams, PointStruct, models
|
||||
)
|
||||
from httpx import RemoteProtocolError
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
|
||||
from .client import create_qdrant_client
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .client import create_qdrant_client, create_async_qdrant_client
|
||||
from .embedders import get_embeddings, get_embedding_dimension
|
||||
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QdrantVectorStore:
|
||||
"""Qdrant 向量数据库操作包装器 - 支持稠密+稀疏双向量存储。"""
|
||||
class QdrantHybridStore:
|
||||
"""
|
||||
Qdrant 向量数据库操作包装器 - 稠密+稀疏混合检索(完全异步)。
|
||||
直接使用 Qdrant 异步客户端实现,不依赖 LangChain。
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None, sparse_embedder: Optional[BM25SparseEmbedder] = None):
|
||||
"""
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称。
|
||||
embeddings: 嵌入模型实例,默认 None(使用内部默认的 LlamaCppEmbedder)。
|
||||
sparse_embedder: 稀疏嵌入模型实例,默认 None(自动加载BM25)。
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
sparse_embedder: Optional[BM25SparseEmbedder] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._client: Optional[QdrantClient] = None
|
||||
self._async_client: Optional[AsyncQdrantClient] = None
|
||||
self._connection_attempts = 0
|
||||
self._last_connection_time: Optional[float] = None
|
||||
|
||||
|
||||
# 稠密嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
self._embedder = embedder
|
||||
self.embeddings = get_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
self._embedder = None
|
||||
|
||||
|
||||
# 稀疏嵌入模型
|
||||
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
|
||||
|
||||
# 集合初始化
|
||||
self.create_collection()
|
||||
|
||||
# 保留 LangChain 向量存储实例(用于兼容)
|
||||
self.vector_store = LangchainQdrantVS(
|
||||
client=self.get_client(),
|
||||
collection_name=self.collection_name,
|
||||
embedding=self.embeddings,
|
||||
vector_name="dense",
|
||||
)
|
||||
|
||||
# ---------- 同步连接管理 ----------
|
||||
def get_client(self) -> QdrantClient:
|
||||
if self._client is None:
|
||||
self._client = create_qdrant_client(timeout=300)
|
||||
self._connection_attempts += 1
|
||||
self._last_connection_time = time.time()
|
||||
logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts)
|
||||
logger.debug("Qdrant 同步客户端已创建 (第 %d 次连接)", self._connection_attempts)
|
||||
return self._client
|
||||
|
||||
def refresh_client(self):
|
||||
"""关闭旧连接,创建新连接。"""
|
||||
if self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
logger.debug("Qdrant 旧连接已关闭")
|
||||
logger.debug("Qdrant 旧同步连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning("关闭 Qdrant 连接时出现异常: %s", e)
|
||||
logger.warning("关闭 Qdrant 同步连接时出现异常: %s", e)
|
||||
finally:
|
||||
self._client = None
|
||||
self._last_connection_time = None
|
||||
|
||||
def check_connection_health(self) -> bool:
|
||||
"""检查连接健康状态,如果连接已失效则自动重建。"""
|
||||
if self._client is None:
|
||||
logger.info("Qdrant 客户端未初始化,将创建新连接")
|
||||
return False
|
||||
|
||||
try:
|
||||
client = self.get_client()
|
||||
client.get_collections()
|
||||
logger.debug("Qdrant 连接健康检查通过")
|
||||
return True
|
||||
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
|
||||
logger.warning("Qdrant 连接健康检查失败: %s", e)
|
||||
self.refresh_client()
|
||||
return False
|
||||
# ---------- 异步连接管理 ----------
|
||||
def get_async_client(self) -> AsyncQdrantClient:
|
||||
if self._async_client is None:
|
||||
self._async_client = create_async_qdrant_client(timeout=300)
|
||||
logger.debug("Qdrant 异步客户端已创建")
|
||||
return self._async_client
|
||||
|
||||
def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""获取连接统计信息。"""
|
||||
return {
|
||||
"connection_attempts": self._connection_attempts,
|
||||
"last_connection_time": self._last_connection_time,
|
||||
"client_initialized": self._client is not None,
|
||||
}
|
||||
async def close_async_client(self):
|
||||
if self._async_client is not None:
|
||||
try:
|
||||
await self._async_client.close()
|
||||
logger.debug("Qdrant 异步连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning("关闭 Qdrant 异步连接时出现异常: %s", e)
|
||||
finally:
|
||||
self._async_client = None
|
||||
|
||||
# ---------- 集合创建(同步,用于初始化) ----------
|
||||
def create_collection(self, force_recreate: bool = False):
|
||||
"""创建集合,支持稠密+稀疏双向量。"""
|
||||
if self._embedder is not None:
|
||||
# 使用内部的 embedder 获取维度
|
||||
vector_size = self._embedder.get_embedding_dimension()
|
||||
else:
|
||||
# 使用外部传入的 embeddings,通过测试获取维度
|
||||
test_embedding = self.embeddings.embed_query("test")
|
||||
vector_size = len(test_embedding)
|
||||
"""创建集合,确保有 'dense' 和 'sparse' 两个命名向量字段。"""
|
||||
vector_size = get_embedding_dimension(self.embeddings)
|
||||
|
||||
max_retries = 3
|
||||
base_delay = 2
|
||||
@@ -130,90 +109,168 @@ class QdrantVectorStore:
|
||||
exists = False
|
||||
|
||||
if not exists:
|
||||
# 向量配置:稠密向量
|
||||
vectors_config = {
|
||||
"dense": VectorParams(
|
||||
size=vector_size,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
}
|
||||
|
||||
# 稀疏向量配置(简化版,不使用特殊索引类型)
|
||||
sparse_vectors_config = {
|
||||
"sparse": SparseVectorParams()
|
||||
}
|
||||
|
||||
client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=vectors_config,
|
||||
sparse_vectors_config=sparse_vectors_config
|
||||
sparse_vectors_config=sparse_vectors_config,
|
||||
)
|
||||
logger.info(
|
||||
"集合 '%s' 已创建(维度=%d,稠密+稀疏双向量)",
|
||||
self.collection_name, vector_size,
|
||||
)
|
||||
logger.info("集合 '%s' 已创建(维度=%d,支持稠密+稀疏双向量)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("集合 '%s' 已存在", self.collection_name)
|
||||
return
|
||||
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e)
|
||||
raise
|
||||
wait_time = base_delay * (2 ** attempt)
|
||||
error_type = type(e).__name__
|
||||
logger.warning(
|
||||
"创建集合 '%s' 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s",
|
||||
self.collection_name, error_type, wait_time, attempt + 1, max_retries, e
|
||||
self.collection_name, type(e).__name__, wait_time, attempt + 1, max_retries, e,
|
||||
)
|
||||
self.refresh_client()
|
||||
logger.debug("已刷新 Qdrant 客户端连接")
|
||||
time.sleep(wait_time)
|
||||
|
||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||
"""将文档添加到向量数据库,自动生成稠密+稀疏双向量。"""
|
||||
# ---------- 异步索引方法 ----------
|
||||
async def aadd_documents(self, documents: List[Document], batch_size: int = 100) -> List[str]:
|
||||
"""
|
||||
异步添加文档(自动生成稠密+稀疏向量并批量写入)。
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# 确保集合存在
|
||||
self.create_collection()
|
||||
client = self.get_client()
|
||||
doc_ids = []
|
||||
|
||||
# 分批处理
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i:i+batch_size]
|
||||
texts = [doc.page_content for doc in batch_docs]
|
||||
|
||||
# 生成双向量
|
||||
dense_vectors = self.embeddings.embed_documents(texts)
|
||||
sparse_vectors = self.sparse_embedder.embed_documents(texts)
|
||||
all_ids = []
|
||||
total_docs = len(documents)
|
||||
|
||||
points = []
|
||||
for j, doc in enumerate(batch_docs):
|
||||
point_id = doc.metadata.get("id", str(uuid.uuid4()))
|
||||
doc_ids.append(point_id)
|
||||
for i in range(0, total_docs, batch_size):
|
||||
batch = documents[i:i+batch_size]
|
||||
batch_ids = await self._aadd_batch(batch)
|
||||
all_ids.extend(batch_ids)
|
||||
logger.info("已向 '%s' 添加批次 %d/%d,共 %d 个文档",
|
||||
self.collection_name,
|
||||
i//batch_size + 1,
|
||||
(total_docs + batch_size - 1)//batch_size,
|
||||
len(batch))
|
||||
|
||||
# 构造双向量
|
||||
named_vectors = {
|
||||
"dense": dense_vectors[j],
|
||||
"sparse": sparse_vectors[j]
|
||||
}
|
||||
logger.info("已向 '%s' 总共添加 %d 个文档(混合模式)", self.collection_name, len(all_ids))
|
||||
return all_ids
|
||||
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=named_vectors,
|
||||
payload={"text": doc.page_content, **doc.metadata}
|
||||
))
|
||||
async def _aadd_batch(self, documents: List[Document]) -> List[str]:
|
||||
"""异步添加单个批次的文档"""
|
||||
client = self.get_async_client()
|
||||
|
||||
# 批量插入
|
||||
client.upsert(collection_name=self.collection_name, points=points)
|
||||
logger.info("已向 '%s' 添加 %d 个文档(稠密+稀疏双向量)", self.collection_name, len(points))
|
||||
# 提取文本
|
||||
texts = [doc.page_content for doc in documents]
|
||||
|
||||
return doc_ids
|
||||
# 生成稠密向量
|
||||
dense_vectors = await self._aembed_texts(texts)
|
||||
|
||||
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""基础稠密向量检索(兼容原有接口)。"""
|
||||
return self.vector_store.similarity_search(query, k=k)
|
||||
# 生成稀疏向量
|
||||
sparse_vectors = self.sparse_embedder.embed_documents(texts)
|
||||
|
||||
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
|
||||
"""基础稠密向量检索带分数(兼容原有接口)。"""
|
||||
return self.vector_store.similarity_search_with_score(query, k=k)
|
||||
# 构建点结构
|
||||
points = []
|
||||
for doc, dense_vec, sparse_vec in zip(documents, dense_vectors, sparse_vectors):
|
||||
point_id = str(uuid.uuid4())
|
||||
payload = {
|
||||
"page_content": doc.page_content,
|
||||
**doc.metadata
|
||||
}
|
||||
point = PointStruct(
|
||||
id=point_id,
|
||||
vector={
|
||||
"dense": dense_vec,
|
||||
"sparse": models.SparseVector(
|
||||
indices=sparse_vec["indices"],
|
||||
values=sparse_vec["values"]
|
||||
)
|
||||
},
|
||||
payload=payload
|
||||
)
|
||||
points.append(point)
|
||||
|
||||
# 写入 Qdrant
|
||||
await client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=points
|
||||
)
|
||||
|
||||
return [p.id for p in points]
|
||||
|
||||
async def _aembed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""异步生成稠密向量(适配同步 Embeddings 接口)"""
|
||||
# 注意:LangChain 的 Embeddings 接口目前主要是同步的
|
||||
# 使用线程池或直接调用(如果 embedding 内部有异步支持)
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_documents, texts)
|
||||
|
||||
# ---------- 异步检索方法 ----------
|
||||
async def asimilarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""
|
||||
异步混合检索(稠密 + 稀疏),返回文档列表。
|
||||
使用 Qdrant 的 Universal Query API + RRF 融合。
|
||||
"""
|
||||
client = self.get_async_client()
|
||||
|
||||
# 生成查询向量
|
||||
dense_query = await self._aembed_query(query)
|
||||
sparse_query = self.sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 使用 Qdrant 的 query_points API
|
||||
response = await client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
prefetch=[
|
||||
models.Prefetch(
|
||||
query=dense_query,
|
||||
using="dense",
|
||||
limit=k
|
||||
),
|
||||
models.Prefetch(
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=k
|
||||
)
|
||||
],
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
limit=k,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# 转换结果
|
||||
results = []
|
||||
for point in response.points:
|
||||
page_content = point.payload.pop("page_content", "")
|
||||
doc = Document(page_content=page_content, metadata=point.payload)
|
||||
results.append(doc)
|
||||
|
||||
logger.debug("混合检索返回 %d 个文档", len(results))
|
||||
return results
|
||||
|
||||
async def _aembed_query(self, text: str) -> List[float]:
|
||||
"""异步生成查询稠密向量"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_query, text)
|
||||
|
||||
# ---------- 同步管理方法(保留,用于初始化和管理) ----------
|
||||
def delete_collection(self):
|
||||
self.get_client().delete_collection(self.collection_name)
|
||||
logger.info("集合 '%s' 已删除", self.collection_name)
|
||||
@@ -233,13 +290,10 @@ class QdrantVectorStore:
|
||||
"vector_size": vector_size,
|
||||
}
|
||||
|
||||
def as_langchain_vectorstore(self):
|
||||
return self.vector_store
|
||||
|
||||
def get_langchain_vectorstore(self):
|
||||
"""返回 LangChain Qdrant 向量存储对象(别名)"""
|
||||
return self.vector_store
|
||||
|
||||
def get_qdrant_client(self):
|
||||
"""返回原生 Qdrant 客户端(用于自定义检索逻辑)"""
|
||||
"""返回原生 Qdrant 同步客户端(用于管理操作)。"""
|
||||
return self.get_client()
|
||||
|
||||
def get_async_qdrant_client(self):
|
||||
"""返回原生 Qdrant 异步客户端(用于索引和检索)。"""
|
||||
return self.get_async_client()
|
||||
|
||||
Reference in New Issue
Block a user