Files
ailine/backend/app/main_graph/nodes/rag_nodes.py
root ef6fbc1521
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m36s
推理优化
2026-05-06 04:26:06 +08:00

275 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索节点模块
包含RAG 检索、置信度判断、重检索等节点
"""
import time
import asyncio
from typing import Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from backend.app.logger import info, debug
from ...model_services import get_small_llm_service
from ._utils import dispatch_custom_event, make_react_event
# 置信度阈值配置
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
# 全局 pipeline 实例
_rag_pipeline = None
def _get_rag_pipeline():
"""获取 RAG Pipeline 实例"""
global _rag_pipeline
if _rag_pipeline is None:
from backend.app.rag.pipeline import RAGPipeline
_rag_pipeline = RAGPipeline(
num_queries=3,
rerank_top_n=5,
use_rerank=True,
return_parent_docs=True,
)
return _rag_pipeline
def _get_rag_tool() -> Optional[callable]:
"""获取 RAG 工具"""
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
return get_rag_tool()
# ========== RAG 检索核心逻辑 ==========
async def _rag_retrieve_core(state: MainGraphState, pipeline) -> MainGraphState:
"""执行 RAG 检索的核心逻辑"""
retrieval_query = state.user_query
# 优先使用推理结果中的优化查询
reasoning_result = state.debug_info.get("reasoning_result")
if reasoning_result and hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 直接调用 pipeline 获取文档和上下文
documents = await pipeline.aretrieve(retrieval_query)
rag_context = pipeline.format_context(documents)
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
info(f"[RAG Core] 获取到 rag_docs: {len(documents)} 个文档")
# 更新状态
state.rag_context = rag_context
state.rag_docs = documents # 保存文档用于置信度评估
state.rag_retrieved = bool(documents) # 有文档才算检索成功
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
state.debug_info["rag_source"] = "pipeline"
state.debug_info["rag_scores"] = pipeline.last_scores # 保存分数信息
return state
# ========== RAG 检索节点 ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""RAG 检索节点:检索 + 置信度评估"""
state.current_phase = "rag_retrieving"
start_time = time.time()
pipeline = _get_rag_pipeline()
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."),
config
)
try:
state = await _rag_retrieve_core(state, pipeline)
# 评估置信度
confidence = await _evaluate_rag_confidence(state)
state.rag_confidence = confidence
info(f"[RAG] 检索完成,置信度={confidence:.2f}RAG尝试次数={state.rag_attempts}")
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": confidence,
"reasoning": f"RAG 检索完成,置信度={confidence:.2f}",
"timestamp": datetime.now().isoformat()
})
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence,
f"RAG 检索完成,置信度={confidence:.2f}"),
config
)
except Exception as e:
info(f"[RAG] 检索失败: {e}")
state.rag_confidence = 0.0
state.rag_retrieved = False
return state
async def _evaluate_rag_confidence(state: MainGraphState) -> float:
"""评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)"""
query = state.user_query or ""
rag_context = state.rag_context or ""
if not rag_context:
return 0.0
# 方式1: 向量相似度(从 rag_docs 中获取)
embedding_score = _get_embedding_similarity(state)
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
# 方式2: 重排序分数(从 rag_docs 中获取)
rerank_score = _get_rerank_score(state)
info(f"[RAG Confidence] 重排分数={rerank_score:.3f}")
# 方式3: 小模型判断
llm_score = await _get_llm_score(state)
info(f"[RAG Confidence] LLM评估={llm_score:.3f}")
# 综合得分(加权平均)
# 向量相似度权重 0.3,重排权重 0.3LLM 权重 0.4
final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4
info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)")
return final_score
def _get_embedding_similarity(state: MainGraphState) -> float:
"""从 rag_scores 或 rag_docs 中获取向量相似度分数"""
# 优先从 pipeline 提供的分数中获取
rag_scores = state.debug_info.get("rag_scores", [])
if rag_scores:
scores = [s.get("embedding_score", 0.0) for s in rag_scores]
if scores:
# 归一化到 0-1
normalized = [min(s / 10.0, 1.0) if s > 1.0 else s for s in scores]
return max(normalized)
# 降级:从 rag_docs 中获取
rag_docs = getattr(state, "rag_docs", [])
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("score", 0.0)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0))
else:
continue
if score > 1.0:
score = min(score / 10.0, 1.0)
scores.append(score)
return max(scores) if scores else 0.0
def _get_rerank_score(state: MainGraphState) -> float:
"""从 rag_scores 或 rag_docs 中获取重排序分数"""
# 优先从 pipeline 提供的分数中获取
rag_scores = state.debug_info.get("rag_scores", [])
if rag_scores:
scores = [s.get("rerank_score", 0.0) for s in rag_scores]
return max(scores) if scores else 0.0
# 降级:从 rag_docs 中获取
rag_docs = getattr(state, "rag_docs", [])
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("rerank_score", 0.0)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("rerank_score", 0.0)
else:
continue
if score > 0:
scores.append(score)
return max(scores) if scores else 0.0
async def _get_llm_score(state: MainGraphState) -> float:
"""使用小模型评估检索结果相关性"""
query = state.user_query or ""
rag_context = state.rag_context or ""
try:
llm = get_small_llm_service()
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
- 1.0 = 完全相关,能直接回答问题
- 0.5 = 部分相关,有一定参考价值
- 0.0 = 完全不相关,无法回答问题
用户问题:{query}
检索结果:{rag_context[:1500]}
只返回一个数字:"""
response = await llm.ainvoke(prompt)
content = response.content.strip()
import re
match = re.search(r'(\d+\.?\d*)', content)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
except Exception as e:
info(f"[RAG Confidence] LLM评估失败: {e}")
return 0.5 # 默认中等置信度
# ========== 置信度判断节点 ==========
def check_rag_confidence(state: MainGraphState) -> str:
"""
根据 RAG 置信度判断下一步
Returns:
"high_confidence" - 高置信度(>=0.6),可直接生成回答
"low_confidence" - 低置信度(<0.6),需要联网搜索
"no_rag" - 无检索结果,需要联网搜索
"""
rag_attempts = getattr(state, 'rag_attempts', 0)
rag_confidence = getattr(state, 'rag_confidence', 0.0)
info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}")
# 情况1: 没有检索结果
if not getattr(state, 'rag_retrieved', False) or not state.rag_context:
info("[Confidence Check] 无检索结果,走联网")
return "no_rag"
# 情况2: 置信度低于阈值
if rag_confidence < RAG_CONFIDENCE_THRESHOLD:
if rag_attempts >= 2:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}且RAG尝试{rag_attempts}次,走联网")
return "low_confidence"
else:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}可再尝试RAG一次")
return "retry_rag"
# 情况3: 高置信度
info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答")
return "high_confidence"
# ========== 导出 ==========
__all__ = [
"rag_retrieve_node",
"check_rag_confidence",
"RAG_CONFIDENCE_THRESHOLD",
]