2026-05-04 02:01:22 +08:00
|
|
|
|
"""
|
2026-05-05 23:17:00 +08:00
|
|
|
|
RAG 检索流水线
|
2026-05-08 00:29:12 +08:00
|
|
|
|
流程: 检索子文档 → 重排 → 获取父文档 → 置信度评估 → 返回
|
2026-05-04 02:01:22 +08:00
|
|
|
|
"""
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
|
|
|
|
|
import asyncio
|
2026-05-08 00:29:12 +08:00
|
|
|
|
import re
|
|
|
|
|
|
from dataclasses import dataclass
|
2026-05-05 23:17:00 +08:00
|
|
|
|
from typing import List
|
2026-04-21 11:02:16 +08:00
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
|
|
2026-05-06 16:15:09 +08:00
|
|
|
|
from backend.app.logger import info, warning
|
2026-05-08 00:29:12 +08:00
|
|
|
|
from ..model_services import get_small_llm_service
|
2026-05-05 23:17:00 +08:00
|
|
|
|
from ..rag.rerank import create_document_reranker
|
|
|
|
|
|
from ..rag.query_transform import MultiQueryGenerator
|
|
|
|
|
|
from ..rag.fusion import reciprocal_rank_fusion
|
|
|
|
|
|
from ..rag.retriever import create_parent_hybrid_retriever
|
|
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
2026-05-08 00:29:12 +08:00
|
|
|
|
@dataclass
|
|
|
|
|
|
class RAGResult:
|
|
|
|
|
|
"""RAG 检索结果(包含置信度)"""
|
|
|
|
|
|
content: str # 格式化后的上下文
|
|
|
|
|
|
documents: List[Document] # 原始文档
|
|
|
|
|
|
confidence: float # 综合置信度 0.0-1.0
|
|
|
|
|
|
scores: dict # 各维度分数 {embedding, rerank, llm, final}
|
|
|
|
|
|
is_useful: bool # 是否可用(confidence >= 0.6)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
class RAGPipeline:
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
2026-05-04 02:01:22 +08:00
|
|
|
|
retriever=None,
|
2026-05-05 23:17:00 +08:00
|
|
|
|
llm: BaseLanguageModel | str = "default_small",
|
2026-04-21 11:02:16 +08:00
|
|
|
|
num_queries: int = 3,
|
|
|
|
|
|
rerank_top_n: int = 5,
|
2026-05-04 02:01:22 +08:00
|
|
|
|
collection_name: str = "rag_documents",
|
2026-05-05 23:17:00 +08:00
|
|
|
|
use_rerank: bool = True,
|
|
|
|
|
|
return_parent_docs: bool = True,
|
2026-05-08 00:29:12 +08:00
|
|
|
|
confidence_threshold: float = 0.6,
|
2026-04-21 11:02:16 +08:00
|
|
|
|
):
|
2026-05-05 23:17:00 +08:00
|
|
|
|
self.retriever = retriever or create_parent_hybrid_retriever(
|
|
|
|
|
|
collection_name=collection_name, search_k=rerank_top_n * 4
|
|
|
|
|
|
)
|
|
|
|
|
|
self.num_queries = num_queries
|
|
|
|
|
|
self.rerank_top_n = rerank_top_n
|
|
|
|
|
|
self.use_rerank = use_rerank
|
|
|
|
|
|
self.return_parent_docs = return_parent_docs
|
2026-05-08 00:29:12 +08:00
|
|
|
|
self.confidence_threshold = confidence_threshold
|
|
|
|
|
|
self._last_docs: List[Document] = []
|
|
|
|
|
|
self._last_scores: List[dict] = []
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
2026-05-04 17:58:10 +08:00
|
|
|
|
if llm == "default_small":
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.llm = get_small_llm_service()
|
2026-05-05 23:17:00 +08:00
|
|
|
|
except Exception:
|
2026-05-04 17:58:10 +08:00
|
|
|
|
self.llm = None
|
|
|
|
|
|
else:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
self.llm = llm if llm else None
|
|
|
|
|
|
|
|
|
|
|
|
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
|
|
|
|
|
|
self.reranker = create_document_reranker() if use_rerank else None
|
2026-05-08 00:29:12 +08:00
|
|
|
|
info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}, threshold={confidence_threshold}")
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
2026-05-06 04:26:06 +08:00
|
|
|
|
@property
|
|
|
|
|
|
def last_docs(self) -> List[Document]:
|
|
|
|
|
|
return self._last_docs
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def last_scores(self) -> List[dict]:
|
|
|
|
|
|
return self._last_scores
|
|
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
async def aretrieve(self, query: str) -> List[Document]:
|
2026-05-08 00:29:12 +08:00
|
|
|
|
"""原接口,保持向后兼容"""
|
|
|
|
|
|
docs = await self._do_retrieve(query)
|
|
|
|
|
|
self._last_docs = docs
|
|
|
|
|
|
self._last_scores = self._extract_scores(docs)
|
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
async def aretrieve_with_confidence(self, query: str, original_query: str = "") -> RAGResult:
|
|
|
|
|
|
"""
|
|
|
|
|
|
带置信度评估的检索
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
query: 检索查询
|
|
|
|
|
|
original_query: 原始用户问题(用于置信度评估)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
RAGResult: 包含内容和置信度的结构化结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
info(f"[Pipeline] aretrieve_with_confidence: query={query[:50]}...")
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 执行检索
|
|
|
|
|
|
docs = await self._do_retrieve(query)
|
|
|
|
|
|
self._last_docs = docs
|
|
|
|
|
|
self._last_scores = self._extract_scores(docs)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 格式化内容
|
|
|
|
|
|
content = self.format_context(docs)
|
|
|
|
|
|
|
|
|
|
|
|
if not docs or not content:
|
|
|
|
|
|
info(f"[Pipeline] 无检索结果,置信度=0")
|
|
|
|
|
|
return RAGResult(
|
|
|
|
|
|
content="",
|
|
|
|
|
|
documents=[],
|
|
|
|
|
|
confidence=0.0,
|
|
|
|
|
|
scores={"embedding": 0.0, "rerank": 0.0, "llm": 0.0, "final": 0.0},
|
|
|
|
|
|
is_useful=False
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 评估置信度(三维度)
|
|
|
|
|
|
scores = await self._evaluate_confidence(
|
|
|
|
|
|
query=original_query or query,
|
|
|
|
|
|
docs=docs,
|
|
|
|
|
|
content=content
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
confidence = scores["final"]
|
|
|
|
|
|
is_useful = confidence >= self.confidence_threshold
|
|
|
|
|
|
|
|
|
|
|
|
info(f"[Pipeline] 置信度评估完成: confidence={confidence:.3f}, is_useful={is_useful}")
|
|
|
|
|
|
|
|
|
|
|
|
return RAGResult(
|
|
|
|
|
|
content=content,
|
|
|
|
|
|
documents=docs,
|
|
|
|
|
|
confidence=confidence,
|
|
|
|
|
|
scores=scores,
|
|
|
|
|
|
is_useful=is_useful
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def _do_retrieve(self, query: str) -> List[Document]:
|
|
|
|
|
|
"""执行检索流程"""
|
2026-05-05 23:17:00 +08:00
|
|
|
|
# Step 1: 检索
|
|
|
|
|
|
child_docs = await self._retrieve(query)
|
|
|
|
|
|
|
2026-05-08 00:29:12 +08:00
|
|
|
|
# Step 1.5: 向量初筛
|
2026-05-06 17:08:47 +08:00
|
|
|
|
vector_top_n = 20
|
|
|
|
|
|
if len(child_docs) > vector_top_n:
|
|
|
|
|
|
child_docs = child_docs[:vector_top_n]
|
|
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
# Step 2: 重排
|
|
|
|
|
|
if self.reranker:
|
|
|
|
|
|
try:
|
|
|
|
|
|
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
|
|
|
|
|
|
except Exception as e:
|
2026-05-06 16:15:09 +08:00
|
|
|
|
warning(f"[Pipeline] 重排失败: {e}")
|
2026-05-05 23:17:00 +08:00
|
|
|
|
child_docs = child_docs[:self.rerank_top_n]
|
|
|
|
|
|
|
|
|
|
|
|
# Step 3: 获取父文档
|
|
|
|
|
|
if self.return_parent_docs:
|
2026-05-08 00:29:12 +08:00
|
|
|
|
return await self._get_parents(child_docs)
|
2026-05-05 23:17:00 +08:00
|
|
|
|
return child_docs
|
|
|
|
|
|
|
2026-05-08 00:29:12 +08:00
|
|
|
|
async def _evaluate_confidence(
|
|
|
|
|
|
self,
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
docs: List[Document],
|
|
|
|
|
|
content: str
|
|
|
|
|
|
) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
三维度置信度评估
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
{
|
|
|
|
|
|
"embedding": float, # 向量相似度 (0-1)
|
|
|
|
|
|
"rerank": float, # 重排分数 (0-1)
|
|
|
|
|
|
"llm": float, # LLM判断 (0-1)
|
|
|
|
|
|
"final": float # 综合分数 (0-1)
|
|
|
|
|
|
}
|
|
|
|
|
|
"""
|
|
|
|
|
|
scores = {"embedding": 0.0, "rerank": 0.0, "llm": 0.5, "final": 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 向量相似度
|
|
|
|
|
|
if docs:
|
|
|
|
|
|
embedding_scores = []
|
|
|
|
|
|
for doc in docs:
|
|
|
|
|
|
score = doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0))
|
|
|
|
|
|
# 归一化(如果分数 > 1)
|
|
|
|
|
|
if score > 1.0:
|
|
|
|
|
|
score = min(score / 10.0, 1.0)
|
|
|
|
|
|
embedding_scores.append(score)
|
|
|
|
|
|
scores["embedding"] = max(embedding_scores) if embedding_scores else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
info(f"[Confidence] embedding={scores['embedding']:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 重排分数
|
|
|
|
|
|
if docs:
|
|
|
|
|
|
rerank_scores = []
|
|
|
|
|
|
for doc in docs:
|
|
|
|
|
|
score = doc.metadata.get("rerank_score", 0.0)
|
|
|
|
|
|
# 归一化(假设满分 10)
|
|
|
|
|
|
if score > 1.0:
|
|
|
|
|
|
score = min(score / 10.0, 1.0)
|
|
|
|
|
|
rerank_scores.append(score)
|
|
|
|
|
|
scores["rerank"] = max(rerank_scores) if rerank_scores else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
info(f"[Confidence] rerank={scores['rerank']:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 3. LLM 判断
|
|
|
|
|
|
if self.llm and content:
|
|
|
|
|
|
llm_score = await self._get_llm_confidence(query, content)
|
|
|
|
|
|
scores["llm"] = llm_score
|
|
|
|
|
|
info(f"[Confidence] llm={scores['llm']:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 综合得分(加权平均)
|
|
|
|
|
|
scores["final"] = (
|
|
|
|
|
|
scores["embedding"] * 0.25 +
|
|
|
|
|
|
scores["rerank"] * 0.25 +
|
|
|
|
|
|
scores["llm"] * 0.50
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
info(f"[Confidence] final={scores['final']:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_llm_confidence(self, query: str, context: str) -> float:
|
|
|
|
|
|
"""使用 LLM 评估检索结果相关性"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
|
|
|
|
|
|
- 1.0 = 完全相关,能直接回答问题
|
|
|
|
|
|
- 0.7 = 高度相关,有很大参考价值
|
|
|
|
|
|
- 0.5 = 部分相关,有一定参考价值
|
|
|
|
|
|
- 0.3 = 低度相关,参考价值有限
|
|
|
|
|
|
- 0.0 = 完全不相关,无法回答问题
|
|
|
|
|
|
|
|
|
|
|
|
用户问题:{query}
|
|
|
|
|
|
|
|
|
|
|
|
检索结果:{context[:1500]}
|
|
|
|
|
|
|
|
|
|
|
|
只返回一个数字(0.0-1.0):"""
|
|
|
|
|
|
|
|
|
|
|
|
response = await self.llm.ainvoke(prompt)
|
|
|
|
|
|
content = response.content.strip()
|
|
|
|
|
|
|
|
|
|
|
|
match = re.search(r'(\d+\.?\d*)', content)
|
|
|
|
|
|
if match:
|
|
|
|
|
|
score = float(match.group(1))
|
|
|
|
|
|
return max(0.0, min(1.0, score))
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
info(f"[Confidence] LLM评估失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
return 0.5 # 默认中等置信度
|
|
|
|
|
|
|
2026-05-06 04:26:06 +08:00
|
|
|
|
def _extract_scores(self, docs: List[Document]) -> List[dict]:
|
|
|
|
|
|
"""提取文档的分数信息"""
|
|
|
|
|
|
scores = []
|
|
|
|
|
|
for doc in docs:
|
|
|
|
|
|
scores.append({
|
|
|
|
|
|
"embedding_score": doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0)),
|
|
|
|
|
|
"rerank_score": doc.metadata.get("rerank_score", 0.0),
|
|
|
|
|
|
})
|
|
|
|
|
|
return scores
|
|
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
async def _retrieve(self, query: str) -> List[Document]:
|
|
|
|
|
|
if self.query_generator:
|
2026-05-04 02:01:22 +08:00
|
|
|
|
queries = await self.query_generator.agenerate(query)
|
2026-05-05 23:17:00 +08:00
|
|
|
|
queries = [query] + [q for q in queries if q != query]
|
|
|
|
|
|
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
|
2026-05-08 00:29:12 +08:00
|
|
|
|
return reciprocal_rank_fusion(doc_lists)
|
|
|
|
|
|
return await self.retriever.ainvoke(query)
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
|
|
|
|
|
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
|
2026-05-08 00:29:12 +08:00
|
|
|
|
parent_map = {}
|
2026-05-05 23:17:00 +08:00
|
|
|
|
for doc in child_docs:
|
|
|
|
|
|
pid = doc.metadata.get("parent_id")
|
|
|
|
|
|
if pid and pid not in parent_map:
|
2026-05-06 01:15:52 +08:00
|
|
|
|
embedding_score = doc.metadata.get("score", 0.0)
|
|
|
|
|
|
rerank_score = doc.metadata.get("rerank_score", 0.0)
|
|
|
|
|
|
parent_map[pid] = (embedding_score, rerank_score)
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
|
|
|
|
|
if not parent_map:
|
|
|
|
|
|
return child_docs
|
|
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
try:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
from backend.rag_core import create_docstore
|
|
|
|
|
|
docstore, _ = create_docstore()
|
2026-05-08 00:29:12 +08:00
|
|
|
|
parent_docs = await docstore.amget(list(parent_map.keys()))
|
2026-05-06 01:15:52 +08:00
|
|
|
|
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for doc in parent_docs:
|
|
|
|
|
|
if doc:
|
|
|
|
|
|
pid = doc.metadata.get("id")
|
|
|
|
|
|
scores = parent_map.get(pid, (0.0, 0.0))
|
|
|
|
|
|
doc.metadata["embedding_score"] = scores[0]
|
|
|
|
|
|
doc.metadata["rerank_score"] = scores[1]
|
2026-05-08 00:29:12 +08:00
|
|
|
|
# 综合分数,rerank 权重更高
|
|
|
|
|
|
result.append((doc, scores[0] + scores[1] * 2))
|
2026-05-06 01:15:52 +08:00
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
result.sort(key=lambda x: x[1], reverse=True)
|
2026-05-08 00:29:12 +08:00
|
|
|
|
return [d for d, _ in result]
|
2026-05-05 23:17:00 +08:00
|
|
|
|
except Exception as e:
|
2026-05-08 00:29:12 +08:00
|
|
|
|
warning(f"[Pipeline] 获取父文档失败: {e}")
|
2026-05-05 23:17:00 +08:00
|
|
|
|
return child_docs
|
2026-05-04 17:58:10 +08:00
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
def format_context(self, documents: List[Document]) -> str:
|
|
|
|
|
|
if not documents:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
parts = []
|
|
|
|
|
|
for i, doc in enumerate(documents, 1):
|
|
|
|
|
|
source = doc.metadata.get("source", "未知来源")
|
|
|
|
|
|
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
|
2026-05-08 00:29:12 +08:00
|
|
|
|
return "\n".join(parts)
|
2026-05-04 02:01:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
def create_rag_pipeline(**kwargs) -> RAGPipeline:
|
|
|
|
|
|
return RAGPipeline(**kwargs)
|