refactor!: 完全异步化 RAG 系统,移除 LangChain ParentDocumentRetriever 依赖
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m34s

- 重写 rag_core/vector_store.py:完全异步实现 aadd_documents、asimilarity_search
- 重写 app/rag/retriever.py:异步混合检索,移除同步兼容代码
- 修改 rag_indexer/index_builder.py:全链路异步调用
- 删除 rag_core/retriever_factory.py:不再使用 LangChain ParentDocumentRetriever
- 清理冗余导入和代码:移除 model_services 兼容、不需要的异常导入
- 更新 rag_indexer/README.md:反映新架构

核心改进:
- 完全异步化:索引构建和检索全链路 async/await
- 自定义实现:不再依赖 LangChain 的 ParentDocumentRetriever
- 双向量支持:子文档同时存储 dense + sparse 向量到 Qdrant
- 架构清晰:rag_core 公共组件、rag_indexer 索引、app/rag 检索
This commit is contained in:
2026-05-04 14:33:12 +08:00
parent 4209386c77
commit a07e398739
14 changed files with 651 additions and 592 deletions

View File

@@ -1,14 +1,13 @@
"""
RAG Core - 公共 RAG 组件包
提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。
"""
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .embedders import get_embeddings, get_embedding_dimension
from .vector_store import QdrantHybridStore
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
from .store import PostgresDocStore, create_docstore
from .retriever_factory import create_parent_retriever
from .client import create_qdrant_client, create_async_qdrant_client
from .config import (
QDRANT_URL,
QDRANT_API_KEY,
@@ -20,8 +19,9 @@ from .config import (
__all__ = [
"LlamaCppEmbedder",
"QdrantVectorStore",
"get_embeddings",
"get_embedding_dimension",
"QdrantHybridStore",
"BM25SparseEmbedder",
"get_sparse_embedder",
"QDRANT_URL",
@@ -32,5 +32,6 @@ __all__ = [
"DOCSTORE_URI",
"PostgresDocStore",
"create_docstore",
"create_parent_retriever",
"create_qdrant_client",
"create_async_qdrant_client",
]

View File

@@ -1,7 +1,7 @@
# rag_core/client.py
import os
from .config import QDRANT_URL, QDRANT_API_KEY
from qdrant_client import QdrantClient
from qdrant_client import QdrantClient, AsyncQdrantClient
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
@@ -28,3 +28,29 @@ def create_qdrant_client(timeout: int = 300) -> QdrantClient:
client_kwargs["api_key"] = QDRANT_API_KEY
return QdrantClient(**client_kwargs)
def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient:
"""
创建并返回一个配置好的 Qdrant 异步客户端。
Args:
timeout: 请求超时时间(秒),默认 300 秒。
Returns:
配置好的 AsyncQdrantClient 实例。
Raises:
ValueError: 如果 QDRANT_URL 未配置。
"""
if not QDRANT_URL:
raise ValueError("Qdrant URL 未配置")
client_kwargs = {
"url": QDRANT_URL,
"timeout": timeout,
}
if QDRANT_API_KEY:
client_kwargs["api_key"] = QDRANT_API_KEY
return AsyncQdrantClient(**client_kwargs)

View File

@@ -1,121 +1,37 @@
"""
嵌入模型包装器 - 直接使用统一嵌入服务
支持自动降级(本地 llama.cpp → 智谱),由 get_embedding_service() 内部处理
"""
import sys
import logging
from typing import List
from pathlib import Path
from typing import List, Optional
from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
from langchain_core.embeddings import Embeddings
logger = logging.getLogger(__name__)
class LlamaCppEmbedder:
def get_embeddings() -> Embeddings:
"""
嵌入器包装类 - 直接使用统一的 get_embedding_service()
降级逻辑完全由 app.model_services 处理
获取统一的嵌入服务实例。
Returns:
LangChain 兼容的 Embeddings 实例
"""
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0", use_fallback: bool = True):
"""
Args:
model: 嵌入模型名称(向后兼容,现在实际使用统一服务)
use_fallback: 是否使用降级机制(保留参数,现在始终为 True
"""
self.model = model
self._fallback_embeddings = None
# 直接获取统一嵌入服务
try:
from backend.app.model_services import get_embedding_service
self._fallback_embeddings = get_embedding_service()
logger.info("✅ 统一嵌入服务加载成功")
except Exception as e:
logger.warning(f"⚠️ 无法加载统一嵌入服务: {e}")
# 保留向后兼容的初始化
self.base_url = LLAMACPP_EMBEDDING_URL
self.api_key = LLAMACPP_API_KEY
def as_langchain_embeddings(self) -> Embeddings:
"""创建 LangChain 兼容的嵌入实例"""
if self._fallback_embeddings:
logger.info("✅ 使用统一嵌入服务(已内置降级机制)")
return self._fallback_embeddings
# 向后兼容,仅在统一服务不可用时使用传统方式
logger.warning("⚠️ 统一服务不可用,使用传统模式(不推荐)")
return _LlamaCppLangchainAdapter(self)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""嵌入一批文档"""
if self._fallback_embeddings:
return self._fallback_embeddings.embed_documents(texts)
# 向后兼容
return self._call_embedding_api(texts)
def embed_query(self, text: str) -> List[float]:
"""嵌入单个查询"""
if self._fallback_embeddings:
return self._fallback_embeddings.embed_query(text)
# 向后兼容
return self._call_embedding_api([text])[0]
def get_embedding_dimension(self) -> int:
"""通过嵌入测试字符串获取嵌入维度"""
test_embedding = self.embed_query("test")
return len(test_embedding)
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
"""仅作为向后兼容的备用方法"""
import httpx
if not hasattr(self, 'base_url') or not self.base_url:
raise ValueError("LLAMACPP_EMBEDDING_URL 未配置且统一服务不可用")
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
base = self.base_url.rstrip("/")
if not base.endswith("/v1"):
base = base + "/v1"
payload = {
"input": texts,
"model": self.model,
}
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/embeddings",
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if isinstance(data, list):
return [item["embedding"] for item in data]
elif isinstance(data, dict) and "data" in data:
return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])]
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
from backend.app.model_services import get_embedding_service
return get_embedding_service()
class _LlamaCppLangchainAdapter(Embeddings):
"""仅作为向后兼容的适配器"""
def __init__(self, embedder: LlamaCppEmbedder):
self._embedder = embedder
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embedder.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
return self._embedder.embed_query(text)
def get_embedding_dimension(embeddings: Optional[Embeddings] = None) -> int:
"""
获取嵌入维度。
Args:
embeddings: 可选的嵌入实例,如果不提供则自动获取
Returns:
嵌入维度大小
"""
if embeddings is None:
embeddings = get_embeddings()
test_embedding = embeddings.embed_query("test")
return len(test_embedding)

View File

@@ -1,112 +0,0 @@
"""
RAG 检索器工厂模块
提供创建各种检索器的工厂函数,包括:
- 基础向量检索器
- ParentDocumentRetriever父子文档
- 混合检索器(稠密+稀疏)
"""
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_core.stores import BaseStore
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore
from .store import create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
search_k: int = 5,
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
embeddings: Optional[Embeddings] = None,
) -> ParentDocumentRetriever:
"""
创建 ParentDocumentRetriever 实例(基础稠密向量版本)。
Args:
collection_name: Qdrant 集合名称,默认 "rag_documents"
parent_splitter: 父文档切分器,默认 None使用默认参数创建
child_splitter: 子文档切分器,默认 None使用默认参数创建
docstore: 文档存储实例,默认 None使用默认参数创建
search_k: 检索时返回的结果数,默认 5
parent_chunk_size: 父文档块大小,默认 1000
parent_chunk_overlap: 父文档块重叠大小,默认 100
child_chunk_size: 子文档块大小,默认 200
child_chunk_overlap: 子文档块重叠大小,默认 20
embeddings: 嵌入模型实例,默认 None使用内部默认的 LocalLlamaCppEmbedder
Returns:
ParentDocumentRetriever 实例
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 向量存储(只读)
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 切分器(若未提供则创建默认)
if parent_splitter is None:
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=parent_chunk_size,
chunk_overlap=parent_chunk_overlap,
)
if child_splitter is None:
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=child_chunk_size,
chunk_overlap=child_chunk_overlap,
)
# 文档存储
if docstore is None:
docstore, _ = create_docstore()
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
)
def create_hybrid_retriever_factory(
collection_name: str = "rag_documents",
search_k: int = 5,
embeddings: Optional[Embeddings] = None,
) -> BaseRetriever:
"""
【不完整,仅占位】创建混合检索器的工厂函数占位符。
注意:完整的混合检索逻辑在 app/rag/retriever.py 中实现。
这里仅返回 QdrantVectorStore 作为基础。
Args:
collection_name: Qdrant 集合名称
search_k: 检索返回结果数
embeddings: 嵌入模型实例
Returns:
基础的 QdrantVectorStore仅稠密检索
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 创建向量存储
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
# 返回 LangChain 兼容的 retriever
return vector_store.get_langchain_vectorstore().as_retriever(search_kwargs={"k": search_k})

View File

@@ -1,6 +1,5 @@
"""
Qdrant 向量数据库包装器。
支持稠密+稀疏双向量存储。
Qdrant 向量数据库包装器(完全异步实现)
"""
import logging
@@ -11,111 +10,91 @@ from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client import AsyncQdrantClient, QdrantClient
from qdrant_client.http.models import (
Distance, VectorParams, SparseVectorParams, PointStruct
Distance, VectorParams, SparseVectorParams, PointStruct, models
)
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from .client import create_qdrant_client
from .embedders import LlamaCppEmbedder
from .client import create_qdrant_client, create_async_qdrant_client
from .embedders import get_embeddings, get_embedding_dimension
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
logger = logging.getLogger(__name__)
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器 - 支持稠密+稀疏双向量存储。"""
class QdrantHybridStore:
"""
Qdrant 向量数据库操作包装器 - 稠密+稀疏混合检索(完全异步)。
直接使用 Qdrant 异步客户端实现,不依赖 LangChain。
"""
def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None, sparse_embedder: Optional[BM25SparseEmbedder] = None):
"""
Args:
collection_name: Qdrant 集合名称。
embeddings: 嵌入模型实例,默认 None使用内部默认的 LlamaCppEmbedder
sparse_embedder: 稀疏嵌入模型实例,默认 None自动加载BM25
"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Embeddings] = None,
sparse_embedder: Optional[BM25SparseEmbedder] = None,
):
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
self._async_client: Optional[AsyncQdrantClient] = None
self._connection_attempts = 0
self._last_connection_time: Optional[float] = None
# 稠密嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
self._embedder = embedder
self.embeddings = get_embeddings()
else:
self.embeddings = embeddings
self._embedder = None
# 稀疏嵌入模型
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
# 集合初始化
self.create_collection()
# 保留 LangChain 向量存储实例(用于兼容)
self.vector_store = LangchainQdrantVS(
client=self.get_client(),
collection_name=self.collection_name,
embedding=self.embeddings,
vector_name="dense",
)
# ---------- 同步连接管理 ----------
def get_client(self) -> QdrantClient:
if self._client is None:
self._client = create_qdrant_client(timeout=300)
self._connection_attempts += 1
self._last_connection_time = time.time()
logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts)
logger.debug("Qdrant 同步客户端已创建 (第 %d 次连接)", self._connection_attempts)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
try:
self._client.close()
logger.debug("Qdrant 旧连接已关闭")
logger.debug("Qdrant 旧同步连接已关闭")
except Exception as e:
logger.warning("关闭 Qdrant 连接时出现异常: %s", e)
logger.warning("关闭 Qdrant 同步连接时出现异常: %s", e)
finally:
self._client = None
self._last_connection_time = None
def check_connection_health(self) -> bool:
"""检查连接健康状态,如果连接已失效则自动重建。"""
if self._client is None:
logger.info("Qdrant 客户端未初始化,将创建新连接")
return False
try:
client = self.get_client()
client.get_collections()
logger.debug("Qdrant 连接健康检查通过")
return True
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
logger.warning("Qdrant 连接健康检查失败: %s", e)
self.refresh_client()
return False
# ---------- 异步连接管理 ----------
def get_async_client(self) -> AsyncQdrantClient:
if self._async_client is None:
self._async_client = create_async_qdrant_client(timeout=300)
logger.debug("Qdrant 异步客户端已创建")
return self._async_client
def get_connection_stats(self) -> Dict[str, Any]:
"""获取连接统计信息。"""
return {
"connection_attempts": self._connection_attempts,
"last_connection_time": self._last_connection_time,
"client_initialized": self._client is not None,
}
async def close_async_client(self):
if self._async_client is not None:
try:
await self._async_client.close()
logger.debug("Qdrant 异步连接已关闭")
except Exception as e:
logger.warning("关闭 Qdrant 异步连接时出现异常: %s", e)
finally:
self._async_client = None
# ---------- 集合创建(同步,用于初始化) ----------
def create_collection(self, force_recreate: bool = False):
"""创建集合,支持稠密+稀疏双向量"""
if self._embedder is not None:
# 使用内部的 embedder 获取维度
vector_size = self._embedder.get_embedding_dimension()
else:
# 使用外部传入的 embeddings通过测试获取维度
test_embedding = self.embeddings.embed_query("test")
vector_size = len(test_embedding)
"""创建集合,确保有 'dense''sparse' 两个命名向量字段"""
vector_size = get_embedding_dimension(self.embeddings)
max_retries = 3
base_delay = 2
@@ -130,90 +109,168 @@ class QdrantVectorStore:
exists = False
if not exists:
# 向量配置:稠密向量
vectors_config = {
"dense": VectorParams(
size=vector_size,
distance=Distance.COSINE
)
}
# 稀疏向量配置(简化版,不使用特殊索引类型)
sparse_vectors_config = {
"sparse": SparseVectorParams()
}
client.create_collection(
collection_name=self.collection_name,
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config
sparse_vectors_config=sparse_vectors_config,
)
logger.info(
"集合 '%s' 已创建(维度=%d,稠密+稀疏双向量)",
self.collection_name, vector_size,
)
logger.info("集合 '%s' 已创建(维度=%d,支持稠密+稀疏双向量)", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
return
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
if attempt == max_retries - 1:
logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e)
raise
wait_time = base_delay * (2 ** attempt)
error_type = type(e).__name__
logger.warning(
"创建集合 '%s' 遇到网络异常 [%s]%d秒后重试 (%d/%d): %s",
self.collection_name, error_type, wait_time, attempt + 1, max_retries, e
self.collection_name, type(e).__name__, wait_time, attempt + 1, max_retries, e,
)
self.refresh_client()
logger.debug("已刷新 Qdrant 客户端连接")
time.sleep(wait_time)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库,自动生成稠密+稀疏双向量。"""
# ---------- 异步索引方法 ----------
async def aadd_documents(self, documents: List[Document], batch_size: int = 100) -> List[str]:
"""
异步添加文档(自动生成稠密+稀疏向量并批量写入)。
"""
if not documents:
return []
# 确保集合存在
self.create_collection()
client = self.get_client()
doc_ids = []
# 分批处理
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i+batch_size]
texts = [doc.page_content for doc in batch_docs]
# 生成双向量
dense_vectors = self.embeddings.embed_documents(texts)
sparse_vectors = self.sparse_embedder.embed_documents(texts)
all_ids = []
total_docs = len(documents)
points = []
for j, doc in enumerate(batch_docs):
point_id = doc.metadata.get("id", str(uuid.uuid4()))
doc_ids.append(point_id)
for i in range(0, total_docs, batch_size):
batch = documents[i:i+batch_size]
batch_ids = await self._aadd_batch(batch)
all_ids.extend(batch_ids)
logger.info("已向 '%s' 添加批次 %d/%d,共 %d 个文档",
self.collection_name,
i//batch_size + 1,
(total_docs + batch_size - 1)//batch_size,
len(batch))
# 构造双向量
named_vectors = {
"dense": dense_vectors[j],
"sparse": sparse_vectors[j]
}
logger.info("已向 '%s' 总共添加 %d 个文档(混合模式)", self.collection_name, len(all_ids))
return all_ids
points.append(PointStruct(
id=point_id,
vector=named_vectors,
payload={"text": doc.page_content, **doc.metadata}
))
async def _aadd_batch(self, documents: List[Document]) -> List[str]:
"""异步添加单个批次的文档"""
client = self.get_async_client()
# 批量插入
client.upsert(collection_name=self.collection_name, points=points)
logger.info("已向 '%s' 添加 %d 个文档(稠密+稀疏双向量)", self.collection_name, len(points))
# 提取文本
texts = [doc.page_content for doc in documents]
return doc_ids
# 生成稠密向量
dense_vectors = await self._aembed_texts(texts)
def similarity_search(self, query: str, k: int = 5) -> List[Document]:
"""基础稠密向量检索(兼容原有接口)。"""
return self.vector_store.similarity_search(query, k=k)
# 生成稀疏向量
sparse_vectors = self.sparse_embedder.embed_documents(texts)
def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]:
"""基础稠密向量检索带分数(兼容原有接口)。"""
return self.vector_store.similarity_search_with_score(query, k=k)
# 构建点结构
points = []
for doc, dense_vec, sparse_vec in zip(documents, dense_vectors, sparse_vectors):
point_id = str(uuid.uuid4())
payload = {
"page_content": doc.page_content,
**doc.metadata
}
point = PointStruct(
id=point_id,
vector={
"dense": dense_vec,
"sparse": models.SparseVector(
indices=sparse_vec["indices"],
values=sparse_vec["values"]
)
},
payload=payload
)
points.append(point)
# 写入 Qdrant
await client.upsert(
collection_name=self.collection_name,
points=points
)
return [p.id for p in points]
async def _aembed_texts(self, texts: List[str]) -> List[List[float]]:
"""异步生成稠密向量(适配同步 Embeddings 接口)"""
# 注意LangChain 的 Embeddings 接口目前主要是同步的
# 使用线程池或直接调用(如果 embedding 内部有异步支持)
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.embeddings.embed_documents, texts)
# ---------- 异步检索方法 ----------
async def asimilarity_search(self, query: str, k: int = 5) -> List[Document]:
"""
异步混合检索(稠密 + 稀疏),返回文档列表。
使用 Qdrant 的 Universal Query API + RRF 融合。
"""
client = self.get_async_client()
# 生成查询向量
dense_query = await self._aembed_query(query)
sparse_query = self.sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 使用 Qdrant 的 query_points API
response = await client.query_points(
collection_name=self.collection_name,
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=k
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=k
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=k,
with_payload=True
)
# 转换结果
results = []
for point in response.points:
page_content = point.payload.pop("page_content", "")
doc = Document(page_content=page_content, metadata=point.payload)
results.append(doc)
logger.debug("混合检索返回 %d 个文档", len(results))
return results
async def _aembed_query(self, text: str) -> List[float]:
"""异步生成查询稠密向量"""
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.embeddings.embed_query, text)
# ---------- 同步管理方法(保留,用于初始化和管理) ----------
def delete_collection(self):
self.get_client().delete_collection(self.collection_name)
logger.info("集合 '%s' 已删除", self.collection_name)
@@ -233,13 +290,10 @@ class QdrantVectorStore:
"vector_size": vector_size,
}
def as_langchain_vectorstore(self):
return self.vector_store
def get_langchain_vectorstore(self):
"""返回 LangChain Qdrant 向量存储对象(别名)"""
return self.vector_store
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(用于自定义检索逻辑)"""
"""返回原生 Qdrant 同步客户端(用于管理操作)。"""
return self.get_client()
def get_async_qdrant_client(self):
"""返回原生 Qdrant 异步客户端(用于索引和检索)。"""
return self.get_async_client()