This commit is contained in:
@@ -81,11 +81,17 @@ class RAGPipeline:
|
||||
return await self.retriever.ainvoke(query)
|
||||
|
||||
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
|
||||
parent_map = {}
|
||||
# 收集 parent_id 和对应的分数
|
||||
parent_map = {} # parent_id -> (embedding_score, rerank_score)
|
||||
|
||||
for doc in child_docs:
|
||||
pid = doc.metadata.get("parent_id")
|
||||
if pid and pid not in parent_map:
|
||||
parent_map[pid] = doc.metadata.get("score", 0.0)
|
||||
# embedding 分数
|
||||
embedding_score = doc.metadata.get("score", 0.0)
|
||||
# rerank 分数(如果有的话)
|
||||
rerank_score = doc.metadata.get("rerank_score", 0.0)
|
||||
parent_map[pid] = (embedding_score, rerank_score)
|
||||
|
||||
if not parent_map:
|
||||
logger.warning("[Pipeline] 未找到 parent_id,返回子文档")
|
||||
@@ -94,10 +100,19 @@ class RAGPipeline:
|
||||
try:
|
||||
from backend.rag_core import create_docstore
|
||||
docstore, _ = create_docstore()
|
||||
# 同步获取(异步版本不存在)
|
||||
parent_docs = docstore.mget(list(parent_map.keys()))
|
||||
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
|
||||
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2]
|
||||
|
||||
# 构建结果,保持分数信息
|
||||
result = []
|
||||
for doc in parent_docs:
|
||||
if doc:
|
||||
pid = doc.metadata.get("id")
|
||||
scores = parent_map.get(pid, (0.0, 0.0))
|
||||
# 将分数添加到 metadata 中
|
||||
doc.metadata["embedding_score"] = scores[0]
|
||||
doc.metadata["rerank_score"] = scores[1]
|
||||
result.append((doc, scores[0] + scores[1] * 2)) # 综合分数,rerank 权重更高
|
||||
|
||||
result.sort(key=lambda x: x[1], reverse=True)
|
||||
docs = [d for d, _ in result]
|
||||
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
|
||||
|
||||
@@ -49,44 +49,38 @@ class DocumentReranker:
|
||||
top_n: 返回前 N 个结果
|
||||
|
||||
Returns:
|
||||
List[Document]: 排序后的文档列表
|
||||
List[Document]: 排序后的文档列表,每个文档的 metadata 中包含 rerank_score
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 1. 从 Document 提取内容(业务逻辑)
|
||||
# 1. 从 Document 提取内容
|
||||
doc_contents = [doc.page_content for doc in documents]
|
||||
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}")
|
||||
total_chars = sum(len(c) for c in doc_contents)
|
||||
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
|
||||
# 粗略估算 tokens (中文约 0.75 tokens/字符)
|
||||
estimated_tokens = int(total_chars * 0.75)
|
||||
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
|
||||
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排")
|
||||
|
||||
# 2. 调用纯服务层计算得分
|
||||
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
|
||||
# 2. 调用重排服务计算得分
|
||||
scores = self._rerank_service.compute_scores(query, doc_contents)
|
||||
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}")
|
||||
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分")
|
||||
|
||||
# 3. 根据得分排序(业务逻辑)
|
||||
# 3. 构建 (文档, 分数) 对并排序
|
||||
doc_score_pairs = list(zip(documents, scores))
|
||||
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|
||||
|
||||
logger.info(f"[Rerank] 排序后的结果:")
|
||||
for i, (doc, score) in enumerate(doc_score_pairs_sorted):
|
||||
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...")
|
||||
|
||||
# 4. 取 top_n
|
||||
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]]
|
||||
# 4. 取 top_n,并添加 rerank_score 到 metadata
|
||||
top_docs = []
|
||||
for doc, score in doc_score_pairs_sorted[:top_n]:
|
||||
# 创建新文档,添加 rerank_score
|
||||
new_doc = Document(
|
||||
page_content=doc.page_content,
|
||||
metadata={**doc.metadata, "rerank_score": score}
|
||||
)
|
||||
top_docs.append(new_doc)
|
||||
|
||||
return top_docs
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}")
|
||||
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
|
||||
logger.warning(f"[Rerank] 重排失败,返回原始结果: {e}")
|
||||
return documents[:top_n]
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from pydantic import Field, PrivateAttr
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
|
||||
from backend.rag_core.client import create_async_qdrant_client
|
||||
from ..model_services import get_embedding_service
|
||||
from ..logger import info, warning, debug
|
||||
from backend.app.logger import info, warning, debug
|
||||
|
||||
|
||||
# 模块级常量
|
||||
|
||||
Reference in New Issue
Block a user