Files
ailine/backend/app/rag/pipeline.py
root 6dfa9f572e
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m24s
重构:清理废弃代码 + 优化 Agent 架构
主要变更:
- 删除 deprecated 文件夹(intent/hybrid_router/rag_nodes 等)
- 删除 intent_classifier.py(未使用)
- 删除 subgraph_wrapper.py(死代码)
- 重构 agent.py:简化工厂函数,支持动态模型切换
- 重构 prompts.py:添加信息获取优先级、思维链要求、工具调用约束
- 优化 tools:统一位置,rag_search 返回置信度评估
- 新增 RAG 置信度评估:embedding(25%) + rerank(25%) + LLM(50%)
- 添加循环检测:防止工具无限重复调用

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-08 00:29:12 +08:00

309 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索流水线
流程: 检索子文档 → 重排 → 获取父文档 → 置信度评估 → 返回
"""
import asyncio
import re
from dataclasses import dataclass
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from backend.app.logger import info, warning
from ..model_services import get_small_llm_service
from ..rag.rerank import create_document_reranker
from ..rag.query_transform import MultiQueryGenerator
from ..rag.fusion import reciprocal_rank_fusion
from ..rag.retriever import create_parent_hybrid_retriever
@dataclass
class RAGResult:
"""RAG 检索结果(包含置信度)"""
content: str # 格式化后的上下文
documents: List[Document] # 原始文档
confidence: float # 综合置信度 0.0-1.0
scores: dict # 各维度分数 {embedding, rerank, llm, final}
is_useful: bool # 是否可用confidence >= 0.6
class RAGPipeline:
def __init__(
self,
retriever=None,
llm: BaseLanguageModel | str = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
use_rerank: bool = True,
return_parent_docs: bool = True,
confidence_threshold: float = 0.6,
):
self.retriever = retriever or create_parent_hybrid_retriever(
collection_name=collection_name, search_k=rerank_top_n * 4
)
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
self.use_rerank = use_rerank
self.return_parent_docs = return_parent_docs
self.confidence_threshold = confidence_threshold
self._last_docs: List[Document] = []
self._last_scores: List[dict] = []
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception:
self.llm = None
else:
self.llm = llm if llm else None
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
self.reranker = create_document_reranker() if use_rerank else None
info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}, threshold={confidence_threshold}")
@property
def last_docs(self) -> List[Document]:
return self._last_docs
@property
def last_scores(self) -> List[dict]:
return self._last_scores
async def aretrieve(self, query: str) -> List[Document]:
"""原接口,保持向后兼容"""
docs = await self._do_retrieve(query)
self._last_docs = docs
self._last_scores = self._extract_scores(docs)
return docs
async def aretrieve_with_confidence(self, query: str, original_query: str = "") -> RAGResult:
"""
带置信度评估的检索
Args:
query: 检索查询
original_query: 原始用户问题(用于置信度评估)
Returns:
RAGResult: 包含内容和置信度的结构化结果
"""
info(f"[Pipeline] aretrieve_with_confidence: query={query[:50]}...")
# 1. 执行检索
docs = await self._do_retrieve(query)
self._last_docs = docs
self._last_scores = self._extract_scores(docs)
# 2. 格式化内容
content = self.format_context(docs)
if not docs or not content:
info(f"[Pipeline] 无检索结果,置信度=0")
return RAGResult(
content="",
documents=[],
confidence=0.0,
scores={"embedding": 0.0, "rerank": 0.0, "llm": 0.0, "final": 0.0},
is_useful=False
)
# 3. 评估置信度(三维度)
scores = await self._evaluate_confidence(
query=original_query or query,
docs=docs,
content=content
)
confidence = scores["final"]
is_useful = confidence >= self.confidence_threshold
info(f"[Pipeline] 置信度评估完成: confidence={confidence:.3f}, is_useful={is_useful}")
return RAGResult(
content=content,
documents=docs,
confidence=confidence,
scores=scores,
is_useful=is_useful
)
async def _do_retrieve(self, query: str) -> List[Document]:
"""执行检索流程"""
# Step 1: 检索
child_docs = await self._retrieve(query)
# Step 1.5: 向量初筛
vector_top_n = 20
if len(child_docs) > vector_top_n:
child_docs = child_docs[:vector_top_n]
# Step 2: 重排
if self.reranker:
try:
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
except Exception as e:
warning(f"[Pipeline] 重排失败: {e}")
child_docs = child_docs[:self.rerank_top_n]
# Step 3: 获取父文档
if self.return_parent_docs:
return await self._get_parents(child_docs)
return child_docs
async def _evaluate_confidence(
self,
query: str,
docs: List[Document],
content: str
) -> dict:
"""
三维度置信度评估
Returns:
{
"embedding": float, # 向量相似度 (0-1)
"rerank": float, # 重排分数 (0-1)
"llm": float, # LLM判断 (0-1)
"final": float # 综合分数 (0-1)
}
"""
scores = {"embedding": 0.0, "rerank": 0.0, "llm": 0.5, "final": 0.0}
# 1. 向量相似度
if docs:
embedding_scores = []
for doc in docs:
score = doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0))
# 归一化(如果分数 > 1
if score > 1.0:
score = min(score / 10.0, 1.0)
embedding_scores.append(score)
scores["embedding"] = max(embedding_scores) if embedding_scores else 0.0
info(f"[Confidence] embedding={scores['embedding']:.3f}")
# 2. 重排分数
if docs:
rerank_scores = []
for doc in docs:
score = doc.metadata.get("rerank_score", 0.0)
# 归一化(假设满分 10
if score > 1.0:
score = min(score / 10.0, 1.0)
rerank_scores.append(score)
scores["rerank"] = max(rerank_scores) if rerank_scores else 0.0
info(f"[Confidence] rerank={scores['rerank']:.3f}")
# 3. LLM 判断
if self.llm and content:
llm_score = await self._get_llm_confidence(query, content)
scores["llm"] = llm_score
info(f"[Confidence] llm={scores['llm']:.3f}")
# 4. 综合得分(加权平均)
scores["final"] = (
scores["embedding"] * 0.25 +
scores["rerank"] * 0.25 +
scores["llm"] * 0.50
)
info(f"[Confidence] final={scores['final']:.3f}")
return scores
async def _get_llm_confidence(self, query: str, context: str) -> float:
"""使用 LLM 评估检索结果相关性"""
try:
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
- 1.0 = 完全相关,能直接回答问题
- 0.7 = 高度相关,有很大参考价值
- 0.5 = 部分相关,有一定参考价值
- 0.3 = 低度相关,参考价值有限
- 0.0 = 完全不相关,无法回答问题
用户问题:{query}
检索结果:{context[:1500]}
只返回一个数字0.0-1.0"""
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
match = re.search(r'(\d+\.?\d*)', content)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
except Exception as e:
info(f"[Confidence] LLM评估失败: {e}")
return 0.5 # 默认中等置信度
def _extract_scores(self, docs: List[Document]) -> List[dict]:
"""提取文档的分数信息"""
scores = []
for doc in docs:
scores.append({
"embedding_score": doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0)),
"rerank_score": doc.metadata.get("rerank_score", 0.0),
})
return scores
async def _retrieve(self, query: str) -> List[Document]:
if self.query_generator:
queries = await self.query_generator.agenerate(query)
queries = [query] + [q for q in queries if q != query]
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
return reciprocal_rank_fusion(doc_lists)
return await self.retriever.ainvoke(query)
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
parent_map = {}
for doc in child_docs:
pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map:
embedding_score = doc.metadata.get("score", 0.0)
rerank_score = doc.metadata.get("rerank_score", 0.0)
parent_map[pid] = (embedding_score, rerank_score)
if not parent_map:
return child_docs
try:
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
parent_docs = await docstore.amget(list(parent_map.keys()))
result = []
for doc in parent_docs:
if doc:
pid = doc.metadata.get("id")
scores = parent_map.get(pid, (0.0, 0.0))
doc.metadata["embedding_score"] = scores[0]
doc.metadata["rerank_score"] = scores[1]
# 综合分数rerank 权重更高
result.append((doc, scores[0] + scores[1] * 2))
result.sort(key=lambda x: x[1], reverse=True)
return [d for d, _ in result]
except Exception as e:
warning(f"[Pipeline] 获取父文档失败: {e}")
return child_docs
def format_context(self, documents: List[Document]) -> str:
if not documents:
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)
def create_rag_pipeline(**kwargs) -> RAGPipeline:
return RAGPipeline(**kwargs)