推理优化
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m36s

This commit is contained in:
2026-05-06 04:26:06 +08:00
parent 1260bef5cb
commit ef6fbc1521
12 changed files with 313 additions and 129 deletions

View File

@@ -19,6 +19,23 @@ from ._utils import dispatch_custom_event, make_react_event
# 置信度阈值配置
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
# 全局 pipeline 实例
_rag_pipeline = None
def _get_rag_pipeline():
"""获取 RAG Pipeline 实例"""
global _rag_pipeline
if _rag_pipeline is None:
from backend.app.rag.pipeline import RAGPipeline
_rag_pipeline = RAGPipeline(
num_queries=3,
rerank_top_n=5,
use_rerank=True,
return_parent_docs=True,
)
return _rag_pipeline
def _get_rag_tool() -> Optional[callable]:
"""获取 RAG 工具"""
@@ -27,7 +44,7 @@ def _get_rag_tool() -> Optional[callable]:
# ========== RAG 检索核心逻辑 ==========
async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainGraphState:
async def _rag_retrieve_core(state: MainGraphState, pipeline) -> MainGraphState:
"""执行 RAG 检索的核心逻辑"""
retrieval_query = state.user_query
@@ -38,15 +55,20 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 调用 RAG 工具
rag_context = await rag_tool.ainvoke(retrieval_query)
# 直接调用 pipeline 获取文档和上下文
documents = await pipeline.aretrieve(retrieval_query)
rag_context = pipeline.format_context(documents)
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
info(f"[RAG Core] 获取到 rag_docs: {len(documents)} 个文档")
# 更新状态
state.rag_context = rag_context
state.rag_retrieved = True
state.rag_docs = documents # 保存文档用于置信度评估
state.rag_retrieved = bool(documents) # 有文档才算检索成功
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
state.debug_info["rag_source"] = "tool"
state.debug_info["rag_source"] = "pipeline"
state.debug_info["rag_scores"] = pipeline.last_scores # 保存分数信息
return state
@@ -57,12 +79,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
state.current_phase = "rag_retrieving"
start_time = time.time()
rag_tool = _get_rag_tool()
if not rag_tool:
info("[RAG] RAG 工具未初始化")
state.rag_confidence = 0.0
state.rag_retrieved = False
return state
pipeline = _get_rag_pipeline()
await dispatch_custom_event(
"react_reasoning",
@@ -71,7 +88,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
)
try:
state = await _rag_retrieve_core(state, rag_tool)
state = await _rag_retrieve_core(state, pipeline)
# 评估置信度
confidence = await _evaluate_rag_confidence(state)
@@ -111,7 +128,7 @@ async def _evaluate_rag_confidence(state: MainGraphState) -> float:
return 0.0
# 方式1: 向量相似度(从 rag_docs 中获取)
embedding_score = _get_embedding_similarity(state, query)
embedding_score = _get_embedding_similarity(state)
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
# 方式2: 重排序分数(从 rag_docs 中获取)
@@ -131,36 +148,43 @@ async def _evaluate_rag_confidence(state: MainGraphState) -> float:
def _get_embedding_similarity(state: MainGraphState) -> float:
"""从 rag_docs 中获取向量相似度分数"""
rag_docs = getattr(state, "rag_docs", [])
""" rag_scores 或 rag_docs 中获取向量相似度分数"""
# 优先从 pipeline 提供的分数中获取
rag_scores = state.debug_info.get("rag_scores", [])
if rag_scores:
scores = [s.get("embedding_score", 0.0) for s in rag_scores]
if scores:
# 归一化到 0-1
normalized = [min(s / 10.0, 1.0) if s > 1.0 else s for s in scores]
return max(normalized)
# 如果有多个文档,取最高分
# 降级:从 rag_docs 中获取
rag_docs = getattr(state, "rag_docs", [])
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("score", 0.0)
# 向量相似度通常在 0-1 之间RRF 分数可能更高
# 归一化到 0-1
if score > 1.0:
score = min(score / 10.0, 1.0) # 假设 max 约 10
scores.append(score)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("score", 0.0)
if score > 1.0:
score = min(score / 10.0, 1.0)
scores.append(score)
score = doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0))
else:
continue
if score > 1.0:
score = min(score / 10.0, 1.0)
scores.append(score)
if scores:
# 取平均或最高分
return max(scores) # 使用最高分更准确
return 0.0
return max(scores) if scores else 0.0
def _get_rerank_score(state: MainGraphState) -> float:
"""从 rag_docs 中获取重排序分数"""
rag_docs = getattr(state, "rag_docs", [])
""" rag_scores 或 rag_docs 中获取重排序分数"""
# 优先从 pipeline 提供的分数中获取
rag_scores = state.debug_info.get("rag_scores", [])
if rag_scores:
scores = [s.get("rerank_score", 0.0) for s in rag_scores]
return max(scores) if scores else 0.0
# 重排分数通常在 0-1 之间
# 降级:从 rag_docs 中获取
rag_docs = getattr(state, "rag_docs", [])
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
@@ -168,14 +192,11 @@ def _get_rerank_score(state: MainGraphState) -> float:
elif hasattr(doc, "metadata"):
score = doc.metadata.get("rerank_score", 0.0)
else:
score = 0.0
continue
if score > 0:
scores.append(score)
if scores:
return max(scores) # 使用最高分
return 0.0
return max(scores) if scores else 0.0
async def _get_llm_score(state: MainGraphState) -> float: