This commit is contained in:
@@ -36,6 +36,8 @@ class RAGPipeline:
|
||||
self.rerank_top_n = rerank_top_n
|
||||
self.use_rerank = use_rerank
|
||||
self.return_parent_docs = return_parent_docs
|
||||
self._last_docs = [] # 保存最后一次检索的文档
|
||||
self._last_scores = [] # 保存最后一次检索的分数
|
||||
|
||||
if llm == "default_small":
|
||||
try:
|
||||
@@ -49,6 +51,16 @@ class RAGPipeline:
|
||||
self.reranker = create_document_reranker() if use_rerank else None
|
||||
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
|
||||
|
||||
@property
|
||||
def last_docs(self) -> List[Document]:
|
||||
"""获取最后一次检索的文档"""
|
||||
return self._last_docs
|
||||
|
||||
@property
|
||||
def last_scores(self) -> List[dict]:
|
||||
"""获取最后一次检索的分数信息"""
|
||||
return self._last_scores
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
# Step 1: 检索
|
||||
child_docs = await self._retrieve(query)
|
||||
@@ -69,9 +81,24 @@ class RAGPipeline:
|
||||
|
||||
# Step 3: 获取父文档
|
||||
if self.return_parent_docs:
|
||||
return await self._get_parents(child_docs)
|
||||
parent_docs = await self._get_parents(child_docs)
|
||||
# 保存分数信息到 last_scores 供外部访问
|
||||
self._last_scores = self._extract_scores(parent_docs)
|
||||
return parent_docs
|
||||
|
||||
self._last_scores = self._extract_scores(child_docs)
|
||||
return child_docs
|
||||
|
||||
def _extract_scores(self, docs: List[Document]) -> List[dict]:
|
||||
"""提取文档的分数信息"""
|
||||
scores = []
|
||||
for doc in docs:
|
||||
scores.append({
|
||||
"embedding_score": doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0)),
|
||||
"rerank_score": doc.metadata.get("rerank_score", 0.0),
|
||||
})
|
||||
return scores
|
||||
|
||||
async def _retrieve(self, query: str) -> List[Document]:
|
||||
if self.query_generator:
|
||||
queries = await self.query_generator.agenerate(query)
|
||||
@@ -100,7 +127,7 @@ class RAGPipeline:
|
||||
try:
|
||||
from backend.rag_core import create_docstore
|
||||
docstore, _ = create_docstore()
|
||||
parent_docs = docstore.mget(list(parent_map.keys()))
|
||||
parent_docs =await docstore.amget(list(parent_map.keys()))
|
||||
|
||||
# 构建结果,保持分数信息
|
||||
result = []
|
||||
|
||||
Reference in New Issue
Block a user