添加rag置信度判断
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m31s

This commit is contained in:
2026-05-06 01:15:52 +08:00
parent 3ae9daa01a
commit 1260bef5cb
35 changed files with 335 additions and 221 deletions

View File

@@ -8,7 +8,7 @@ from .web_search import web_search_node
from .error_handling import error_handling_node
from .routing import init_state_node, route_by_reasoning, should_summarize
from .llm_call import create_dynamic_llm_call_node
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
from .rag_nodes import rag_retrieve_node
# 记忆节点
from .retrieve_memory import create_retrieve_memory_node

View File

@@ -3,7 +3,7 @@
"""
from ...main_graph.state import MainGraphState, ErrorSeverity
from ...logger import info
from backend.app.logger import info
def error_handling_node(state: MainGraphState) -> MainGraphState:

View File

@@ -7,7 +7,7 @@ from typing import Optional
from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState
from ...logger import info, debug
from backend.app.logger import info, debug
from ...model_services.chat_services import get_small_llm_service, get_chat_service
from .rag_nodes import rag_retrieve_node
from ._utils import dispatch_custom_event
@@ -113,10 +113,18 @@ async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig]
def _has_valid_rag_results(state: MainGraphState) -> bool:
"""检查 RAG 结果是否有效"""
rag_docs = getattr(state, "rag_docs", [])
"""检查 RAG 结果是否有效(基于置信度)"""
from .rag_nodes import RAG_CONFIDENCE_THRESHOLD
rag_context = getattr(state, "rag_context", "")
return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10)
rag_confidence = getattr(state, "rag_confidence", 0.0)
# 有结果且置信度足够
has_content = rag_context and len(rag_context) > 0
has_confidence = rag_confidence >= RAG_CONFIDENCE_THRESHOLD
info(f"[Fast RAG Check] has_content={has_content}, rag_confidence={rag_confidence:.2f}, threshold={RAG_CONFIDENCE_THRESHOLD}")
return has_content and has_confidence
async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState:

View File

@@ -8,7 +8,7 @@ from typing import Any, Dict
# 本地模块
from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change
from ...logger import info, warning
from backend.app.logger import info, warning
from langchain_core.runnables.config import RunnableConfig

View File

@@ -11,7 +11,7 @@ from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState
from ...logger import info, debug
from backend.app.logger import info, debug
from ...model_services.chat_services import get_small_llm_service
from ._utils import dispatch_custom_event

View File

@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage
from ...main_graph.state import MainGraphState
from ...agent.prompts import create_system_prompt
from ...utils.logging import log_state_change
from ...logger import debug, info, error
from backend.app.logger import debug, info, error
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
@@ -115,24 +115,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
):
chunks.append(chunk)
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks")
for i, chunk in enumerate(chunks[:10]): # 只打印前10个避免日志过多
chunk_type = type(chunk).__name__
chunk_content = getattr(chunk, 'content', '') if hasattr(chunk, 'content') else str(chunk)
# 打印更多属性
additional_kwargs = getattr(chunk, 'additional_kwargs', {}) or {}
response_metadata = getattr(chunk, 'response_metadata', {}) or {}
# 打印所有属性
info(f"[llm_call] chunk[{i}] type={chunk_type}")
info(f"[llm_call] chunk[{i}] content长度={len(chunk_content) if chunk_content else 0}, content={repr(chunk_content[:200] if chunk_content else '')}")
info(f"[llm_call] chunk[{i}] additional_kwargs={additional_kwargs}")
info(f"[llm_call] chunk[{i}] response_metadata keys={list(response_metadata.keys()) if response_metadata else []}")
info(f"[llm_call] chunk[{i}] response_metadata={response_metadata}")
# 检查是否有其他可能存储内容的属性
if hasattr(chunk, 'tool_call_chunks'):
info(f"[llm_call] chunk[{i}] tool_call_chunks={chunk.tool_call_chunks}")
if hasattr(chunk, 'usage_metadata'):
info(f"[llm_call] chunk[{i}] usage_metadata={chunk.usage_metadata}")
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks}")
# 将所有 chunk 合并成最终的 AIMessage
if chunks:

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...logger import info
from backend.app.logger import info
# 全局变量,在 GraphBuilder 中注入

View File

@@ -1,6 +1,6 @@
"""
RAG 检索节点模块
使用模块级变量管理 RAG 工具
包含RAG 检索、置信度判断、重检索等节点
"""
import time
@@ -11,10 +11,15 @@ from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from ...logger import info
from backend.app.logger import info, debug
from ...model_services import get_small_llm_service
from ._utils import dispatch_custom_event, make_react_event
# 置信度阈值配置
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
def _get_rag_tool() -> Optional[callable]:
"""获取 RAG 工具"""
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
@@ -36,43 +41,27 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
# 调用 RAG 工具
rag_context = await rag_tool.ainvoke(retrieval_query)
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
info(f"[RAG Core] ========================================")
# 更新状态
state.rag_context = rag_context
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
state.rag_retrieved = True
state.success = True
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
state.debug_info["rag_source"] = "tool"
info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}")
return state
# ========== RAG 检索节点 ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""RAG 检索节点:带超时和重试"""
"""RAG 检索节点:检索 + 置信度评估"""
state.current_phase = "rag_retrieving"
start_time = time.time()
last_error = None
# 获取 RAG 工具
rag_tool = _get_rag_tool()
if not rag_tool:
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message="RAG 工具未初始化",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=RAG_RETRY_CONFIG.max_retries,
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
info("[RAG] RAG 工具未初始化")
state.rag_confidence = 0.0
state.rag_retrieved = False
return state
await dispatch_custom_event(
@@ -81,99 +70,184 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
config
)
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
result = await _rag_retrieve_core(state, rag_tool)
try:
state = await _rag_retrieve_core(state, rag_tool)
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
# 评估置信度
confidence = await _evaluate_rag_confidence(state)
state.rag_confidence = confidence
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
info(f"[RAG] 检索完成,置信度={confidence:.2f}RAG尝试次数={state.rag_attempts}")
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": 1.0,
"reasoning": "RAG 检索完成",
"timestamp": datetime.now().isoformat()
})
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": confidence,
"reasoning": f"RAG 检索完成,置信度={confidence:.2f}",
"timestamp": datetime.now().isoformat()
})
doc_count = len(result.rag_docs) if result.rag_docs else 0
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0,
f"RAG 检索完成,找到 {doc_count} 条相关文档"),
config
)
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence,
f"RAG 检索完成,置信度={confidence:.2f}"),
config
)
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
config
)
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 失败记录
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": 0.0,
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
"timestamp": datetime.now().isoformat()
})
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
f"RAG 检索失败: {str(last_error)}"),
config
)
except Exception as e:
info(f"[RAG] 检索失败: {e}")
state.rag_confidence = 0.0
state.rag_retrieved = False
return state
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""重新检索节点"""
state.current_phase = "rag_re_retrieving"
async def _evaluate_rag_confidence(state: MainGraphState) -> float:
"""评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)"""
query = state.user_query or ""
rag_context = state.rag_context or ""
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
if not rag_context:
return 0.0
return await rag_retrieve_node(state, config)
# 方式1: 向量相似度(从 rag_docs 中获取)
embedding_score = _get_embedding_similarity(state, query)
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
# 方式2: 重排序分数(从 rag_docs 中获取)
rerank_score = _get_rerank_score(state)
info(f"[RAG Confidence] 重排分数={rerank_score:.3f}")
# 方式3: 小模型判断
llm_score = await _get_llm_score(state)
info(f"[RAG Confidence] LLM评估={llm_score:.3f}")
# 综合得分(加权平均)
# 向量相似度权重 0.3,重排权重 0.3LLM 权重 0.4
final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4
info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)")
return final_score
def _get_embedding_similarity(state: MainGraphState) -> float:
"""从 rag_docs 中获取向量相似度分数"""
rag_docs = getattr(state, "rag_docs", [])
# 如果有多个文档,取最高分
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("score", 0.0)
# 向量相似度通常在 0-1 之间RRF 分数可能更高
# 归一化到 0-1
if score > 1.0:
score = min(score / 10.0, 1.0) # 假设 max 约 10
scores.append(score)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("score", 0.0)
if score > 1.0:
score = min(score / 10.0, 1.0)
scores.append(score)
if scores:
# 取平均或最高分
return max(scores) # 使用最高分更准确
return 0.0
def _get_rerank_score(state: MainGraphState) -> float:
"""从 rag_docs 中获取重排序分数"""
rag_docs = getattr(state, "rag_docs", [])
# 重排分数通常在 0-1 之间
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("rerank_score", 0.0)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("rerank_score", 0.0)
else:
score = 0.0
if score > 0:
scores.append(score)
if scores:
return max(scores) # 使用最高分
return 0.0
async def _get_llm_score(state: MainGraphState) -> float:
"""使用小模型评估检索结果相关性"""
query = state.user_query or ""
rag_context = state.rag_context or ""
try:
llm = get_small_llm_service()
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
- 1.0 = 完全相关,能直接回答问题
- 0.5 = 部分相关,有一定参考价值
- 0.0 = 完全不相关,无法回答问题
用户问题:{query}
检索结果:{rag_context[:1500]}
只返回一个数字:"""
response = await llm.ainvoke(prompt)
content = response.content.strip()
import re
match = re.search(r'(\d+\.?\d*)', content)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
except Exception as e:
info(f"[RAG Confidence] LLM评估失败: {e}")
return 0.5 # 默认中等置信度
# ========== 置信度判断节点 ==========
def check_rag_confidence(state: MainGraphState) -> str:
"""
根据 RAG 置信度判断下一步
Returns:
"high_confidence" - 高置信度(>=0.6),可直接生成回答
"low_confidence" - 低置信度(<0.6),需要联网搜索
"no_rag" - 无检索结果,需要联网搜索
"""
rag_attempts = getattr(state, 'rag_attempts', 0)
rag_confidence = getattr(state, 'rag_confidence', 0.0)
info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}")
# 情况1: 没有检索结果
if not getattr(state, 'rag_retrieved', False) or not state.rag_context:
info("[Confidence Check] 无检索结果,走联网")
return "no_rag"
# 情况2: 置信度低于阈值
if rag_confidence < RAG_CONFIDENCE_THRESHOLD:
if rag_attempts >= 2:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}且RAG尝试{rag_attempts}次,走联网")
return "low_confidence"
else:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}可再尝试RAG一次")
return "retry_rag"
# 情况3: 高置信度
info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答")
return "high_confidence"
# ========== 导出 ==========
__all__ = [
"rag_retrieve_node",
"rag_re_retrieve_node",
"check_rag_confidence",
"RAG_CONFIDENCE_THRESHOLD",
]

