refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s

This commit is contained in:
2026-05-04 17:58:10 +08:00
parent a07e398739
commit 9841f47432
31 changed files with 578 additions and 1496 deletions

View File

@@ -33,16 +33,16 @@ DEFAULT_PARENT_SEARCH_K = 5
class HybridRetriever(BaseRetriever):
"""
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合(异步)
使用 Qdrant Universal Query API (query_points)
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
@@ -62,21 +62,39 @@ class HybridRetriever(BaseRetriever):
self._vector_store = vector_store
self._client = vector_store.get_async_qdrant_client()
self._sparse_embedder = get_sparse_embedder()
def _get_relevant_documents(
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
同步检索(不推荐使用,仅供兼容性)
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
"""
import asyncio
try:
loop = asyncio.get_running_loop()
# 已有事件循环,使用 create_task
task = loop.create_task(self._aget_relevant_documents(query))
return loop.run_until_complete(task)
except RuntimeError:
# 没有事件循环,创建新的
return asyncio.run(self._aget_relevant_documents(query))
async def _aget_relevant_documents(
self, query: str, **kwargs
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
异步混合检索相关文档
"""
# 1. 生成查询向量
dense_query = await self._vector_store._aembed_query(query)
dense_query = await self._vector_store.aembed_query(query)
sparse_query = self._sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 使用 Qdrant 的 query_points API
response = await self._client.query_points(
collection_name=self.collection_name,
@@ -96,7 +114,7 @@ class HybridRetriever(BaseRetriever):
limit=self.search_k,
with_payload=True
)
# 3. 转换结果
results = []
for point in response.points:
@@ -105,28 +123,28 @@ class HybridRetriever(BaseRetriever):
metadata=point.payload
)
results.append(doc)
debug(f"混合检索返回 %d 个文档", len(results))
debug(f"混合检索返回 {len(results)} 个文档")
return results
class ParentHybridRetriever(BaseRetriever):
"""
父子文档混合检索器(异步):
1. 先用混合检索找到相关子文档
2. 根据子文档的 parent_id 找到对应的父文档
3. 去重并返回父文档
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
_docstore: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
@@ -149,24 +167,40 @@ class ParentHybridRetriever(BaseRetriever):
self._client = vector_store.get_async_qdrant_client()
self._sparse_embedder = get_sparse_embedder()
self._docstore = docstore
def _get_relevant_documents(
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
同步检索(不推荐使用,仅供兼容性)
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
"""
import asyncio
try:
loop = asyncio.get_running_loop()
task = loop.create_task(self._aget_relevant_documents(query))
return loop.run_until_complete(task)
except RuntimeError:
return asyncio.run(self._aget_relevant_documents(query))
async def _aget_relevant_documents(
self, query: str, **kwargs
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
异步检索相关父文档
"""
# 1. 生成查询向量
dense_query = await self._vector_store._aembed_query(query)
dense_query = await self._vector_store.aembed_query(query)
sparse_query = self._sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 多取一些子文档,避免去重后数量不足
search_limit = self.search_k * 2
# 3. 使用 query_points API 进行混合检索
response = await self._client.query_points(
collection_name=self.collection_name,
@@ -186,30 +220,30 @@ class ParentHybridRetriever(BaseRetriever):
limit=search_limit,
with_payload=True
)
if not response.points:
debug("混合检索未找到任何文档")
return []
# 4. 收集 parent_id 和对应最高得分
parent_score_map = {}
parent_ids = set()
child_point_map = {} # 保存子文档点用于降级
for point in response.points:
payload_copy = point.payload.copy()
parent_id = payload_copy.get("parent_id", point.id)
score = point.score
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
parent_ids.add(parent_id)
child_point_map[parent_id] = point
# 5. 批量查询父文档
parent_docs = []
found_parent_ids = set()
# 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中)
try:
parent_points = await self._client.retrieve(
@@ -217,7 +251,7 @@ class ParentHybridRetriever(BaseRetriever):
ids=list(parent_ids),
with_payload=True
)
for point in parent_points:
payload_copy = point.payload.copy()
doc = Document(
@@ -226,10 +260,10 @@ class ParentHybridRetriever(BaseRetriever):
)
parent_docs.append(doc)
found_parent_ids.add(point.id)
except Exception as e:
warning(f"从 Qdrant 查询父文档失败: %s", e)
warning(f"从 Qdrant 查询父文档失败: {e}")
# 6. 如果有 docstore尝试从 docstore 查询剩余的父文档
if self._docstore and len(found_parent_ids) < len(parent_ids):
missing_parent_ids = parent_ids - found_parent_ids
@@ -240,12 +274,12 @@ class ParentHybridRetriever(BaseRetriever):
parent_docs.append(doc)
found_parent_ids.add(doc_id)
except Exception as e:
warning(f"从 docstore 查询父文档失败: %s", e)
warning(f"从 docstore 查询父文档失败: {e}")
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
missing_parent_ids = parent_ids - found_parent_ids
if missing_parent_ids:
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: %s", missing_parent_ids)
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
for parent_id in missing_parent_ids:
child_point = child_point_map.get(parent_id)
if child_point:
@@ -255,17 +289,17 @@ class ParentHybridRetriever(BaseRetriever):
metadata=payload_copy
)
parent_docs.append(doc)
# 8. 按照得分降序排序,返回前 k 个
parent_docs_with_scores = [
(doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0))
for doc in parent_docs
]
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
debug(f"父子文档混合检索返回 %d 个父文档", len(final_docs))
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
return final_docs
@@ -291,7 +325,7 @@ def create_hybrid_retriever(
embeddings = get_embedding_service()
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
vector_store = QdrantHybridStore(collection_name=collection_name)
try:
vector_store.get_client().get_collection(collection_name)
@@ -336,7 +370,7 @@ def create_parent_hybrid_retriever(
embeddings = get_embedding_service()
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
vector_store = QdrantHybridStore(collection_name=collection_name)
try:
vector_store.get_client().get_collection(collection_name)