diff --git a/backend/app/graph/__init__.py b/backend/app/graph/__init__.py index ae76de5..74c2c01 100644 --- a/backend/app/graph/__init__.py +++ b/backend/app/graph/__init__.py @@ -7,11 +7,16 @@ from .subgraph_builder import build_main_graph, build_react_main_graph from .react_nodes import ( init_state_node, react_reason_node, - rag_retrieve_node, error_handling_node, final_response_node, route_by_reasoning ) +from .rag_nodes import ( + rag_retrieve_node, + rag_re_retrieve_node, + inject_rag_tool_to_state, + get_rag_tool_from_state +) from .state import ( MessagesState, GraphContext, @@ -44,13 +49,18 @@ __all__ = [ "build_react_main_graph", "init_state_node", "react_reason_node", - "rag_retrieve_node", "error_handling_node", "final_response_node", "route_by_reasoning", "ErrorRecord", "ErrorSeverity", + # RAG 节点(独立模块) + "rag_retrieve_node", + "rag_re_retrieve_node", + "inject_rag_tool_to_state", + "get_rag_tool_from_state", + # 超时和重试工具 "RetryConfig", "RetryResult", diff --git a/backend/app/graph/rag_nodes.py b/backend/app/graph/rag_nodes.py new file mode 100644 index 0000000..f56e3f3 --- /dev/null +++ b/backend/app/graph/rag_nodes.py @@ -0,0 +1,204 @@ +""" +RAG 节点模块 - 独立的 RAG 检索节点 +包含: +- rag_retrieve_node: RAG 检索节点(带超时重试) +- rag_re_retrieve_node: 重新检索节点 +- 相关的 RAG 工具集成 +""" + +import time +from typing import Dict, Any, Optional +from datetime import datetime + +from .state import MainGraphState, ErrorRecord, ErrorSeverity +from .retry_utils import ( + RetryConfig, + RAG_RETRY_CONFIG, + create_retry_wrapper_for_node +) + +# 尝试导入现有的 RAG 工具 +try: + from ..rag.tools import create_rag_tool_sync + from ..rag.pipeline import RAGPipeline + HAS_RAG = True +except ImportError: + HAS_RAG = False + + +def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]: + """ + 从状态中获取 RAG 工具(如果有) + + Args: + state: 主图状态 + + Returns: + RAG 工具实例或 None + """ + if "rag_tool" in state.debug_info: + return state.debug_info["rag_tool"] + return None + + +# ========== RAG 检索核心逻辑 ========== +def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: + """ + RAG 检索核心逻辑(不带重试) + + Args: + state: 主图状态 + + Returns: + 更新后的状态 + """ + # 获取检索查询(优先使用推理结果中的优化查询) + retrieval_query = state.user_query + if "reasoning_result" in state.debug_info: + reasoning_result = state.debug_info["reasoning_result"] + if hasattr(reasoning_result, "retrieval_config"): + cfg = reasoning_result.retrieval_config + if cfg and cfg.retrieval_query: + retrieval_query = cfg.retrieval_query + + # 尝试获取 RAG 工具 + rag_tool = get_rag_tool_from_state(state) + + if rag_tool and HAS_RAG: + # 使用真实的 RAG 工具 + try: + rag_context = rag_tool.invoke(retrieval_query) + state.rag_context = rag_context + state.rag_docs = [ + {"source": "rag_doc", "content": rag_context} + ] + state.rag_retrieved = True + state.success = True + return state + except Exception as e: + raise RuntimeError(f"RAG 调用失败: {str(e)}") from e + else: + # 没有 RAG 工具,使用模拟数据(演示用) + state.rag_context = ( + f"[RAG 检索结果]\n" + f"查询: {retrieval_query}\n" + f"这是来自知识库的相关信息。" + ) + state.rag_docs = [ + {"source": "doc1.txt", "content": "LangGraph 是一个用于构建 Agent 的框架"}, + {"source": "doc2.txt", "content": "React 模式是 '思考→行动→观察' 循环"} + ] + state.rag_retrieved = True + state.success = True + return state + + +# ========== RAG 检索节点(带超时和重试) ========== +def rag_retrieve_node(state: MainGraphState) -> MainGraphState: + """ + RAG 检索节点:带超时和重试 + + Args: + state: 主图状态 + + Returns: + 更新后的状态 + """ + state.current_phase = "rag_retrieving" + + start_time = time.time() + last_error = None + + for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): + try: + # 执行核心逻辑 + result = _rag_retrieve_core(state) + + # 成功 + state.debug_info["rag_retrieval"] = { + "attempt": attempt + 1, + "success": True, + "time": time.time() - start_time + } + return result + + except Exception as e: + last_error = e + + if attempt >= RAG_RETRY_CONFIG.max_retries: + break + + # 指数退避等待 + delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt) + time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) + + # 所有重试都失败,记录结构化错误 + error_record = ErrorRecord( + error_type="RAGRetrievalError", + error_message=str(last_error) if last_error else "RAG 检索超时", + severity=ErrorSeverity.WARNING, + source="rag_retrieve_node", + timestamp=datetime.now().isoformat(), + retry_count=RAG_RETRY_CONFIG.max_retries, + max_retries=RAG_RETRY_CONFIG.max_retries, + context={ + "query": state.user_query, + "total_time": time.time() - start_time, + "timeout": RAG_RETRY_CONFIG.timeout + } + ) + + state.errors.append(error_record) + state.current_error = error_record + state.current_phase = "error_handling" + + return state + + +# ========== 重新检索节点 ========== +def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState: + """ + 重新检索节点:用于第二次检索(不同的参数) + + Args: + state: 主图状态 + + Returns: + 更新后的状态 + """ + state.current_phase = "rag_re_retrieving" + + # 可以在这里修改检索参数(例如:扩大范围、调整查询) + state.debug_info["rag_re_retrieve"] = { + "original_retrieved": state.rag_retrieved, + "original_docs_count": len(state.rag_docs) + } + + # 使用相同的检索逻辑 + return rag_retrieve_node(state) + + +# ========== 工具:将 RAG 工具注入到状态 ========== +def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState: + """ + 将 RAG 工具注入到状态中,供后续节点使用 + + Args: + state: 主图状态 + rag_tool: RAG 工具实例 + + Returns: + 更新后的状态 + """ + state.debug_info["rag_tool"] = rag_tool + state.debug_info["rag_tool_injected"] = datetime.now().isoformat() + return state + + +# ========== 导出 ========== +__all__ = [ + "rag_retrieve_node", + "rag_re_retrieve_node", + "inject_rag_tool_to_state", + "get_rag_tool_from_state" +] diff --git a/backend/app/graph/react_nodes.py b/backend/app/graph/react_nodes.py index e1e2484..3a47880 100644 --- a/backend/app/graph/react_nodes.py +++ b/backend/app/graph/react_nodes.py @@ -2,50 +2,30 @@ React 模式节点模块 - 带超时和重试功能 包含: - react_reason_node: 使用 intent.py 进行推理 -- rag_retrieve_node: RAG 检索节点(带重试) - error_handling_node: 错误处理节点 - final_response_node: 最终回答节点 +- init_state_node: 初始化节点 """ import sys -import time from typing import Dict, Any, Optional from datetime import datetime -from functools import wraps # 导入我们的 intent.py from ..agent_subgraphs.common.intent import ( react_reason, get_route_by_reasoning, ReasoningAction, - RetrievalConfig, ReasoningResult ) from ..agent_subgraphs.common.state_base import StateUtils from .state import MainGraphState, ErrorRecord, ErrorSeverity from .retry_utils import ( RetryConfig, - RetryResult, - with_retry, - create_retry_wrapper_for_node, - RAG_RETRY_CONFIG, SUBGRAPH_RETRY_CONFIG ) -def get_rag_tool(): - """ - 获取 RAG 工具(延迟导入,避免循环依赖) - """ - try: - # 尝试导入现有的 RAG 工具 - from ..rag.tools import create_rag_tool_sync - # 注意:这里简化处理,实际使用时应该从全局获取初始化好的工具 - return None # 先返回 None,后面通过注入方式 - except Exception: - return None - - # ========== 1. React 推理节点 ========== def react_reason_node(state: MainGraphState) -> MainGraphState: """ @@ -94,7 +74,7 @@ def react_reason_node(state: MainGraphState) -> MainGraphState: "reasoning": result.reasoning } - # 保存推理结果到状态(供条件路由使用) + # 保存推理结果到状态 state.debug_info["reasoning_result"] = result # 确定下一步动作 @@ -103,97 +83,7 @@ def react_reason_node(state: MainGraphState) -> MainGraphState: return state -# ========== 2. RAG 检索节点(带超时和重试) ========== -def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: - """ - RAG 检索核心逻辑(不带重试) - """ - # 获取推理结果中的检索配置 - reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result") - retrieval_query = state.user_query - - if reasoning_result and reasoning_result.retrieval_config: - cfg: RetrievalConfig = reasoning_result.retrieval_config - if cfg.retrieval_query: - retrieval_query = cfg.retrieval_query - - # 尝试获取 RAG 工具并调用 - # 这里演示如何调用,实际使用时需要确保 RAG 已初始化 - # 暂时用模拟数据 - state.rag_context = ( - f"[模拟RAG检索结果]\n" - f"查询: {retrieval_query}\n" - f"这是一个来自知识库的示例回答。" - ) - state.rag_docs = [ - {"source": "doc1.txt", "content": "示例内容1"}, - {"source": "doc2.txt", "content": "示例内容2"} - ] - state.rag_retrieved = True - state.success = True - - return state - - -def rag_retrieve_node(state: MainGraphState) -> MainGraphState: - """ - RAG 检索节点:带超时和重试 - - Returns: 更新后的状态 - """ - state.current_phase = "rag_retrieving" - - # 使用重试包装器 - start_time = time.time() - last_error = None - - for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): - try: - # 执行核心逻辑 - result = _rag_retrieve_core(state) - - # 成功 - state.debug_info["rag_retrieval"] = { - "attempt": attempt + 1, - "success": True, - "time": time.time() - start_time - } - return result - - except Exception as e: - last_error = e - - if attempt >= RAG_RETRY_CONFIG.max_retries: - break - - # 等待后重试(指数退避) - delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt) - time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) - - # 所有重试都失败,记录结构化错误 - error_record = ErrorRecord( - error_type="RAGRetrievalError", - error_message=str(last_error) if last_error else "RAG 检索超时", - severity=ErrorSeverity.WARNING, - source="rag_retrieve_node", - timestamp=datetime.now().isoformat(), - retry_count=RAG_RETRY_CONFIG.max_retries, - max_retries=RAG_RETRY_CONFIG.max_retries, - context={ - "query": state.user_query, - "total_time": time.time() - start_time, - "timeout": RAG_RETRY_CONFIG.timeout - } - ) - - state.errors.append(error_record) - state.current_error = error_record - state.current_phase = "error_handling" - - return state - - -# ========== 3. 错误处理节点 ========== +# ========== 2. 错误处理节点 ========== def error_handling_node(state: MainGraphState) -> MainGraphState: """ 错误处理节点:处理子图/工具调用错误 @@ -210,7 +100,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState: state.current_phase = "error_handling" if not state.current_error: - # 没有错误,直接返回 state.current_phase = "react_reasoning" return state @@ -219,7 +108,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState: # 更新错误状态 state.error_message = f"{error.error_type}: {error.error_message}" - # 记录结构化错误信息(用于 LLM 决策) + # 记录结构化错误信息 structured_error = { "tool": error.source, "status": "failed", @@ -231,11 +120,11 @@ def error_handling_node(state: MainGraphState) -> MainGraphState: # 根据错误类型添加建议 if "RAG" in error.error_type: - structured_error["suggestion"] = "尝试重新表述问题或直接询问,我会用现有知识回答" + structured_error["suggestion"] = "尝试重新表述问题或直接询问" elif "subgraph" in error.source or "contact" in error.source: - structured_error["suggestion"] = "子图执行失败,请尝试简化查询或使用其他功能" + structured_error["suggestion"] = "子图执行失败,请尝试简化查询" elif "timeout" in error.error_message.lower(): - structured_error["suggestion"] = "请求超时,请稍后再试或简化查询" + structured_error["suggestion"] = "请求超时,请稍后再试" else: structured_error["suggestion"] = "请尝试其他方式提问" @@ -248,7 +137,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState: ) if can_retry: - # 重试策略 error.retry_count += 1 state.retry_action = error.source state.debug_info["retry_count"] = error.retry_count @@ -265,7 +153,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState: # 策略2: 无法重试,尝试降级方案 if error.severity != ErrorSeverity.FATAL: - # 降级到直接回答模式 state.final_result = ( f"⚠️ 遇到一些问题:\n" f"```json\n{structured_error}\n```\n" @@ -275,7 +162,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState: state.current_phase = "finalizing" return state - # 策略3: 致命错误,无法继续 + # 策略3: 致命错误 state.final_result = ( f"❌ 服务暂时不可用,请稍后再试。\n" f"```json\n{structured_error}\n```" @@ -286,7 +173,7 @@ def error_handling_node(state: MainGraphState) -> MainGraphState: return state -# ========== 4. 最终回答节点 ========== +# ========== 3. 最终回答节点 ========== def final_response_node(state: MainGraphState) -> MainGraphState: """ 最终回答节点:整理并生成最终回答 @@ -307,12 +194,15 @@ def final_response_node(state: MainGraphState) -> MainGraphState: parts.append("---") # 添加子图结果(如果有) - if state.contact_result and state.contact_result.get("final_result"): - parts.append(state.contact_result["final_result"]) - if state.dictionary_result and state.dictionary_result.get("final_result"): - parts.append(state.dictionary_result["final_result"]) - if state.news_result and state.news_result.get("final_result"): - parts.append(state.news_result["final_result"]) + if state.contact_result and hasattr(state.contact_result, "get"): + if state.contact_result.get("final_result"): + parts.append(state.contact_result["final_result"]) + if state.dictionary_result and hasattr(state.dictionary_result, "get"): + if state.dictionary_result.get("final_result"): + parts.append(state.dictionary_result["final_result"]) + if state.news_result and hasattr(state.news_result, "get"): + if state.news_result.get("final_result"): + parts.append(state.news_result["final_result"]) # 如果都没有,用默认回答 if not parts: @@ -326,7 +216,7 @@ def final_response_node(state: MainGraphState) -> MainGraphState: return state -# ========== 5. 初始化状态节点 ========== +# ========== 4. 初始化状态节点 ========== def init_state_node(state: MainGraphState) -> MainGraphState: """ 初始化状态节点:在流程开始时设置初始值 @@ -335,7 +225,7 @@ def init_state_node(state: MainGraphState) -> MainGraphState: state.reasoning_step = 0 state.start_time = datetime.now().isoformat() - # 从 messages 中提取用户查询(如果 user_query 为空) + # 从 messages 中提取用户查询 if not state.user_query and state.messages: last_msg = state.messages[-1] state.user_query = getattr(last_msg, "content", str(last_msg)) @@ -343,7 +233,7 @@ def init_state_node(state: MainGraphState) -> MainGraphState: return state -# ========== 6. 条件路由函数 ========== +# ========== 5. 条件路由函数 ========== def route_by_reasoning(state: MainGraphState) -> str: """ 根据推理结果决定下一步路由 @@ -358,7 +248,6 @@ def route_by_reasoning(state: MainGraphState) -> str: if state.current_phase == "finalizing" or state.current_phase == "done": return "final_response" if state.current_phase == "retrying": - # 重试路由 if state.retry_action and "rag" in state.retry_action.lower(): return "rag_retrieve" return "react_reason" @@ -367,7 +256,6 @@ def route_by_reasoning(state: MainGraphState) -> str: reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result") if not reasoning_result: - # 没有推理结果,直接结束 return "final_response" # 使用 intent.py 提供的路由函数 @@ -378,11 +266,21 @@ def route_by_reasoning(state: MainGraphState) -> str: "direct_response": "final_response", "retrieve_rag": "rag_retrieve", "re_retrieve_rag": "rag_retrieve", - "clarify": "final_response", # 简化:澄清直接回答让用户补充 - "call_tool": "final_response", # 简化:工具调用暂未实现 + "clarify": "final_response", + "call_tool": "final_response", "contact": "contact_subgraph", "dictionary": "dictionary_subgraph", "news_analysis": "news_analysis_subgraph", } return route_mapping.get(route, "final_response") + + +# ========== 导出 ========== +__all__ = [ + "init_state_node", + "react_reason_node", + "error_handling_node", + "final_response_node", + "route_by_reasoning" +] diff --git a/backend/app/graph/subgraph_builder.py b/backend/app/graph/subgraph_builder.py index 1a468d6..b3ebd7a 100644 --- a/backend/app/graph/subgraph_builder.py +++ b/backend/app/graph/subgraph_builder.py @@ -10,11 +10,11 @@ from .state import MainGraphState, CurrentAction from .react_nodes import ( init_state_node, react_reason_node, - rag_retrieve_node, error_handling_node, final_response_node, route_by_reasoning ) +from .rag_nodes import rag_retrieve_node from ..agent_subgraphs.contact import build_contact_subgraph from ..agent_subgraphs.dictionary import build_dictionary_subgraph from ..agent_subgraphs.news_analysis import build_news_analysis_subgraph