View File

@@ -7,9 +7,9 @@ from typing import Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ...core.intent import react_reason_async, ReasoningResult
from backend.app.core.intent import react_reason_async, ReasoningResult
from ...main_graph.state import MainGraphState
from ...logger import info
from backend.app.logger import info
from ._utils import dispatch_custom_event, make_react_event

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...utils.logging import log_state_change
from ...logger import debug
from backend.app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):

View File

@@ -10,9 +10,9 @@
from datetime import datetime
from ...core.intent import get_route_by_reasoning, ReasoningAction
from backend.app.core.intent import get_route_by_reasoning, ReasoningAction
from ...main_graph.state import MainGraphState
from ...logger import info
from backend.app.logger import info
# ========== 初始化状态节点 ==========
@@ -97,6 +97,12 @@ def route_by_reasoning(state: MainGraphState) -> str:
}
target = route_mapping.get(route, "llm_call")
# ========== RAG 次数硬限制 ==========
rag_attempts = getattr(state, 'rag_attempts', 0)
if target == "rag_retrieve" and rag_attempts >= 2:
info(f"[条件路由] RAG已尝试{rag_attempts}次,强制走联网搜索")
target = "web_search"
# ========== 循环防护检测 ==========
# 1. 路由模式检测A→B→A→B 交替)
if len(previous_actions) >= 4:

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...utils.logging import log_state_change
from ...logger import debug, info, error, warning
from backend.app.logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):

View File

@@ -11,7 +11,7 @@ from ...main_graph.config import get_stream_writer
# 本地模块
from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change
from ...logger import debug, info
from backend.app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...logger import info
from backend.app.logger import info
async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: