2026-04-21 11:02:16 +08:00
|
|
|
|
# rag/fusion.py
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Dict
|
|
|
|
|
|
from langchain_core.documents import Document
|
2026-05-06 16:15:09 +08:00
|
|
|
|
from backend.app.logger import info
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
2026-05-06 16:02:53 +08:00
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
def reciprocal_rank_fusion(
|
|
|
|
|
|
doc_lists: List[List[Document]],
|
|
|
|
|
|
k: int = 60
|
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
对多个检索结果列表进行 RRF 融合。
|
2026-05-06 16:15:09 +08:00
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
Args:
|
|
|
|
|
|
doc_lists: 多个检索结果列表,每个列表来自一个查询
|
|
|
|
|
|
k: RRF 常数,通常设为 60
|
2026-05-06 16:15:09 +08:00
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
Returns:
|
|
|
|
|
|
融合后按 RRF 得分降序排列的文档列表
|
|
|
|
|
|
"""
|
2026-05-06 16:15:09 +08:00
|
|
|
|
info(f"[RRF] reciprocal_rank_fusion 开始: {len(doc_lists)} 组文档")
|
2026-04-21 11:02:16 +08:00
|
|
|
|
# 使用文档内容作为唯一标识(如果内容相同但 metadata 不同,视为同一文档)
|
|
|
|
|
|
# 更好的做法是用 docstore 的 ID,这里简化处理:用内容 hash
|
|
|
|
|
|
doc_to_score: Dict[str, float] = {}
|
|
|
|
|
|
doc_map: Dict[str, Document] = {}
|
2026-05-06 16:15:09 +08:00
|
|
|
|
|
2026-05-06 16:02:53 +08:00
|
|
|
|
for list_idx, docs in enumerate(doc_lists):
|
2026-05-06 16:15:09 +08:00
|
|
|
|
info(f"[RRF] 处理第 {list_idx} 组: {len(docs)} 个文档")
|
2026-04-21 11:02:16 +08:00
|
|
|
|
for rank, doc in enumerate(docs, start=1):
|
|
|
|
|
|
# 生成唯一标识符(内容+来源组合,避免不同文件相同内容混淆)
|
|
|
|
|
|
doc_id = f"{doc.page_content[:200]}_{doc.metadata.get('source', '')}"
|
|
|
|
|
|
if doc_id not in doc_map:
|
|
|
|
|
|
doc_map[doc_id] = doc
|
|
|
|
|
|
score = doc_to_score.get(doc_id, 0.0) + 1.0 / (k + rank)
|
|
|
|
|
|
doc_to_score[doc_id] = score
|
2026-05-06 16:15:09 +08:00
|
|
|
|
|
|
|
|
|
|
info(f"[RRF] 去重后共 {len(doc_map)} 个唯一文档")
|
2026-04-21 11:02:16 +08:00
|
|
|
|
# 按得分排序
|
|
|
|
|
|
sorted_ids = sorted(doc_to_score.keys(), key=lambda x: doc_to_score[x], reverse=True)
|
2026-05-06 16:02:53 +08:00
|
|
|
|
result = [doc_map[doc_id] for doc_id in sorted_ids]
|
2026-05-06 16:15:09 +08:00
|
|
|
|
info(f"[RRF] reciprocal_rank_fusion 结束: 返回 {len(result)} 个文档")
|
|
|
|
|
|
return result
|