""" RAG 检索节点模块 包含:RAG 检索、置信度判断、重检索等节点 """ import time import asyncio from typing import Optional from datetime import datetime from langchain_core.runnables.config import RunnableConfig from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG from backend.app.logger import info, debug from ...model_services import get_small_llm_service from ._utils import dispatch_custom_event, make_react_event # 置信度阈值配置 RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关 # 全局 pipeline 实例 _rag_pipeline = None def _get_rag_pipeline(): """获取 RAG Pipeline 实例""" global _rag_pipeline if _rag_pipeline is None: from backend.app.rag.pipeline import RAGPipeline _rag_pipeline = RAGPipeline( num_queries=3, rerank_top_n=5, use_rerank=True, return_parent_docs=True, ) return _rag_pipeline def _get_rag_tool() -> Optional[callable]: """获取 RAG 工具""" from backend.app.main_graph.utils.rag_initializer import get_rag_tool return get_rag_tool() # ========== RAG 检索核心逻辑 ========== async def _rag_retrieve_core(state: MainGraphState, pipeline) -> MainGraphState: info(f"[RAG Core] _rag_retrieve_core 开始") retrieval_query = state.user_query # 优先使用推理结果中的优化查询 - 从新的结构化字段获取 reasoning_result = state.react_reasoning.reasoning_result if reasoning_result and hasattr(reasoning_result, "retrieval_config"): cfg = reasoning_result.retrieval_config if cfg and cfg.retrieval_query: retrieval_query = cfg.retrieval_query info(f"[RAG Core] 使用检索查询: {retrieval_query[:50]}...") # 直接调用 pipeline 获取文档和上下文 info(f"[RAG Core] 调用 pipeline.aretrieve") documents = await pipeline.aretrieve(retrieval_query) info(f"[RAG Core] pipeline.aretrieve 返回,得到 {len(documents)} 个文档") info(f"[RAG Core] 调用 pipeline.format_context") rag_context = pipeline.format_context(documents) info(f"[RAG Core] pipeline.format_context 返回") info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}") info(f"[RAG Core] 获取到 rag_docs: {len(documents)} 个文档") # 更新状态 state.rag_context = rag_context state.rag_docs = documents # 保存文档用于置信度评估 state.rag_retrieved = bool(documents) # 有文档才算检索成功 state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1 # 移除对 debug_info 的依赖,不再保存 rag_scores info(f"[RAG Core] _rag_retrieve_core 结束") return state # ========== RAG 检索节点 ========== async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: info(f"[RAG] rag_retrieve_node 开始") state.current_phase = "rag_retrieving" start_time = time.time() info(f"[RAG] 调用 _get_rag_pipeline") pipeline = _get_rag_pipeline() await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."), config ) try: info(f"[RAG] 调用 _rag_retrieve_core") state = await _rag_retrieve_core(state, pipeline) info(f"[RAG] _rag_retrieve_core 返回") # 评估置信度 info(f"[RAG] 调用 _evaluate_rag_confidence") confidence = await _evaluate_rag_confidence(state) state.rag_confidence = confidence info(f"[RAG] 检索完成,置信度={confidence:.2f},RAG尝试次数={state.rag_attempts}") state.reasoning_history.append({ "step": state.reasoning_step, "action": "RETRIEVE_RAG", "confidence": confidence, "reasoning": f"RAG 检索完成,置信度={confidence:.2f}", "timestamp": datetime.now().isoformat() }) await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence, f"RAG 检索完成,置信度={confidence:.2f}"), config ) except Exception as e: info(f"[RAG] 检索失败: {e}", exc_info=True) state.rag_confidence = 0.0 state.rag_retrieved = False info(f"[RAG] rag_retrieve_node 结束") return state async def _evaluate_rag_confidence(state: MainGraphState) -> float: """评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)""" query = state.user_query or "" rag_context = state.rag_context or "" if not rag_context: return 0.0 # 方式1: 向量相似度(从 rag_docs 中获取) embedding_score = _get_embedding_similarity(state) info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}") # 方式2: 重排序分数(从 rag_docs 中获取) rerank_score = _get_rerank_score(state) info(f"[RAG Confidence] 重排分数={rerank_score:.3f}") # 方式3: 小模型判断 llm_score = await _get_llm_score(state) info(f"[RAG Confidence] LLM评估={llm_score:.3f}") # 综合得分(加权平均) # 向量相似度权重 0.3,重排权重 0.3,LLM 权重 0.4 final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4 info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)") return final_score def _get_embedding_similarity(state: MainGraphState) -> float: """从 rag_docs 中获取向量相似度分数(不再从 debug_info 获取)""" # 降级:从 rag_docs 中获取 rag_docs = getattr(state, "rag_docs", []) scores = [] for doc in rag_docs: if isinstance(doc, dict): score = doc.get("score", 0.0) elif hasattr(doc, "metadata"): score = doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0)) else: continue if score > 1.0: score = min(score / 10.0, 1.0) scores.append(score) return max(scores) if scores else 0.0 def _get_rerank_score(state: MainGraphState) -> float: """从 rag_docs 中获取重排序分数(不再从 debug_info 获取)""" # 降级:从 rag_docs 中获取 rag_docs = getattr(state, "rag_docs", []) scores = [] for doc in rag_docs: if isinstance(doc, dict): score = doc.get("rerank_score", 0.0) elif hasattr(doc, "metadata"): score = doc.metadata.get("rerank_score", 0.0) else: continue if score > 0: scores.append(score) return max(scores) if scores else 0.0 async def _get_llm_score(state: MainGraphState) -> float: """使用小模型评估检索结果相关性""" query = state.user_query or "" rag_context = state.rag_context or "" try: llm = get_small_llm_service() prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数: - 1.0 = 完全相关,能直接回答问题 - 0.5 = 部分相关,有一定参考价值 - 0.0 = 完全不相关,无法回答问题 用户问题:{query} 检索结果:{rag_context[:1500]} 只返回一个数字:""" response = await llm.ainvoke(prompt) content = response.content.strip() import re match = re.search(r'(\d+\.?\d*)', content) if match: score = float(match.group(1)) return max(0.0, min(1.0, score)) except Exception as e: info(f"[RAG Confidence] LLM评估失败: {e}") return 0.5 # 默认中等置信度 # ========== 置信度判断节点 ========== def check_rag_confidence(state: MainGraphState) -> str: """ 根据 RAG 置信度判断下一步 Returns: "high_confidence" - 高置信度(>=0.6),可直接生成回答 "low_confidence" - 低置信度(<0.6),需要联网搜索 "no_rag" - 无检索结果,需要联网搜索 """ rag_attempts = getattr(state, 'rag_attempts', 0) rag_confidence = getattr(state, 'rag_confidence', 0.0) info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}") # 情况1: 没有检索结果 if not getattr(state, 'rag_retrieved', False) or not state.rag_context: info("[Confidence Check] 无检索结果,走联网") return "no_rag" # 情况2: 置信度低于阈值 if rag_confidence < RAG_CONFIDENCE_THRESHOLD: if rag_attempts >= 2: info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},且RAG尝试{rag_attempts}次,走联网") return "low_confidence" else: info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},可再尝试RAG一次") return "retry_rag" # 情况3: 高置信度 info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答") return "high_confidence" # ========== 导出 ========== __all__ = [ "rag_retrieve_node", "check_rag_confidence", "RAG_CONFIDENCE_THRESHOLD", ]