- react_reason_node: 直接发送自定义推理事件 - web_search_node: 添加开始/完成/错误事件 - rag_retrieve_node: 添加开始/完成/重试/错误事件 - 子图包装器: 添加子图开始/完成/错误事件
This commit is contained in:
@@ -159,18 +159,38 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
|||||||
|
|
||||||
|
|
||||||
# ========== RAG 检索节点(带超时和重试)==========
|
# ========== RAG 检索节点(带超时和重试)==========
|
||||||
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||||
"""
|
"""
|
||||||
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
|
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 主图状态
|
state: 主图状态
|
||||||
|
config: LangChain 配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
更新后的状态
|
更新后的状态
|
||||||
"""
|
"""
|
||||||
state.current_phase = "rag_retrieving"
|
state.current_phase = "rag_retrieving"
|
||||||
|
|
||||||
|
# 发送开始事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": "rag_retrieve_start",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": "开始执行 RAG 检索..."
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[rag_retrieve_node] 无法发送开始事件: {e}")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
last_error = None
|
last_error = None
|
||||||
|
|
||||||
@@ -185,6 +205,27 @@ def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
|||||||
"success": True,
|
"success": True,
|
||||||
"time": time.time() - start_time
|
"time": time.time() - start_time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 发送完成事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
doc_count = len(result.rag_docs) if result.rag_docs else 0
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": "rag_retrieve_complete",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": f"RAG 检索完成,找到 {doc_count} 条相关文档"
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[rag_retrieve_node] 无法发送完成事件: {e}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -193,6 +234,25 @@ def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
|||||||
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 发送重试事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": "rag_retrieve_retry",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": f"RAG 检索失败,第 {attempt + 1} 次重试..."
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[rag_retrieve_node] 无法发送重试事件: {e}")
|
||||||
|
|
||||||
# 指数退避等待
|
# 指数退避等待
|
||||||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||||||
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||||
@@ -219,6 +279,25 @@ def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
|
|||||||
state.current_error = error_record
|
state.current_error = error_record
|
||||||
state.current_phase = "error_handling"
|
state.current_phase = "error_handling"
|
||||||
|
|
||||||
|
# 发送错误事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": "rag_retrieve_error",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": f"RAG 检索失败: {str(last_error)}"
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[rag_retrieve_node] 无法发送错误事件: {e}")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -121,12 +121,31 @@ async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, An
|
|||||||
|
|
||||||
# ========== 2. 联网搜索节点 ==========
|
# ========== 2. 联网搜索节点 ==========
|
||||||
|
|
||||||
def web_search_node(state: MainGraphState) -> MainGraphState:
|
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||||
"""
|
"""
|
||||||
联网搜索节点:执行搜索并将结果保存到状态
|
联网搜索节点:执行搜索并将结果保存到状态
|
||||||
"""
|
"""
|
||||||
state.current_phase = "web_searching"
|
state.current_phase = "web_searching"
|
||||||
|
|
||||||
|
# 发送开始事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": "web_search_start",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": "开始执行联网搜索..."
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[web_search_node] 无法发送开始事件: {e}")
|
||||||
|
|
||||||
# 获取搜索查询
|
# 获取搜索查询
|
||||||
reasoning_result = state.debug_info.get("reasoning_result")
|
reasoning_result = state.debug_info.get("reasoning_result")
|
||||||
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
|
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
|
||||||
@@ -151,6 +170,25 @@ def web_search_node(state: MainGraphState) -> MainGraphState:
|
|||||||
state.success = True
|
state.success = True
|
||||||
print(f"[WebSearch] 搜索完成")
|
print(f"[WebSearch] 搜索完成")
|
||||||
|
|
||||||
|
# 发送完成事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": "web_search_complete",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": f"联网搜索完成,找到 {len(search_result) if isinstance(search_result, list) else 1} 条结果"
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[web_search_node] 无法发送完成事件: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from app.main_graph.state import ErrorRecord, ErrorSeverity
|
from app.main_graph.state import ErrorRecord, ErrorSeverity
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -170,6 +208,25 @@ def web_search_node(state: MainGraphState) -> MainGraphState:
|
|||||||
state.current_phase = "error_handling"
|
state.current_phase = "error_handling"
|
||||||
state.success = False
|
state.success = False
|
||||||
|
|
||||||
|
# 发送错误事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": "web_search_error",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": f"联网搜索失败: {str(e)}"
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[web_search_node] 无法发送错误事件: {e}")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,26 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
|||||||
|
|
||||||
Returns: 包装后的节点函数
|
Returns: 包装后的节点函数
|
||||||
"""
|
"""
|
||||||
def wrapped_node(state: MainGraphState) -> MainGraphState:
|
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||||
|
# 发送子图开始事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": f"{name}_subgraph_start",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": f"开始执行 {name} 子图..."
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[{name}_subgraph] 无法发送开始事件: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用子图
|
# 调用子图
|
||||||
result = subgraph.invoke(state)
|
result = subgraph.invoke(state)
|
||||||
@@ -72,7 +91,7 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
|||||||
state.news_result = result
|
state.news_result = result
|
||||||
subgraph_result = result.get("final_result", "")
|
subgraph_result = result.get("final_result", "")
|
||||||
|
|
||||||
# 关键:设置最终结果,这样就不需要再回到 react_reason 了
|
# 关键:设置最终结果
|
||||||
if subgraph_result:
|
if subgraph_result:
|
||||||
state.final_result = subgraph_result
|
state.final_result = subgraph_result
|
||||||
else:
|
else:
|
||||||
@@ -89,6 +108,26 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
|||||||
"reasoning": f"{name}子图执行完成",
|
"reasoning": f"{name}子图执行完成",
|
||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# 发送子图完成事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": f"{name}_subgraph_complete",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": f"{name} 子图执行完成"
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[{name}_subgraph] 无法发送完成事件: {e}")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -111,6 +150,25 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
|||||||
state.current_phase = "error_handling"
|
state.current_phase = "error_handling"
|
||||||
state.success = False
|
state.success = False
|
||||||
|
|
||||||
|
# 发送子图错误事件
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||||
|
callbacks = config.get("callbacks")
|
||||||
|
if callbacks:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"react_reasoning",
|
||||||
|
{
|
||||||
|
"step": state.reasoning_step,
|
||||||
|
"action": f"{name}_subgraph_error",
|
||||||
|
"confidence": 1.0,
|
||||||
|
"reasoning": f"{name} 子图执行失败: {str(e)}"
|
||||||
|
},
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
info(f"[{name}_subgraph] 无法发送错误事件: {e}")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
return wrapped_node
|
return wrapped_node
|
||||||
|
|||||||
Reference in New Issue
Block a user