This commit is contained in:
@@ -21,7 +21,7 @@ from .nodes.fast_paths import (
|
||||
fast_tool_node,
|
||||
)
|
||||
from .nodes.llm_call import create_dynamic_llm_call_node
|
||||
from .nodes.rag_nodes import rag_retrieve_node, check_rag_confidence
|
||||
from .nodes.rag_nodes import rag_retrieve_node
|
||||
from .nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from .nodes.summarize import create_summarize_node
|
||||
@@ -164,7 +164,7 @@ def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) ->
|
||||
}
|
||||
)
|
||||
|
||||
# 快速路径的完成检查
|
||||
# 快速路径的完成检查(fast_rag 失败直接走 react_reason)
|
||||
for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]:
|
||||
graph.add_conditional_edges(
|
||||
fast_node,
|
||||
@@ -198,17 +198,8 @@ def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) ->
|
||||
}
|
||||
)
|
||||
|
||||
# RAG 检索后的置信度判断分支
|
||||
graph.add_conditional_edges(
|
||||
"rag_retrieve",
|
||||
check_rag_confidence,
|
||||
{
|
||||
"high_confidence": "llm_call", # 高置信度 → 直接生成回答
|
||||
"retry_rag": "rag_retrieve", # 低置信度 → 再次检索
|
||||
"low_confidence": "web_search", # 两次RAG后仍低 → 联网搜索
|
||||
"no_rag": "web_search", # 无结果 → 联网搜索
|
||||
}
|
||||
)
|
||||
# RAG 检索后回到 react_reason,由意图识别决定下一步
|
||||
graph.add_edge("rag_retrieve", "react_reason")
|
||||
|
||||
# 循环边(回到 react_reason)
|
||||
loop_back_nodes = ["web_search", "handle_error"] + subgraph_names
|
||||
|
||||
@@ -103,8 +103,9 @@ async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig]
|
||||
# 注意:这里不设置 final_result,让 llm_call 节点处理
|
||||
return state
|
||||
|
||||
# 无效结果:升级到 React 循环
|
||||
# 检索结果无效:标记失败,升级到 React 循环
|
||||
info("[Fast RAG] 无有效检索结果,升级到 React 循环")
|
||||
await dispatch_custom_event("fast_path_end", {"path": "fast_rag", "success": False}, config)
|
||||
return _mark_fast_path_failed(state, "无有效检索结果")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -18,24 +18,20 @@ from backend.app.logger import debug, info, error
|
||||
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
|
||||
"""
|
||||
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
|
||||
|
||||
|
||||
Args:
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
|
||||
tools: 工具列表(llm_call 不使用工具,只负责回答)
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 预构建所有模型的 tools 绑定(避免每次调用都 bind)
|
||||
bound_models: Dict[str, Any] = {}
|
||||
for name, llm in chat_services.items():
|
||||
if tools:
|
||||
bound_models[name] = llm.bind_tools(tools)
|
||||
else:
|
||||
bound_models[name] = llm
|
||||
|
||||
# 预构建 prompt
|
||||
prompt = create_system_prompt(tools)
|
||||
# llm_call 节点不使用工具,只负责生成回答
|
||||
# 直接使用原始模型,不绑定工具
|
||||
models = chat_services
|
||||
|
||||
# 预构建 prompt(不带工具描述)
|
||||
prompt = create_system_prompt()
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@@ -70,14 +66,14 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
|
||||
# 动态选择模型
|
||||
model_name = getattr(state, "current_model", "")
|
||||
if not model_name or model_name not in bound_models:
|
||||
if not model_name or model_name not in models:
|
||||
# 回退到第一个可用模型
|
||||
fallback_name = next(iter(bound_models.keys()))
|
||||
fallback_name = next(iter(models.keys()))
|
||||
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
|
||||
model_name = fallback_name
|
||||
|
||||
llm_with_tools = bound_models[model_name]
|
||||
info(f"[llm_call] 使用模型: {model_name}")
|
||||
|
||||
llm = models[model_name]
|
||||
info(f"[llm_call] 使用模型(无工具): {model_name}")
|
||||
|
||||
try:
|
||||
# 添加上下文到消息
|
||||
@@ -103,7 +99,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chain = prompt | llm_with_tools
|
||||
chain = prompt | llm
|
||||
chunks = []
|
||||
info(f"[llm_call] 开始调用 LLM astream...")
|
||||
async for chunk in chain.astream(
|
||||
@@ -115,8 +111,13 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks}")
|
||||
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks[0].content[:50]}...{chunks[-1].content[:50]}")
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0].content
|
||||
for chunk in chunks[1:]:
|
||||
response = response + chunk.content
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
@@ -167,9 +168,6 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
# 检查是否有工具调用
|
||||
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
|
||||
|
||||
result = {
|
||||
"messages": [response],
|
||||
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
|
||||
@@ -179,7 +177,6 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
"final_result": response.content,
|
||||
"success": True,
|
||||
"current_phase": "done",
|
||||
"has_tool_calls": has_tool_calls,
|
||||
"current_model": model_name # 记录实际使用的模型
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,23 @@ from ._utils import dispatch_custom_event, make_react_event
|
||||
# 置信度阈值配置
|
||||
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
|
||||
|
||||
# 全局 pipeline 实例
|
||||
_rag_pipeline = None
|
||||
|
||||
|
||||
def _get_rag_pipeline():
|
||||
"""获取 RAG Pipeline 实例"""
|
||||
global _rag_pipeline
|
||||
if _rag_pipeline is None:
|
||||
from backend.app.rag.pipeline import RAGPipeline
|
||||
_rag_pipeline = RAGPipeline(
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
use_rerank=True,
|
||||
return_parent_docs=True,
|
||||
)
|
||||
return _rag_pipeline
|
||||
|
||||
|
||||
def _get_rag_tool() -> Optional[callable]:
|
||||
"""获取 RAG 工具"""
|
||||
@@ -27,7 +44,7 @@ def _get_rag_tool() -> Optional[callable]:
|
||||
|
||||
|
||||
# ========== RAG 检索核心逻辑 ==========
|
||||
async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainGraphState:
|
||||
async def _rag_retrieve_core(state: MainGraphState, pipeline) -> MainGraphState:
|
||||
"""执行 RAG 检索的核心逻辑"""
|
||||
retrieval_query = state.user_query
|
||||
|
||||
@@ -38,15 +55,20 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
|
||||
if cfg and cfg.retrieval_query:
|
||||
retrieval_query = cfg.retrieval_query
|
||||
|
||||
# 调用 RAG 工具
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
# 直接调用 pipeline 获取文档和上下文
|
||||
documents = await pipeline.aretrieve(retrieval_query)
|
||||
rag_context = pipeline.format_context(documents)
|
||||
|
||||
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
|
||||
info(f"[RAG Core] 获取到 rag_docs: {len(documents)} 个文档")
|
||||
|
||||
# 更新状态
|
||||
state.rag_context = rag_context
|
||||
state.rag_retrieved = True
|
||||
state.rag_docs = documents # 保存文档用于置信度评估
|
||||
state.rag_retrieved = bool(documents) # 有文档才算检索成功
|
||||
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
|
||||
state.debug_info["rag_source"] = "tool"
|
||||
state.debug_info["rag_source"] = "pipeline"
|
||||
state.debug_info["rag_scores"] = pipeline.last_scores # 保存分数信息
|
||||
|
||||
return state
|
||||
|
||||
@@ -57,12 +79,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
|
||||
state.current_phase = "rag_retrieving"
|
||||
start_time = time.time()
|
||||
|
||||
rag_tool = _get_rag_tool()
|
||||
if not rag_tool:
|
||||
info("[RAG] RAG 工具未初始化")
|
||||
state.rag_confidence = 0.0
|
||||
state.rag_retrieved = False
|
||||
return state
|
||||
pipeline = _get_rag_pipeline()
|
||||
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
@@ -71,7 +88,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
|
||||
)
|
||||
|
||||
try:
|
||||
state = await _rag_retrieve_core(state, rag_tool)
|
||||
state = await _rag_retrieve_core(state, pipeline)
|
||||
|
||||
# 评估置信度
|
||||
confidence = await _evaluate_rag_confidence(state)
|
||||
@@ -111,7 +128,7 @@ async def _evaluate_rag_confidence(state: MainGraphState) -> float:
|
||||
return 0.0
|
||||
|
||||
# 方式1: 向量相似度(从 rag_docs 中获取)
|
||||
embedding_score = _get_embedding_similarity(state, query)
|
||||
embedding_score = _get_embedding_similarity(state)
|
||||
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
|
||||
|
||||
# 方式2: 重排序分数(从 rag_docs 中获取)
|
||||
@@ -131,36 +148,43 @@ async def _evaluate_rag_confidence(state: MainGraphState) -> float:
|
||||
|
||||
|
||||
def _get_embedding_similarity(state: MainGraphState) -> float:
|
||||
"""从 rag_docs 中获取向量相似度分数"""
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
"""从 rag_scores 或 rag_docs 中获取向量相似度分数"""
|
||||
# 优先从 pipeline 提供的分数中获取
|
||||
rag_scores = state.debug_info.get("rag_scores", [])
|
||||
if rag_scores:
|
||||
scores = [s.get("embedding_score", 0.0) for s in rag_scores]
|
||||
if scores:
|
||||
# 归一化到 0-1
|
||||
normalized = [min(s / 10.0, 1.0) if s > 1.0 else s for s in scores]
|
||||
return max(normalized)
|
||||
|
||||
# 如果有多个文档,取最高分
|
||||
# 降级:从 rag_docs 中获取
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
scores = []
|
||||
for doc in rag_docs:
|
||||
if isinstance(doc, dict):
|
||||
score = doc.get("score", 0.0)
|
||||
# 向量相似度通常在 0-1 之间,RRF 分数可能更高
|
||||
# 归一化到 0-1
|
||||
if score > 1.0:
|
||||
score = min(score / 10.0, 1.0) # 假设 max 约 10
|
||||
scores.append(score)
|
||||
elif hasattr(doc, "metadata"):
|
||||
score = doc.metadata.get("score", 0.0)
|
||||
if score > 1.0:
|
||||
score = min(score / 10.0, 1.0)
|
||||
scores.append(score)
|
||||
score = doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0))
|
||||
else:
|
||||
continue
|
||||
if score > 1.0:
|
||||
score = min(score / 10.0, 1.0)
|
||||
scores.append(score)
|
||||
|
||||
if scores:
|
||||
# 取平均或最高分
|
||||
return max(scores) # 使用最高分更准确
|
||||
return 0.0
|
||||
return max(scores) if scores else 0.0
|
||||
|
||||
|
||||
def _get_rerank_score(state: MainGraphState) -> float:
|
||||
"""从 rag_docs 中获取重排序分数"""
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
"""从 rag_scores 或 rag_docs 中获取重排序分数"""
|
||||
# 优先从 pipeline 提供的分数中获取
|
||||
rag_scores = state.debug_info.get("rag_scores", [])
|
||||
if rag_scores:
|
||||
scores = [s.get("rerank_score", 0.0) for s in rag_scores]
|
||||
return max(scores) if scores else 0.0
|
||||
|
||||
# 重排分数通常在 0-1 之间
|
||||
# 降级:从 rag_docs 中获取
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
scores = []
|
||||
for doc in rag_docs:
|
||||
if isinstance(doc, dict):
|
||||
@@ -168,14 +192,11 @@ def _get_rerank_score(state: MainGraphState) -> float:
|
||||
elif hasattr(doc, "metadata"):
|
||||
score = doc.metadata.get("rerank_score", 0.0)
|
||||
else:
|
||||
score = 0.0
|
||||
|
||||
continue
|
||||
if score > 0:
|
||||
scores.append(score)
|
||||
|
||||
if scores:
|
||||
return max(scores) # 使用最高分
|
||||
return 0.0
|
||||
return max(scores) if scores else 0.0
|
||||
|
||||
|
||||
async def _get_llm_score(state: MainGraphState) -> float:
|
||||
|
||||
@@ -23,6 +23,8 @@ async def react_reason_node(state: MainGraphState, config: Optional[RunnableConf
|
||||
# 步骤1: 准备上下文
|
||||
context = {
|
||||
"retrieved_docs": state.rag_docs,
|
||||
"rag_confidence": getattr(state, "rag_confidence", 0.0),
|
||||
"rag_attempts": getattr(state, "rag_attempts", 0),
|
||||
"previous_actions": [h.get("action") for h in state.reasoning_history],
|
||||
"reasoning_history": state.reasoning_history,
|
||||
"messages": state.messages,
|
||||
|
||||
@@ -112,8 +112,8 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
info(f"[条件路由] 检测到路由循环: {previous_actions[-4:]},强制终止")
|
||||
return "llm_call"
|
||||
|
||||
# 2. 状态停滞检测(连续相同动作)
|
||||
if len(previous_actions) >= 2 and previous_actions[-1] == previous_actions[-2]:
|
||||
# 2. 状态停滞检测(连续相同动作 TODO:本来应该是2)
|
||||
if len(previous_actions) >= 3 and previous_actions[-1] == previous_actions[-2] and previous_actions[-2] == previous_actions[-3]:
|
||||
info(f"[条件路由] 连续相同动作 '{previous_actions[-1]}',强制终止")
|
||||
return "llm_call"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user