This commit is contained in:
@@ -19,10 +19,10 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
|
||||
from rag_core.client import create_async_qdrant_client
|
||||
from app.model_services import get_embedding_service
|
||||
from app.logger import info, warning, debug
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
|
||||
from backend.rag_core.client import create_async_qdrant_client
|
||||
from ..model_services import get_embedding_service
|
||||
from ..logger import info, warning, debug
|
||||
|
||||
|
||||
# 模块级常量
|
||||
@@ -131,20 +131,20 @@ class HybridRetriever(BaseRetriever):
|
||||
class ParentHybridRetriever(BaseRetriever):
|
||||
"""
|
||||
父子文档混合检索器(异步):
|
||||
|
||||
|
||||
1. 先用混合检索找到相关子文档
|
||||
2. 根据子文档的 parent_id 找到对应的父文档
|
||||
3. 去重并返回父文档
|
||||
"""
|
||||
|
||||
|
||||
collection_name: str = Field(description="Qdrant 集合名称")
|
||||
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
|
||||
|
||||
|
||||
_vector_store: Any = PrivateAttr()
|
||||
_client: Any = PrivateAttr()
|
||||
_sparse_embedder: Any = PrivateAttr()
|
||||
_docstore: Any = PrivateAttr()
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
@@ -188,7 +188,7 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
异步检索相关父文档
|
||||
异步检索相关子文档
|
||||
"""
|
||||
# 1. 生成查询向量
|
||||
dense_query = await self._vector_store.aembed_query(query)
|
||||
@@ -197,10 +197,10 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
|
||||
# 2. 多取一些子文档,避免去重后数量不足
|
||||
search_limit = self.search_k * 2
|
||||
|
||||
|
||||
# 3. 使用 query_points API 进行混合检索
|
||||
response = await self._client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
@@ -220,87 +220,27 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
limit=search_limit,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
|
||||
if not response.points:
|
||||
debug("混合检索未找到任何文档")
|
||||
return []
|
||||
|
||||
# 4. 收集 parent_id 和对应最高得分
|
||||
parent_score_map = {}
|
||||
parent_ids = set()
|
||||
child_point_map = {} # 保存子文档点用于降级
|
||||
|
||||
|
||||
# 4. 构建子文档列表
|
||||
child_docs = []
|
||||
for point in response.points:
|
||||
payload_copy = point.payload.copy()
|
||||
parent_id = payload_copy.get("parent_id", point.id)
|
||||
score = point.score
|
||||
|
||||
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
|
||||
parent_score_map[parent_id] = score
|
||||
parent_ids.add(parent_id)
|
||||
child_point_map[parent_id] = point
|
||||
|
||||
# 5. 批量查询父文档
|
||||
parent_docs = []
|
||||
found_parent_ids = set()
|
||||
|
||||
# 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中)
|
||||
try:
|
||||
parent_points = await self._client.retrieve(
|
||||
collection_name=self.collection_name,
|
||||
ids=list(parent_ids),
|
||||
with_payload=True
|
||||
doc = Document(
|
||||
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
|
||||
metadata={
|
||||
**payload_copy,
|
||||
"child_id": point.id,
|
||||
"score": point.score
|
||||
}
|
||||
)
|
||||
|
||||
for point in parent_points:
|
||||
payload_copy = point.payload.copy()
|
||||
doc = Document(
|
||||
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
|
||||
metadata=payload_copy
|
||||
)
|
||||
parent_docs.append(doc)
|
||||
found_parent_ids.add(point.id)
|
||||
|
||||
except Exception as e:
|
||||
warning(f"从 Qdrant 查询父文档失败: {e}")
|
||||
|
||||
# 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档
|
||||
if self._docstore and len(found_parent_ids) < len(parent_ids):
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
try:
|
||||
docstore_docs = await self._docstore.amget(missing_parent_ids)
|
||||
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
|
||||
if doc is not None:
|
||||
parent_docs.append(doc)
|
||||
found_parent_ids.add(doc_id)
|
||||
except Exception as e:
|
||||
warning(f"从 docstore 查询父文档失败: {e}")
|
||||
|
||||
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
if missing_parent_ids:
|
||||
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
|
||||
for parent_id in missing_parent_ids:
|
||||
child_point = child_point_map.get(parent_id)
|
||||
if child_point:
|
||||
payload_copy = child_point.payload.copy()
|
||||
doc = Document(
|
||||
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
|
||||
metadata=payload_copy
|
||||
)
|
||||
parent_docs.append(doc)
|
||||
|
||||
# 8. 按照得分降序排序,返回前 k 个
|
||||
parent_docs_with_scores = [
|
||||
(doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0))
|
||||
for doc in parent_docs
|
||||
]
|
||||
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
|
||||
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
|
||||
|
||||
return final_docs
|
||||
child_docs.append(doc)
|
||||
|
||||
debug(f"父子文档混合检索返回 {len(child_docs)} 个子文档")
|
||||
return child_docs
|
||||
|
||||
|
||||
def create_hybrid_retriever(
|
||||
|
||||
Reference in New Issue
Block a user