2026-04-26 11:23:12 +08:00
|
|
|
|
"""
|
2026-05-05 00:54:04 +08:00
|
|
|
|
RAG 检索节点模块
|
2026-05-06 01:15:52 +08:00
|
|
|
|
包含:RAG 检索、置信度判断、重检索等节点
|
2026-04-26 11:23:12 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import time
|
2026-04-26 11:25:01 +08:00
|
|
|
|
import asyncio
|
2026-05-05 04:32:42 +08:00
|
|
|
|
from typing import Optional
|
2026-04-26 11:23:12 +08:00
|
|
|
|
from datetime import datetime
|
2026-05-05 23:17:00 +08:00
|
|
|
|
from langchain_core.runnables.config import RunnableConfig
|
2026-04-26 11:23:12 +08:00
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
|
|
|
|
|
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
2026-05-06 01:15:52 +08:00
|
|
|
|
from backend.app.logger import info, debug
|
|
|
|
|
|
from ...model_services import get_small_llm_service
|
2026-05-05 00:54:04 +08:00
|
|
|
|
from ._utils import dispatch_custom_event, make_react_event
|
2026-04-26 11:23:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
# 置信度阈值配置
|
|
|
|
|
|
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
|
|
|
|
|
|
|
2026-05-06 04:26:06 +08:00
|
|
|
|
# 全局 pipeline 实例
|
|
|
|
|
|
_rag_pipeline = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_rag_pipeline():
|
|
|
|
|
|
"""获取 RAG Pipeline 实例"""
|
|
|
|
|
|
global _rag_pipeline
|
|
|
|
|
|
if _rag_pipeline is None:
|
|
|
|
|
|
from backend.app.rag.pipeline import RAGPipeline
|
|
|
|
|
|
_rag_pipeline = RAGPipeline(
|
|
|
|
|
|
num_queries=3,
|
|
|
|
|
|
rerank_top_n=5,
|
|
|
|
|
|
use_rerank=True,
|
|
|
|
|
|
return_parent_docs=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
return _rag_pipeline
|
|
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
|
2026-05-05 04:32:42 +08:00
|
|
|
|
def _get_rag_tool() -> Optional[callable]:
|
|
|
|
|
|
"""获取 RAG 工具"""
|
2026-05-05 23:17:00 +08:00
|
|
|
|
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
|
2026-05-05 04:32:42 +08:00
|
|
|
|
return get_rag_tool()
|
2026-04-26 11:23:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-05-05 00:54:04 +08:00
|
|
|
|
# ========== RAG 检索核心逻辑 ==========
|
2026-05-06 04:26:06 +08:00
|
|
|
|
async def _rag_retrieve_core(state: MainGraphState, pipeline) -> MainGraphState:
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG Core] _rag_retrieve_core 开始")
|
2026-04-26 11:23:12 +08:00
|
|
|
|
retrieval_query = state.user_query
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 14:45:40 +08:00
|
|
|
|
# 优先使用推理结果中的优化查询 - 从新的结构化字段获取
|
|
|
|
|
|
reasoning_result = state.react_reasoning.reasoning_result
|
2026-05-05 00:54:04 +08:00
|
|
|
|
if reasoning_result and hasattr(reasoning_result, "retrieval_config"):
|
|
|
|
|
|
cfg = reasoning_result.retrieval_config
|
|
|
|
|
|
if cfg and cfg.retrieval_query:
|
|
|
|
|
|
retrieval_query = cfg.retrieval_query
|
|
|
|
|
|
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG Core] 使用检索查询: {retrieval_query[:50]}...")
|
2026-05-06 04:26:06 +08:00
|
|
|
|
# 直接调用 pipeline 获取文档和上下文
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG Core] 调用 pipeline.aretrieve")
|
2026-05-06 04:26:06 +08:00
|
|
|
|
documents = await pipeline.aretrieve(retrieval_query)
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG Core] pipeline.aretrieve 返回,得到 {len(documents)} 个文档")
|
|
|
|
|
|
info(f"[RAG Core] 调用 pipeline.format_context")
|
2026-05-06 04:26:06 +08:00
|
|
|
|
rag_context = pipeline.format_context(documents)
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG Core] pipeline.format_context 返回")
|
2026-05-06 04:26:06 +08:00
|
|
|
|
|
2026-05-05 04:32:42 +08:00
|
|
|
|
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
|
2026-05-06 04:26:06 +08:00
|
|
|
|
info(f"[RAG Core] 获取到 rag_docs: {len(documents)} 个文档")
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
# 更新状态
|
2026-05-05 04:32:42 +08:00
|
|
|
|
state.rag_context = rag_context
|
2026-05-06 04:26:06 +08:00
|
|
|
|
state.rag_docs = documents # 保存文档用于置信度评估
|
|
|
|
|
|
state.rag_retrieved = bool(documents) # 有文档才算检索成功
|
2026-05-06 01:15:52 +08:00
|
|
|
|
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
|
2026-05-06 14:45:40 +08:00
|
|
|
|
# 移除对 debug_info 的依赖,不再保存 rag_scores
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG Core] _rag_retrieve_core 结束")
|
2026-05-05 04:32:42 +08:00
|
|
|
|
return state
|
2026-04-26 11:25:01 +08:00
|
|
|
|
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
|
|
|
|
|
# ========== RAG 检索节点 ==========
|
2026-05-05 23:17:00 +08:00
|
|
|
|
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG] rag_retrieve_node 开始")
|
2026-04-26 11:23:12 +08:00
|
|
|
|
state.current_phase = "rag_retrieving"
|
|
|
|
|
|
start_time = time.time()
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG] 调用 _get_rag_pipeline")
|
2026-05-06 04:26:06 +08:00
|
|
|
|
pipeline = _get_rag_pipeline()
|
2026-05-05 04:32:42 +08:00
|
|
|
|
|
2026-05-05 00:54:04 +08:00
|
|
|
|
await dispatch_custom_event(
|
|
|
|
|
|
"react_reasoning",
|
|
|
|
|
|
make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."),
|
|
|
|
|
|
config
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
try:
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG] 调用 _rag_retrieve_core")
|
2026-05-06 04:26:06 +08:00
|
|
|
|
state = await _rag_retrieve_core(state, pipeline)
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG] _rag_retrieve_core 返回")
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
# 评估置信度
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG] 调用 _evaluate_rag_confidence")
|
2026-05-06 01:15:52 +08:00
|
|
|
|
confidence = await _evaluate_rag_confidence(state)
|
|
|
|
|
|
state.rag_confidence = confidence
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
info(f"[RAG] 检索完成,置信度={confidence:.2f},RAG尝试次数={state.rag_attempts}")
|
|
|
|
|
|
|
|
|
|
|
|
state.reasoning_history.append({
|
|
|
|
|
|
"step": state.reasoning_step,
|
|
|
|
|
|
"action": "RETRIEVE_RAG",
|
|
|
|
|
|
"confidence": confidence,
|
|
|
|
|
|
"reasoning": f"RAG 检索完成,置信度={confidence:.2f}",
|
|
|
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
await dispatch_custom_event(
|
|
|
|
|
|
"react_reasoning",
|
|
|
|
|
|
make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence,
|
|
|
|
|
|
f"RAG 检索完成,置信度={confidence:.2f}"),
|
|
|
|
|
|
config
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG] 检索失败: {e}", exc_info=True)
|
2026-05-06 01:15:52 +08:00
|
|
|
|
state.rag_confidence = 0.0
|
|
|
|
|
|
state.rag_retrieved = False
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 16:02:53 +08:00
|
|
|
|
info(f"[RAG] rag_retrieve_node 结束")
|
2026-04-26 11:23:12 +08:00
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
async def _evaluate_rag_confidence(state: MainGraphState) -> float:
|
|
|
|
|
|
"""评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)"""
|
|
|
|
|
|
query = state.user_query or ""
|
|
|
|
|
|
rag_context = state.rag_context or ""
|
|
|
|
|
|
|
|
|
|
|
|
if not rag_context:
|
|
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
# 方式1: 向量相似度(从 rag_docs 中获取)
|
2026-05-06 04:26:06 +08:00
|
|
|
|
embedding_score = _get_embedding_similarity(state)
|
2026-05-06 01:15:52 +08:00
|
|
|
|
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 方式2: 重排序分数(从 rag_docs 中获取)
|
|
|
|
|
|
rerank_score = _get_rerank_score(state)
|
|
|
|
|
|
info(f"[RAG Confidence] 重排分数={rerank_score:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 方式3: 小模型判断
|
|
|
|
|
|
llm_score = await _get_llm_score(state)
|
|
|
|
|
|
info(f"[RAG Confidence] LLM评估={llm_score:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 综合得分(加权平均)
|
|
|
|
|
|
# 向量相似度权重 0.3,重排权重 0.3,LLM 权重 0.4
|
|
|
|
|
|
final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4
|
|
|
|
|
|
info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)")
|
|
|
|
|
|
|
|
|
|
|
|
return final_score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_embedding_similarity(state: MainGraphState) -> float:
|
2026-05-06 14:45:40 +08:00
|
|
|
|
"""从 rag_docs 中获取向量相似度分数(不再从 debug_info 获取)"""
|
2026-05-06 04:26:06 +08:00
|
|
|
|
# 降级:从 rag_docs 中获取
|
|
|
|
|
|
rag_docs = getattr(state, "rag_docs", [])
|
2026-05-06 01:15:52 +08:00
|
|
|
|
scores = []
|
|
|
|
|
|
for doc in rag_docs:
|
|
|
|
|
|
if isinstance(doc, dict):
|
|
|
|
|
|
score = doc.get("score", 0.0)
|
|
|
|
|
|
elif hasattr(doc, "metadata"):
|
2026-05-06 04:26:06 +08:00
|
|
|
|
score = doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0))
|
|
|
|
|
|
else:
|
|
|
|
|
|
continue
|
|
|
|
|
|
if score > 1.0:
|
|
|
|
|
|
score = min(score / 10.0, 1.0)
|
|
|
|
|
|
scores.append(score)
|
2026-05-06 01:15:52 +08:00
|
|
|
|
|
2026-05-06 04:26:06 +08:00
|
|
|
|
return max(scores) if scores else 0.0
|
2026-05-06 01:15:52 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_rerank_score(state: MainGraphState) -> float:
|
2026-05-06 14:45:40 +08:00
|
|
|
|
"""从 rag_docs 中获取重排序分数(不再从 debug_info 获取)"""
|
2026-05-06 04:26:06 +08:00
|
|
|
|
# 降级:从 rag_docs 中获取
|
2026-05-06 01:15:52 +08:00
|
|
|
|
rag_docs = getattr(state, "rag_docs", [])
|
|
|
|
|
|
scores = []
|
|
|
|
|
|
for doc in rag_docs:
|
|
|
|
|
|
if isinstance(doc, dict):
|
|
|
|
|
|
score = doc.get("rerank_score", 0.0)
|
|
|
|
|
|
elif hasattr(doc, "metadata"):
|
|
|
|
|
|
score = doc.metadata.get("rerank_score", 0.0)
|
|
|
|
|
|
else:
|
2026-05-06 04:26:06 +08:00
|
|
|
|
continue
|
2026-05-06 01:15:52 +08:00
|
|
|
|
if score > 0:
|
|
|
|
|
|
scores.append(score)
|
|
|
|
|
|
|
2026-05-06 04:26:06 +08:00
|
|
|
|
return max(scores) if scores else 0.0
|
2026-05-06 01:15:52 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_llm_score(state: MainGraphState) -> float:
|
|
|
|
|
|
"""使用小模型评估检索结果相关性"""
|
|
|
|
|
|
query = state.user_query or ""
|
|
|
|
|
|
rag_context = state.rag_context or ""
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
llm = get_small_llm_service()
|
|
|
|
|
|
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
|
|
|
|
|
|
- 1.0 = 完全相关,能直接回答问题
|
|
|
|
|
|
- 0.5 = 部分相关,有一定参考价值
|
|
|
|
|
|
- 0.0 = 完全不相关,无法回答问题
|
|
|
|
|
|
|
|
|
|
|
|
用户问题:{query}
|
|
|
|
|
|
|
|
|
|
|
|
检索结果:{rag_context[:1500]}
|
|
|
|
|
|
|
|
|
|
|
|
只返回一个数字:"""
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
response = await llm.ainvoke(prompt)
|
|
|
|
|
|
content = response.content.strip()
|
2026-05-05 00:54:04 +08:00
|
|
|
|
|
2026-05-06 01:15:52 +08:00
|
|
|
|
import re
|
|
|
|
|
|
match = re.search(r'(\d+\.?\d*)', content)
|
|
|
|
|
|
if match:
|
|
|
|
|
|
score = float(match.group(1))
|
|
|
|
|
|
return max(0.0, min(1.0, score))
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
info(f"[RAG Confidence] LLM评估失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
return 0.5 # 默认中等置信度
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 置信度判断节点 ==========
|
|
|
|
|
|
def check_rag_confidence(state: MainGraphState) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据 RAG 置信度判断下一步
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
"high_confidence" - 高置信度(>=0.6),可直接生成回答
|
|
|
|
|
|
"low_confidence" - 低置信度(<0.6),需要联网搜索
|
|
|
|
|
|
"no_rag" - 无检索结果,需要联网搜索
|
|
|
|
|
|
"""
|
|
|
|
|
|
rag_attempts = getattr(state, 'rag_attempts', 0)
|
|
|
|
|
|
rag_confidence = getattr(state, 'rag_confidence', 0.0)
|
|
|
|
|
|
|
|
|
|
|
|
info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 情况1: 没有检索结果
|
|
|
|
|
|
if not getattr(state, 'rag_retrieved', False) or not state.rag_context:
|
|
|
|
|
|
info("[Confidence Check] 无检索结果,走联网")
|
|
|
|
|
|
return "no_rag"
|
|
|
|
|
|
|
|
|
|
|
|
# 情况2: 置信度低于阈值
|
|
|
|
|
|
if rag_confidence < RAG_CONFIDENCE_THRESHOLD:
|
|
|
|
|
|
if rag_attempts >= 2:
|
|
|
|
|
|
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},且RAG尝试{rag_attempts}次,走联网")
|
|
|
|
|
|
return "low_confidence"
|
|
|
|
|
|
else:
|
|
|
|
|
|
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},可再尝试RAG一次")
|
|
|
|
|
|
return "retry_rag"
|
|
|
|
|
|
|
|
|
|
|
|
# 情况3: 高置信度
|
|
|
|
|
|
info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答")
|
|
|
|
|
|
return "high_confidence"
|
2026-04-26 11:23:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 导出 ==========
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
|
"rag_retrieve_node",
|
2026-05-06 01:15:52 +08:00
|
|
|
|
"check_rag_confidence",
|
|
|
|
|
|
"RAG_CONFIDENCE_THRESHOLD",
|
2026-04-26 11:23:12 +08:00
|
|
|
|
]
|