This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
RAG 检索节点模块
|
||||
使用模块级变量管理 RAG 工具
|
||||
包含:RAG 检索、置信度判断、重检索等节点
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -11,10 +11,15 @@ from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
||||
from ...logger import info
|
||||
from backend.app.logger import info, debug
|
||||
from ...model_services import get_small_llm_service
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
|
||||
# 置信度阈值配置
|
||||
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
|
||||
|
||||
|
||||
def _get_rag_tool() -> Optional[callable]:
|
||||
"""获取 RAG 工具"""
|
||||
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
@@ -36,43 +41,27 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
|
||||
# 调用 RAG 工具
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
|
||||
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
|
||||
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
|
||||
info(f"[RAG Core] ========================================")
|
||||
|
||||
# 更新状态
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
|
||||
state.debug_info["rag_source"] = "tool"
|
||||
|
||||
info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}")
|
||||
return state
|
||||
|
||||
|
||||
# ========== RAG 检索节点 ==========
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""RAG 检索节点:带超时和重试"""
|
||||
"""RAG 检索节点:检索 + 置信度评估"""
|
||||
state.current_phase = "rag_retrieving"
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
|
||||
# 获取 RAG 工具
|
||||
rag_tool = _get_rag_tool()
|
||||
|
||||
if not rag_tool:
|
||||
error_record = ErrorRecord(
|
||||
error_type="RAGRetrievalError",
|
||||
error_message="RAG 工具未初始化",
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source="rag_retrieve_node",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=RAG_RETRY_CONFIG.max_retries,
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
info("[RAG] RAG 工具未初始化")
|
||||
state.rag_confidence = 0.0
|
||||
state.rag_retrieved = False
|
||||
return state
|
||||
|
||||
await dispatch_custom_event(
|
||||
@@ -81,99 +70,184 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
|
||||
config
|
||||
)
|
||||
|
||||
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
||||
try:
|
||||
result = await _rag_retrieve_core(state, rag_tool)
|
||||
try:
|
||||
state = await _rag_retrieve_core(state, rag_tool)
|
||||
|
||||
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
|
||||
# 评估置信度
|
||||
confidence = await _evaluate_rag_confidence(state)
|
||||
state.rag_confidence = confidence
|
||||
|
||||
state.debug_info["rag_retrieval"] = {
|
||||
"attempt": attempt + 1,
|
||||
"success": True,
|
||||
"time": time.time() - start_time
|
||||
}
|
||||
info(f"[RAG] 检索完成,置信度={confidence:.2f},RAG尝试次数={state.rag_attempts}")
|
||||
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
"confidence": 1.0,
|
||||
"reasoning": "RAG 检索完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
"confidence": confidence,
|
||||
"reasoning": f"RAG 检索完成,置信度={confidence:.2f}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
doc_count = len(result.rag_docs) if result.rag_docs else 0
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0,
|
||||
f"RAG 检索完成,找到 {doc_count} 条相关文档"),
|
||||
config
|
||||
)
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence,
|
||||
f"RAG 检索完成,置信度={confidence:.2f}"),
|
||||
config
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
||||
break
|
||||
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
|
||||
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
|
||||
config
|
||||
)
|
||||
|
||||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||||
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||
|
||||
# 失败记录
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
"confidence": 0.0,
|
||||
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
error_record = ErrorRecord(
|
||||
error_type="RAGRetrievalError",
|
||||
error_message=str(last_error) if last_error else "RAG 检索超时",
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source="rag_retrieve_node",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=RAG_RETRY_CONFIG.max_retries,
|
||||
max_retries=RAG_RETRY_CONFIG.max_retries,
|
||||
)
|
||||
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
|
||||
f"RAG 检索失败: {str(last_error)}"),
|
||||
config
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[RAG] 检索失败: {e}")
|
||||
state.rag_confidence = 0.0
|
||||
state.rag_retrieved = False
|
||||
|
||||
return state
|
||||
|
||||
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""重新检索节点"""
|
||||
state.current_phase = "rag_re_retrieving"
|
||||
async def _evaluate_rag_confidence(state: MainGraphState) -> float:
|
||||
"""评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)"""
|
||||
query = state.user_query or ""
|
||||
rag_context = state.rag_context or ""
|
||||
|
||||
state.debug_info["rag_re_retrieve"] = {
|
||||
"original_retrieved": state.rag_retrieved,
|
||||
"original_docs_count": len(state.rag_docs)
|
||||
}
|
||||
if not rag_context:
|
||||
return 0.0
|
||||
|
||||
return await rag_retrieve_node(state, config)
|
||||
# 方式1: 向量相似度(从 rag_docs 中获取)
|
||||
embedding_score = _get_embedding_similarity(state, query)
|
||||
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
|
||||
|
||||
# 方式2: 重排序分数(从 rag_docs 中获取)
|
||||
rerank_score = _get_rerank_score(state)
|
||||
info(f"[RAG Confidence] 重排分数={rerank_score:.3f}")
|
||||
|
||||
# 方式3: 小模型判断
|
||||
llm_score = await _get_llm_score(state)
|
||||
info(f"[RAG Confidence] LLM评估={llm_score:.3f}")
|
||||
|
||||
# 综合得分(加权平均)
|
||||
# 向量相似度权重 0.3,重排权重 0.3,LLM 权重 0.4
|
||||
final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4
|
||||
info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)")
|
||||
|
||||
return final_score
|
||||
|
||||
|
||||
def _get_embedding_similarity(state: MainGraphState) -> float:
|
||||
"""从 rag_docs 中获取向量相似度分数"""
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
|
||||
# 如果有多个文档,取最高分
|
||||
scores = []
|
||||
for doc in rag_docs:
|
||||
if isinstance(doc, dict):
|
||||
score = doc.get("score", 0.0)
|
||||
# 向量相似度通常在 0-1 之间,RRF 分数可能更高
|
||||
# 归一化到 0-1
|
||||
if score > 1.0:
|
||||
score = min(score / 10.0, 1.0) # 假设 max 约 10
|
||||
scores.append(score)
|
||||
elif hasattr(doc, "metadata"):
|
||||
score = doc.metadata.get("score", 0.0)
|
||||
if score > 1.0:
|
||||
score = min(score / 10.0, 1.0)
|
||||
scores.append(score)
|
||||
|
||||
if scores:
|
||||
# 取平均或最高分
|
||||
return max(scores) # 使用最高分更准确
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_rerank_score(state: MainGraphState) -> float:
|
||||
"""从 rag_docs 中获取重排序分数"""
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
|
||||
# 重排分数通常在 0-1 之间
|
||||
scores = []
|
||||
for doc in rag_docs:
|
||||
if isinstance(doc, dict):
|
||||
score = doc.get("rerank_score", 0.0)
|
||||
elif hasattr(doc, "metadata"):
|
||||
score = doc.metadata.get("rerank_score", 0.0)
|
||||
else:
|
||||
score = 0.0
|
||||
|
||||
if score > 0:
|
||||
scores.append(score)
|
||||
|
||||
if scores:
|
||||
return max(scores) # 使用最高分
|
||||
return 0.0
|
||||
|
||||
|
||||
async def _get_llm_score(state: MainGraphState) -> float:
|
||||
"""使用小模型评估检索结果相关性"""
|
||||
query = state.user_query or ""
|
||||
rag_context = state.rag_context or ""
|
||||
|
||||
try:
|
||||
llm = get_small_llm_service()
|
||||
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
|
||||
- 1.0 = 完全相关,能直接回答问题
|
||||
- 0.5 = 部分相关,有一定参考价值
|
||||
- 0.0 = 完全不相关,无法回答问题
|
||||
|
||||
用户问题:{query}
|
||||
|
||||
检索结果:{rag_context[:1500]}
|
||||
|
||||
只返回一个数字:"""
|
||||
|
||||
response = await llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
|
||||
import re
|
||||
match = re.search(r'(\d+\.?\d*)', content)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
except Exception as e:
|
||||
info(f"[RAG Confidence] LLM评估失败: {e}")
|
||||
|
||||
return 0.5 # 默认中等置信度
|
||||
|
||||
|
||||
# ========== 置信度判断节点 ==========
|
||||
def check_rag_confidence(state: MainGraphState) -> str:
|
||||
"""
|
||||
根据 RAG 置信度判断下一步
|
||||
|
||||
Returns:
|
||||
"high_confidence" - 高置信度(>=0.6),可直接生成回答
|
||||
"low_confidence" - 低置信度(<0.6),需要联网搜索
|
||||
"no_rag" - 无检索结果,需要联网搜索
|
||||
"""
|
||||
rag_attempts = getattr(state, 'rag_attempts', 0)
|
||||
rag_confidence = getattr(state, 'rag_confidence', 0.0)
|
||||
|
||||
info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}")
|
||||
|
||||
# 情况1: 没有检索结果
|
||||
if not getattr(state, 'rag_retrieved', False) or not state.rag_context:
|
||||
info("[Confidence Check] 无检索结果,走联网")
|
||||
return "no_rag"
|
||||
|
||||
# 情况2: 置信度低于阈值
|
||||
if rag_confidence < RAG_CONFIDENCE_THRESHOLD:
|
||||
if rag_attempts >= 2:
|
||||
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},且RAG尝试{rag_attempts}次,走联网")
|
||||
return "low_confidence"
|
||||
else:
|
||||
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},可再尝试RAG一次")
|
||||
return "retry_rag"
|
||||
|
||||
# 情况3: 高置信度
|
||||
info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答")
|
||||
return "high_confidence"
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
"check_rag_confidence",
|
||||
"RAG_CONFIDENCE_THRESHOLD",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user