From 128aad0c2214932db9583706d1dbabad2905bff0 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Tue, 5 May 2026 04:32:42 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E5=BF=AB?= =?UTF-8?q?=E9=80=9F=E8=B7=AF=E5=BE=84=E6=B5=81=E7=A8=8B=EF=BC=8C=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E9=80=9A=E8=BF=87=20llm=5Fcall=20=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重构 fast_paths.py,让 fast_chitchat 和 fast_rag 都进入 llm_call 而不是直接设置 final_result - 修改 check_fast_path_success 函数返回 'llm_call' 而不是 'success' - 更新 main_graph_builder.py 的条件边配置,支持路由到 llm_call - 在快速路径节点中添加清除 state.final_result 的逻辑,避免复用旧结果 - 重构 RAG 工具初始化方式,使用模块级变量管理 - 修改 finalize.py 让它返回 final_result - 更新 agent_service.py 的 RAG 工具注入方式 - 简化 hybrid_router.py 的代码结构 - 清理 rag_nodes.py 的全局变量相关代码 - 更新相关测试文件 --- backend/app/agent/agent_service.py | 20 +- backend/app/main_graph/nodes/__init__.py | 6 + backend/app/main_graph/nodes/fast_paths.py | 203 ++++++ backend/app/main_graph/nodes/finalize.py | 22 +- backend/app/main_graph/nodes/hybrid_router.py | 611 +++++------------- backend/app/main_graph/nodes/rag_nodes.py | 125 ++-- .../main_graph/utils/main_graph_builder.py | 8 +- .../app/main_graph/utils/rag_initializer.py | 57 +- tools/run.py | 14 +- tools/start.py | 1 - tools/test/test_fast_rag_fix.py | 88 --- tools/test/test_graph_branches.py | 93 +-- tools/test/test_rag_pipeline.py | 1 - 13 files changed, 533 insertions(+), 716 deletions(-) create mode 100644 backend/app/main_graph/nodes/fast_paths.py delete mode 100644 tools/test/test_fast_rag_fix.py diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 6512ada..d5d29b6 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -48,9 +48,7 @@ class AIAgentService: if rag_tool: self.tools.append(rag_tool) self.tools_by_name[rag_tool.name] = rag_tool - # 关键:设置全局 RAG 工具,供 rag_nodes.py 使用 - from ..main_graph.nodes.rag_nodes import set_global_rag_tool - set_global_rag_tool(rag_tool) + self.rag_tool = rag_tool # 保存到实例变量,供 config 注入 # 2. 构建各模型的 Graph(使用新版 React 模式) for name, llm in chat_services.items(): @@ -82,7 +80,10 @@ class AIAgentService: graph = self.graphs[model] config = { - "configurable": {"thread_id": thread_id}, + "configurable": { + "thread_id": thread_id, + "rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具 + }, "metadata": {"user_id": user_id} } # 新版状态输入:传入完整的 MainGraphState,关键是 user_query @@ -136,7 +137,10 @@ class AIAgentService: raise ValueError(f"模型 '{model_name}' 未找到或未初始化") config = { - "configurable": {"thread_id": thread_id}, + "configurable": { + "thread_id": thread_id, + "rag_tool": getattr(self, "rag_tool", None), # 注入 RAG 工具 + }, "metadata": {"user_id": user_id} } input_state = { @@ -250,9 +254,11 @@ class AIAgentService: elif chunk_type == "updates": updates_data = chunk["data"] + info(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}") + # 特别检查 final_result + if isinstance(updates_data, dict) and "final_result" in updates_data: + info(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...") serialized_data = self._serialize_value(updates_data) - - # 检查是否有人工审核请求 if "review_pending" in serialized_data and serialized_data["review_pending"]: review_id = serialized_data.get("review_id", "") content_to_review = serialized_data.get("content_to_review", "") diff --git a/backend/app/main_graph/nodes/__init__.py b/backend/app/main_graph/nodes/__init__.py index caf44f7..a3cb165 100644 --- a/backend/app/main_graph/nodes/__init__.py +++ b/backend/app/main_graph/nodes/__init__.py @@ -19,6 +19,10 @@ from .finalize import finalize_node # 混合路由节点 from .hybrid_router import ( hybrid_router_node, + route_from_hybrid_decision, + check_fast_path_success, +) +from .fast_paths import ( fast_chitchat_node, fast_rag_node, fast_tool_node, @@ -45,6 +49,8 @@ __all__ = [ "finalize_node", # 混合路由节点 "hybrid_router_node", + "route_from_hybrid_decision", + "check_fast_path_success", "fast_chitchat_node", "fast_rag_node", "fast_tool_node", diff --git a/backend/app/main_graph/nodes/fast_paths.py b/backend/app/main_graph/nodes/fast_paths.py new file mode 100644 index 0000000..77139f6 --- /dev/null +++ b/backend/app/main_graph/nodes/fast_paths.py @@ -0,0 +1,203 @@ +""" +快速路径节点模块 +包含闲聊、RAG、工具等快速处理节点 +""" + +from typing import Optional + +from ..state import MainGraphState +from ...logger import info, debug +from ...model_services.chat_services import get_small_llm_service, get_chat_service +from .rag_nodes import rag_retrieve_node +from ._utils import dispatch_custom_event + + +# ========== 闲聊回复模板 ========== +CHITCHAT_TEMPLATES = { + "谢谢": "不客气!如果还有其他问题,请随时告诉我 😊", + "再见": "再见!期待下次为您服务 👋", + "你好": "你好!有什么我可以帮您的吗?", + "默认": None # 使用 LLM 生成 +} + +CHITCHAT_KEYWORDS = { + "谢谢": ["谢谢", "感谢", "thanks", "thank you"], + "再见": ["再见", "拜拜", "bye", "goodbye"], + "你好": ["你好", "您好", "hi", "hello", "hey", "早上好", "晚上好", "下午好"], +} + + +# ========== 闲聊节点 ========== +async def fast_chitchat_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: + """快速闲聊节点""" + state.current_phase = "fast_chitchat" + query = state.user_query or "" + info(f"[Fast Chitchat] 处理: {query[:50]}") + + # 发送开始事件 + await dispatch_custom_event("fast_path_start", {"path": "fast_chitchat"}, config) + + # 清除之前的 final_result,让 llm_call 生成新回答 + state.final_result = None + + # 标记快速路径成功,但不设置 final_result,让 llm_call 生成回答 + state.success = True + state.current_phase = "llm_call" + state.debug_info["fast_chitchat_success"] = True + + # 发送完成事件 + await dispatch_custom_event("fast_path_end", {"path": "fast_chitchat", "success": True}, config) + + return state + + +def _match_chitchat_template(query: str) -> str: + """匹配闲聊模板""" + query_clean = query.strip().lower() + + for intent, keywords in CHITCHAT_KEYWORDS.items(): + if any(kw in query_clean for kw in keywords): + return CHITCHAT_TEMPLATES[intent] + + # 默认:使用 LLM 生成 + try: + llm = get_small_llm_service() + response = llm.invoke(f"你是一个友好的助手。用户说:{query}。请简短友好地回复:") + return response.content + except Exception: + return "你好!有什么我可以帮您的吗?" + + +# ========== 快速 RAG 节点 ========== +async def fast_rag_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: + """快速 RAG 节点:只负责 RAG 检索,然后交给 llm_call 生成回答""" + state.current_phase = "fast_rag" + query = state.user_query or "" + info(f"[Fast RAG] 开始处理: {query[:50]}") + + # 获取 RAG 工具 + from app.main_graph.utils.rag_initializer import get_rag_tool + rag_tool = get_rag_tool() + info(f"[Fast RAG] 获取到 rag_tool: {rag_tool is not None}") + + # 发送开始事件 + await dispatch_custom_event("fast_path_start", {"path": "fast_rag"}, config) + + # 清除之前的 final_result,让 llm_call 生成新回答 + state.final_result = None + + # 如果没有 rag_tool,升级到 React 循环 + if not rag_tool: + info("[Fast RAG] 未找到 RAG 工具,升级到 React 循环") + return _mark_fast_path_failed(state, "未找到 RAG 工具") + + try: + # 尝试 RAG 检索 + state = await rag_retrieve_node(state, config) + + # 检查检索结果 + if _has_valid_rag_results(state): + info(f"[Fast RAG] 检索有效,进入 llm_call 生成回答") + await dispatch_custom_event("fast_path_end", {"path": "fast_rag", "success": True}, config) + # 注意:这里不设置 final_result,让 llm_call 节点处理 + return state + + # 无效结果:升级到 React 循环 + info("[Fast RAG] 无有效检索结果,升级到 React 循环") + return _mark_fast_path_failed(state, "无有效检索结果") + + except Exception as e: + info(f"[Fast RAG] 执行失败: {e}") + return _mark_fast_path_failed(state, str(e)) + + +def _has_valid_rag_results(state: MainGraphState) -> bool: + """检查 RAG 结果是否有效""" + rag_docs = getattr(state, "rag_docs", []) + rag_context = getattr(state, "rag_context", "") + return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10) + + +async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState: + """使用小模型快速生成回答""" + try: + chat_llm = get_chat_service() + rag_context = state.rag_context or str(state.rag_docs)[:2000] + + prompt = f"""请根据以下信息回答用户问题: + +检索到的信息: +{rag_context} + +用户问题:{query} + +请给出简洁、准确的回答:""" + + # 使用流式输出 + from app.main_graph.config import get_stream_writer + writer = get_stream_writer() + + full_content = "" + async for chunk in chat_llm.astream(prompt): + content = getattr(chunk, 'content', '') + if content: + full_content += content + # 流式输出 + if writer and hasattr(writer, '__call__'): + try: + writer({ + "type": "llm_token", + "token": content + }) + except Exception: + pass + + state.final_result = full_content + state.success = True + state.current_phase = "finalizing" + state.debug_info["fast_rag_success"] = True + return state + + except Exception as e: + info(f"[Fast RAG] 快速回答生成失败: {e}") + return _mark_fast_path_failed(state, "回答生成失败") + + +# ========== 快速工具节点 ========== +async def fast_tool_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: + """快速工具节点""" + state.current_phase = "fast_tool" + + decision = state.debug_info.get("hybrid_decision", {}) + suggested_tools = decision.get("suggested_tools", []) + info(f"[Fast Tool] 开始处理,建议工具: {suggested_tools}") + + await dispatch_custom_event("fast_path_start", {"path": "fast_tool", "suggested_tools": suggested_tools}, config) + + # 无明确工具建议,升级到 React 循环 + if not suggested_tools: + info("[Fast Tool] 无明确工具建议,升级到 React 循环") + return _mark_fast_path_failed(state, "无明确工具建议") + + # 当前版本暂不支持快速工具调用,升级到 React 循环 + info("[Fast Tool] 快速工具调用暂未完善,升级到 React 循环") + return _mark_fast_path_failed(state, "快速工具调用暂未完善") + + +# ========== 公共函数 ========== +def _mark_fast_path_failed(state: MainGraphState, reason: str = "") -> MainGraphState: + """标记快速路径失败,准备升级到 React 循环""" + state.debug_info["fast_path_failed"] = True + state.debug_info["fast_path_fail_reason"] = reason + state.success = False + info(f"[Fast Path] 标记失败,准备升级: {reason}") + return state + + +# ========== 导出 ========== +__all__ = [ + "fast_chitchat_node", + "fast_rag_node", + "fast_tool_node", + "_mark_fast_path_failed", +] diff --git a/backend/app/main_graph/nodes/finalize.py b/backend/app/main_graph/nodes/finalize.py index 5e6af14..25882d4 100644 --- a/backend/app/main_graph/nodes/finalize.py +++ b/backend/app/main_graph/nodes/finalize.py @@ -22,31 +22,39 @@ async def finalize_node(state: MainGraphState, config: RunnableConfig) -> Dict[s config: 运行时配置 Returns: - 空字典(完成节点,无状态更新) + 更新后的状态(包含 final_result) """ log_state_change("finalize", state, "进入") + # 确保 final_result 被传递出去 + result = { + "final_result": state.final_result, + "success": state.success, + "current_phase": "done" + } + try: # 获取流式写入器并发送完成事件 from app.main_graph.config import get_stream_writer writer = get_stream_writer() - + # 只在 writer 存在且不是 noop 时才发送 if writer and hasattr(writer, '__call__'): try: writer({ - "type": "custom", + "type": "custom", "data": { "type": "done", "token_usage": state.last_token_usage, - "elapsed_time": state.last_elapsed_time + "elapsed_time": state.last_elapsed_time, + "final_result": state.final_result } }) - info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息") + info("🏁 [完成事件] 已发送完成事件") except Exception as e: warning(f"⚠️ [完成事件] 发送完成事件失败 (非致命): {e}") except Exception as e: warning(f"⚠️ [完成事件] 处理失败 (非致命): {e}") - + log_state_change("finalize", state, "离开") - return {} \ No newline at end of file + return result \ No newline at end of file diff --git a/backend/app/main_graph/nodes/hybrid_router.py b/backend/app/main_graph/nodes/hybrid_router.py index f626b8c..8d8d7b8 100644 --- a/backend/app/main_graph/nodes/hybrid_router.py +++ b/backend/app/main_graph/nodes/hybrid_router.py @@ -1,107 +1,47 @@ """ -混合路由节点模块 - 前置路由 + 快速路径 +混合路由节点模块 - 前置路由决策 +负责决定走快速路径还是 React 循环 """ import re import json -from typing import Dict, Any, Optional, List +from typing import Optional from dataclasses import dataclass, field from datetime import datetime from ..state import MainGraphState from ...logger import info, debug -from ...model_services.chat_services import get_small_llm_service, get_chat_service -from .rag_nodes import rag_retrieve_node +from ...model_services.chat_services import get_small_llm_service +from ._utils import dispatch_custom_event # ========== 核心数据类型 ========== - @dataclass class HybridRouterResult: """混合路由结果""" intent: str = "complex" # chitchat / knowledge / tool / complex confidence: float = 0.0 - suggested_tools: List[str] = field(default_factory=list) + suggested_tools: list = field(default_factory=list) path: str = "react_loop" # fast_chitchat / fast_rag / fast_tool / react_loop reasoning: str = "" -# ========== 规则分流(无 LLM,<5ms) ========== - -# 问候、感谢等直接返回的关键词 -AL_CHITCHAT = { +# ========== 规则配置 ========== +CHITCHAT_KEYWORDS = { "你好", "您好", "hi", "hello", "hey", "早上好", "晚上好", "下午好", "谢谢", "感谢", "多谢", "thanks", "thank you", "再见", "拜拜", "goodbye", "bye" } -# 子图关键词映射 SUBGRAPH_KEYWORDS = { "contact": ["通讯录", "联系人", "contact", "email", "邮件", "邮箱"], "dictionary": ["词典", "单词", "翻译", "dictionary", "translate", "生词"], "news_analysis": ["资讯", "新闻", "分析", "news", "report", "热点"] } -def _rule_based_redirect(query: str) -> Optional[HybridRouterResult]: - """ - 规则分流:处理明显不需要推理的情况(超快速) - - Args: - query: 用户查询 - - Returns: - HybridRouterResult 或 None - """ - query_clean = query.strip().lower() - - # 1. 检查闲聊 - if query_clean in AL_CHITCHAT or any(keyword in query_clean for keyword in AL_CHITCHAT): - return HybridRouterResult( - intent="chitchat", - confidence=1.0, - path="fast_chitchat", - reasoning=f"规则匹配:闲聊类请求" - ) - - # 2. 检查子图关键词(直接调用工具) - for subgraph_name, keywords in SUBGRAPH_KEYWORDS.items(): - if any(kw in query_clean for kw in keywords): - return HybridRouterResult( - intent="tool", - confidence=0.9, - suggested_tools=[subgraph_name], - path="fast_tool", - reasoning=f"规则匹配:{subgraph_name} 子图关键词" - ) - - # 3. 检查是否是纯问号或很短的问题(可能需要澄清) - if len(query_clean) < 3 or (query_clean.endswith("?") and len(query_clean) < 5): - return HybridRouterResult( - intent="complex", - confidence=0.3, - path="react_loop", - reasoning="规则匹配:问题过于简短或不确定" - ) - - return None - -# ========== 轻量级 LLM 分类 ========== - -async def _classify_with_small_llm(query: str) -> HybridRouterResult: - """ - 使用轻量级 LLM 进行意图分类 - - Args: - query: 用户查询 - - Returns: - HybridRouterResult - """ - try: - llm = get_small_llm_service() - - prompt = f"""你是一个专业的意图分类助手。请分析用户的查询,并输出 JSON 格式的结果。 +# ========== 意图分类 Prompt 模板 ========== +INTENT_CLASSIFICATION_PROMPT = """你是一个专业的意图分类助手。请分析用户的查询,并输出 JSON 格式的结果。 意图类型(4选一): - chitchat: 闲聊、问候、感谢、道别(不需要工具) @@ -120,52 +60,94 @@ async def _classify_with_small_llm(query: str) -> HybridRouterResult: "suggested_tools": ["contact|dictionary|news_analysis", "other"] }} -注意:如果不能100%确定意图,请选择 "complex",置信度设低一些。 -""" - +注意:如果不能100%确定意图,请选择 "complex",置信度设低一些。""" + + +# ========== 规则分流(<5ms) ========== +def _rule_based_redirect(query: str) -> Optional[HybridRouterResult]: + """规则分流:处理明显不需要推理的情况""" + query_clean = query.strip().lower() + + # 1. 闲聊 + if query_clean in CHITCHAT_KEYWORDS or any(kw in query_clean for kw in CHITCHAT_KEYWORDS): + return HybridRouterResult( + intent="chitchat", + confidence=1.0, + path="fast_chitchat", + reasoning="规则匹配:闲聊类请求" + ) + + # 2. 子图关键词 + for subgraph_name, keywords in SUBGRAPH_KEYWORDS.items(): + if any(kw in query_clean for kw in keywords): + return HybridRouterResult( + intent="tool", + confidence=0.9, + suggested_tools=[subgraph_name], + path="fast_tool", + reasoning=f"规则匹配:{subgraph_name} 子图关键词" + ) + + # 3. 短问题 + if len(query_clean) < 3 or (query_clean.endswith("?") and len(query_clean) < 5): + return HybridRouterResult( + intent="complex", + confidence=0.3, + path="react_loop", + reasoning="规则匹配:问题过于简短" + ) + + return None + + +# ========== LLM 分类 ========== +async def _classify_with_llm(query: str) -> HybridRouterResult: + """使用轻量级 LLM 进行意图分类""" + try: + llm = get_small_llm_service() + prompt = INTENT_CLASSIFICATION_PROMPT.format(query=query) response = await llm.ainvoke(prompt) - content = response.content - + # 解析 JSON - json_match = re.search(r'(\{[^{}]*\{[^{}]*\}[^{}]*\})|(\{[^{}]*\})', content) - if json_match: - try: - data = json.loads(json_match.group(0)) - - intent = data.get("intent", "complex") - confidence = float(data.get("confidence", 0.3)) - reasoning = data.get("reasoning", "") - suggested_tools = data.get("suggested_tools", []) - - # 置信度低于 0.5 一律走 complex - if confidence < 0.5: - intent = "complex" - path = "react_loop" - elif intent == "chitchat": - path = "fast_chitchat" - elif intent == "knowledge": - path = "fast_rag" - elif intent == "tool": - path = "fast_tool" - else: - intent = "complex" - path = "react_loop" - - return HybridRouterResult( - intent=intent, - confidence=confidence, - suggested_tools=suggested_tools, - path=path, - reasoning=reasoning - ) - except Exception as e: - debug(f"轻量 LLM 响应解析失败: {e}") - pass - + json_match = re.search(r'\{[\s\S]*?\}', response.content) + if not json_match: + return _default_result() + + data = json.loads(json_match.group()) + return _parse_classification_result(data) + except Exception as e: - debug(f"轻量 LLM 调用失败: {e}") - - # LLM 失败,降级到规则+默认 + debug(f"LLM 分类失败: {e}") + return _default_result() + + +def _parse_classification_result(data: dict) -> HybridRouterResult: + """解析分类结果""" + intent = data.get("intent", "complex") + confidence = float(data.get("confidence", 0.3)) + + # 置信度低于阈值,走 complex + if confidence < 0.5: + intent = "complex" + + # intent -> path 映射 + path_map = { + "chitchat": "fast_chitchat", + "knowledge": "fast_rag", + "tool": "fast_tool", + } + + return HybridRouterResult( + intent=intent, + confidence=confidence, + suggested_tools=data.get("suggested_tools", []), + path=path_map.get(intent, "react_loop"), + reasoning=data.get("reasoning", "") + ) + + +def _default_result() -> HybridRouterResult: + """默认结果(LLM 失败时)""" return HybridRouterResult( intent="complex", confidence=0.3, @@ -174,372 +156,73 @@ async def _classify_with_small_llm(query: str) -> HybridRouterResult: ) -# ========== 路由决策 ========== - -def _make_decision(classification_result: HybridRouterResult) -> HybridRouterResult: - """ - 根据分类结果最终决策 - - Args: - classification_result: 分类结果 - - Returns: - 最终决策结果 - """ - if classification_result.confidence < 0.5: - classification_result.intent = "complex" - classification_result.path = "react_loop" - return classification_result - - return classification_result - - -# ========== 混合路由主节点 ========== - -async def hybrid_router_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: - """ - 混合路由节点:前置路由,决定走快速路径还是 React 循环 - - Args: - state: 当前状态 - config: LangChain 配置(用于发送自定义事件) - - Returns: - 更新后的状态 - """ +# ========== 主路由节点 ========== +async def hybrid_router_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: + """混合路由节点:前置路由,决定走快速路径还是 React 循环""" state.current_phase = "hybrid_router" - query = state.user_query or "" + info(f"[Hybrid Router] 开始路由: {query[:50]}...") - - # 1. 规则分流(超快速) + + # 1. 规则分流 rule_result = _rule_based_redirect(query) if rule_result: - info(f"[Hybrid Router] 规则分流命中: {rule_result.path}") decision = rule_result + info(f"[Hybrid Router] 规则命中: {decision.path}") else: - # 2. 轻量 LLM 分类 - info(f"[Hybrid Router] 规则未命中,使用轻量 LLM 分类") - classification_result = await _classify_with_small_llm(query) - decision = _make_decision(classification_result) - - # 3. 发送 SSE 事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "intent_classified", - { - "intent": decision.intent, - "confidence": decision.confidence, - "reasoning": decision.reasoning, - "suggested_tools": decision.suggested_tools - }, - callbacks=callbacks - ) - - await adispatch_custom_event( - "path_decision", - { - "path": decision.path, - "intent": decision.intent, - "reasoning": decision.reasoning - }, - callbacks=callbacks - ) - except Exception as e: - debug(f"[Hybrid Router] 发送 SSE 事件失败: {e}") - - # 4. 更新状态 - state.debug_info["hybrid_decision"] = decision + # 2. LLM 分类 + info("[Hybrid Router] 规则未命中,使用 LLM 分类") + decision = await _classify_with_llm(query) + + # 3. 更新状态 + state.debug_info["hybrid_decision"] = { + "intent": decision.intent, + "confidence": decision.confidence, + "path": decision.path, + "reasoning": decision.reasoning, + "suggested_tools": decision.suggested_tools + } state.debug_info["hybrid_start_time"] = datetime.now().isoformat() - + + # 4. 发送事件 + await dispatch_custom_event("intent_classified", { + "intent": decision.intent, + "confidence": decision.confidence, + "reasoning": decision.reasoning, + "suggested_tools": decision.suggested_tools + }, config) + + await dispatch_custom_event("path_decision", { + "path": decision.path, + "intent": decision.intent, + "reasoning": decision.reasoning + }, config) + info(f"[Hybrid Router] 路由决策: {decision.path} (intent={decision.intent}, confidence={decision.confidence})") - return state -# ========== 快速路径:闲聊 ========== - -async def fast_chitchat_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: - """ - 快速闲聊节点:直接返回回复,不走 RAG/工具/循环 - - Args: - state: 当前状态 - config: LangChain 配置 - - Returns: - 更新后的状态 - """ - state.current_phase = "fast_chitchat" - - query = state.user_query or "" - info(f"[Fast Chitchat] 处理: {query[:50]}") - - # 发送 SSE 事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "fast_path_start", - {"path": "fast_chitchat"}, - callbacks=callbacks - ) - except Exception as e: - debug(f"[Fast Chitchat] 发送事件失败: {e}") - - # 快速回复(可以扩展为模板库) - query_clean = query.strip().lower() - - if any(kw in query_clean for kw in ["谢谢", "感谢", "thanks", "thank you"]): - reply = "不客气!如果还有其他问题,请随时告诉我 😊" - elif any(kw in query_clean for kw in ["再见", "拜拜", "bye", "goodbye"]): - reply = "再见!期待下次为您服务 👋" - elif any(kw in query_clean for kw in ["你好", "您好", "hi", "hello", "hey", "早上好", "晚上好", "下午好"]): - reply = "你好!有什么我可以帮您的吗?" - else: - # 兜底:用轻量 LLM 生成 - try: - llm = get_small_llm_service() - response = await llm.ainvoke(f"你是一个友好的助手。用户说:{query}。请简短友好地回复:") - reply = response.content - except: - reply = "你好!有什么我可以帮您的吗?" - - state.final_result = reply - state.success = True - state.current_phase = "finalizing" - state.debug_info["fast_chitchat_success"] = True - - # 发送 fast_path_end 事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "fast_path_end", - {"path": "fast_chitchat", "success": True}, - callbacks=callbacks - ) - except Exception as e: - debug(f"[Fast Chitchat] 发送完成事件失败: {e}") - - return state - - -# ========== 快速路径:RAG(带自动升级) ========== - -async def fast_rag_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: - """ - 快速 RAG 节点:先尝试快速检索,失败自动升级到 React 循环 - - Args: - state: 当前状态 - config: LangChain 配置 - - Returns: - 更新后的状态 - """ - state.current_phase = "fast_rag" - - query = state.user_query or "" - info(f"[Fast RAG] 开始处理: {query[:50]}") - - # 发送 SSE 事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "fast_path_start", - {"path": "fast_rag"}, - callbacks=callbacks - ) - except Exception as e: - debug(f"[Fast RAG] 发送事件失败: {e}") - - try: - # 先尝试 RAG 检索 - 注意:rag_retrieve_node 是异步函数,需要 await - state = await rag_retrieve_node(state, config) - - # 检查检索结果 - rag_docs = getattr(state, "rag_docs", []) - rag_context = getattr(state, "rag_context", "") - - # 检查是否有有效结果 - has_valid_results = (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10) - - if has_valid_results: - # 快速 RAG 成功!使用小模型快速生成回答 - try: - llm = get_chat_service() - prompt = f"""请根据以下信息回答用户问题: - -检索到的信息: -{rag_context or str(rag_docs)[:2000]} - -用户问题:{query} - -请给出简洁、准确的回答:""" - - response = await llm.ainvoke(prompt) - - state.final_result = response.content - state.success = True - state.current_phase = "finalizing" - state.debug_info["fast_rag_success"] = True - - # 发送成功事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "fast_path_end", - {"path": "fast_rag", "success": True}, - callbacks=callbacks - ) - except Exception as e: - debug(f"[Fast RAG] 发送完成事件失败: {e}") - - return state - - except Exception as e: - info(f"[Fast RAG] 快速回答生成失败: {e}") - # 继续往下走,升级到 React 循环 - - # RAG 失败或无结果:标记升级 - info(f"[Fast RAG] 无有效检索结果,升级到 React 循环") - return mark_fast_path_failed(state, reason="无有效检索结果") - - except Exception as e: - info(f"[Fast RAG] 执行失败: {e}") - return mark_fast_path_failed(state, reason=str(e)) - - -# ========== 快速路径:工具(带自动升级) ========== - -async def fast_tool_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: - """ - 快速工具节点:尝试直接调用工具,失败自动升级到 React 循环 - - Args: - state: 当前状态 - config: LangChain 配置 - - Returns: - 更新后的状态 - """ - state.current_phase = "fast_tool" - - decision: HybridRouterResult = state.debug_info.get("hybrid_decision", HybridRouterResult()) - suggested_tools = decision.suggested_tools or [] - - query = state.user_query or "" - info(f"[Fast Tool] 开始处理,建议工具: {suggested_tools}") - - # 发送 SSE 事件 - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - await adispatch_custom_event( - "fast_path_start", - {"path": "fast_tool", "suggested_tools": suggested_tools}, - callbacks=callbacks - ) - except Exception as e: - debug(f"[Fast Tool] 发送事件失败: {e}") - - # 检查是否有明确的工具建议 - if not suggested_tools: - info(f"[Fast Tool] 无明确工具建议,升级到 React 循环") - return mark_fast_path_failed(state, reason="无明确工具建议") - - # 工具调用逻辑(这里暂时先标记升级,让 React 循环去处理) - # 后续可以扩展为直接调用子图 - info(f"[Fast Tool] 快速工具调用暂未完善,升级到 React 循环") - return mark_fast_path_failed(state, reason="快速工具调用暂未完善") - - -# ========== 标记快速路径失败(用于自动升级) ========== - -def mark_fast_path_failed(state: MainGraphState, reason: str = "") -> MainGraphState: - """ - 标记快速路径失败,准备升级到 React 循环 - - Args: - state: 当前状态 - reason: 失败原因 - - Returns: - 更新后的状态 - """ - state.debug_info["fast_path_failed"] = True - state.debug_info["fast_path_fail_reason"] = reason - state.success = False - - # 发送 escalation 事件 - config = state.debug_info.get("config") - if config: - try: - from langchain_core.callbacks.manager import adispatch_custom_event - callbacks = config.get("callbacks") - if callbacks: - # 这里需要在异步上下文中调用 - pass - except Exception as e: - debug(f"[Fast Path] 发送升级事件失败: {e}") - - info(f"[Fast Path] 标记失败,准备升级: {reason}") - return state - - -# ========== 快速路径检查器(自动升级机制) ========== - +# ========== 条件路由函数 ========== def route_from_hybrid_decision(state: MainGraphState) -> str: - """ - 从混合路由决策获取下一步的节点名称 - - Args: - state: 当前状态 - - Returns: - 节点名称 - """ - decision: HybridRouterResult = state.debug_info.get("hybrid_decision", HybridRouterResult()) - return decision.path + """从混合路由决策获取下一步节点""" + decision = state.debug_info.get("hybrid_decision", {}) + return decision.get("path", "react_loop") def check_fast_path_success(state: MainGraphState) -> str: - """ - 检查快速路径是否成功,成功直接到 finalize,失败升级到 react_reason - - Args: - state: 当前状态 - - Returns: - "success" 或 "escalate" - """ - # 检查是否有错误标记 + """检查快速路径是否成功""" if state.debug_info.get("fast_path_failed"): - info(f"[Fast Path Check] 快速路径失败,升级到 React 循环") + info("[Fast Path Check] 快速路径失败,升级到 React 循环") return "escalate" - - # 检查是否成功设置了 final_result - if state.final_result: - info(f"[Fast Path Check] 快速路径成功,进入 finalize") - return "success" - - # 默认:认为成功(某些快速路径可能直接在节点中完成) - return "success" + + info("[Fast Path Check] 快速路径成功,进入 llm_call") + return "llm_call" + + +# ========== 导出 ========== +__all__ = [ + "hybrid_router_node", + "route_from_hybrid_decision", + "check_fast_path_success", + "HybridRouterResult", +] diff --git a/backend/app/main_graph/nodes/rag_nodes.py b/backend/app/main_graph/nodes/rag_nodes.py index 1eae99b..b362a5b 100644 --- a/backend/app/main_graph/nodes/rag_nodes.py +++ b/backend/app/main_graph/nodes/rag_nodes.py @@ -1,11 +1,11 @@ """ RAG 检索节点模块 -包含 RAG 检索节点(带超时重试) +使用模块级变量管理 RAG 工具 """ import time import asyncio -from typing import Dict, Any, Optional +from typing import Optional from datetime import datetime from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity @@ -13,43 +13,15 @@ from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG from app.logger import info from ._utils import dispatch_custom_event, make_react_event -from app.rag.tools import create_rag_tool -from app.rag.pipeline import RAGPipeline - -# ========== 全局 RAG 工具实例 ========== -_GLOBAL_RAG_TOOL: Optional[Any] = None -_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None - - -def get_global_rag_tool() -> Optional[Any]: - return _GLOBAL_RAG_TOOL - - -def set_global_rag_tool(tool: Any) -> None: - global _GLOBAL_RAG_TOOL - _GLOBAL_RAG_TOOL = tool - - -def set_global_rag_pipeline(pipeline: RAGPipeline) -> None: - global _GLOBAL_RAG_PIPELINE - _GLOBAL_RAG_PIPELINE = pipeline - - -def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]: - """从状态或全局获取 RAG 工具""" - return state.debug_info.get("rag_tool") or get_global_rag_tool() - - -def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState: - """将 RAG 工具注入到状态中""" - state.debug_info["rag_tool"] = rag_tool - state.debug_info["rag_tool_injected"] = datetime.now().isoformat() - return state +def _get_rag_tool() -> Optional[callable]: + """获取 RAG 工具""" + from app.main_graph.utils.rag_initializer import get_rag_tool + return get_rag_tool() # ========== RAG 检索核心逻辑 ========== -async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: +async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainGraphState: """执行 RAG 检索的核心逻辑""" retrieval_query = state.user_query @@ -60,55 +32,54 @@ async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: if cfg and cfg.retrieval_query: retrieval_query = cfg.retrieval_query - rag_tool = get_rag_tool_from_state(state) + # 调用 RAG 工具 + rag_context = await rag_tool.ainvoke(retrieval_query) + info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}") - if rag_tool: - rag_context = await rag_tool.ainvoke(retrieval_query) - state.rag_context = rag_context - state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}] - state.rag_retrieved = True - state.success = True - state.debug_info["rag_source"] = "rag_tool" - return state + state.rag_context = rag_context + state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}] + state.rag_retrieved = True + state.success = True + state.debug_info["rag_source"] = "tool" - if _GLOBAL_RAG_PIPELINE: - documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query) - if documents: - rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents) - state.rag_context = rag_context - state.rag_docs = [ - {"source": doc.metadata.get("source", "unknown"), "content": doc.page_content} - for doc in documents - ] - else: - state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。" - state.rag_docs = [] - state.rag_retrieved = True - state.success = True - state.debug_info["rag_source"] = "rag_pipeline" - return state - - raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()") + info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}") + return state # ========== RAG 检索节点 ========== -async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: +async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: """RAG 检索节点:带超时和重试""" state.current_phase = "rag_retrieving" start_time = time.time() last_error = None - # 步骤1: 发送开始事件 + # 获取 RAG 工具 + rag_tool = _get_rag_tool() + + if not rag_tool: + error_record = ErrorRecord( + error_type="RAGRetrievalError", + error_message="RAG 工具未初始化", + severity=ErrorSeverity.WARNING, + source="rag_retrieve_node", + timestamp=datetime.now().isoformat(), + retry_count=0, + max_retries=RAG_RETRY_CONFIG.max_retries, + ) + state.errors.append(error_record) + state.current_error = error_record + state.current_phase = "error_handling" + return state + await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_start", 1.0, "开始执行 RAG 检索..."), config ) - # 步骤2: 执行检索(带重试) for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): try: - result = await _rag_retrieve_core(state) + result = await _rag_retrieve_core(state, rag_tool) info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符") @@ -118,7 +89,6 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An "time": time.time() - start_time } - # 记录成功到历史 state.reasoning_history.append({ "step": state.reasoning_step, "action": "RETRIEVE_RAG", @@ -127,7 +97,6 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An "timestamp": datetime.now().isoformat() }) - # 发送完成事件 doc_count = len(result.rag_docs) if result.rag_docs else 0 await dispatch_custom_event( "react_reasoning", @@ -144,7 +113,6 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An if attempt >= RAG_RETRY_CONFIG.max_retries: break - # 发送重试事件 await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0, @@ -152,11 +120,10 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An config ) - # 指数退避 delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt) await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) - # 步骤3: 所有重试失败,记录到历史(避免推理循环) + # 失败记录 state.reasoning_history.append({ "step": state.reasoning_step, "action": "RETRIEVE_RAG", @@ -165,7 +132,6 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An "timestamp": datetime.now().isoformat() }) - # 步骤4: 记录错误 error_record = ErrorRecord( error_type="RAGRetrievalError", error_message=str(last_error) if last_error else "RAG 检索超时", @@ -174,19 +140,12 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An timestamp=datetime.now().isoformat(), retry_count=RAG_RETRY_CONFIG.max_retries, max_retries=RAG_RETRY_CONFIG.max_retries, - context={ - "query": state.user_query, - "total_time": time.time() - start_time, - "has_rag_tool": get_global_rag_tool() is not None, - "has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None - } ) state.errors.append(error_record) state.current_error = error_record state.current_phase = "error_handling" - # 发送错误事件 await dispatch_custom_event( "react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0, @@ -197,8 +156,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An return state -# ========== 重新检索节点 ========== -async def rag_re_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: +async def rag_re_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: """重新检索节点""" state.current_phase = "rag_re_retrieving" @@ -214,9 +172,4 @@ async def rag_re_retrieve_node(state: MainGraphState, config: Optional[Dict[str, __all__ = [ "rag_retrieve_node", "rag_re_retrieve_node", - "inject_rag_tool_to_state", - "get_rag_tool_from_state", - "get_global_rag_tool", - "set_global_rag_tool", - "set_global_rag_pipeline", ] diff --git a/backend/app/main_graph/utils/main_graph_builder.py b/backend/app/main_graph/utils/main_graph_builder.py index cd14f7f..cae948f 100644 --- a/backend/app/main_graph/utils/main_graph_builder.py +++ b/backend/app/main_graph/utils/main_graph_builder.py @@ -13,11 +13,13 @@ from ..nodes.error_handling import error_handling_node from ..nodes.routing import init_state_node, route_by_reasoning from ..nodes.hybrid_router import ( hybrid_router_node, + route_from_hybrid_decision, + check_fast_path_success, +) +from ..nodes.fast_paths import ( fast_chitchat_node, fast_rag_node, fast_tool_node, - route_from_hybrid_decision, - check_fast_path_success ) from ..nodes.llm_call import create_llm_call_node from ..nodes.rag_nodes import rag_retrieve_node @@ -294,7 +296,7 @@ def build_react_main_graph(llm=None, tools=None, mem0_client=None, use_hybrid_ro fast_node, check_fast_path_success, { - "success": "finalize", + "llm_call": "llm_call", "escalate": "react_reason" } ) diff --git a/backend/app/main_graph/utils/rag_initializer.py b/backend/app/main_graph/utils/rag_initializer.py index 6707a6e..7351635 100644 --- a/backend/app/main_graph/utils/rag_initializer.py +++ b/backend/app/main_graph/utils/rag_initializer.py @@ -3,12 +3,43 @@ from app.rag.tools import create_rag_tool from app.rag.retriever import create_parent_hybrid_retriever from app.model_services import get_embedding_service from app.logger import info, warning +import sys + +# 全局 RAG 工具 +_rag_tool = None +_initialized = False + + +def get_rag_tool() -> callable: + """获取全局 RAG 工具""" + return _rag_tool + + +def is_initialized() -> bool: + """检查是否已初始化""" + return _initialized + + +async def init_rag_tool(local_llm_creator, force: bool = False): + """ + 初始化 RAG 工具(注册到模块级变量) + + Args: + local_llm_creator: 返回 LLM 实例的函数 + force: 是否强制重新初始化 + + Returns: + RAG 工具(@tool 装饰函数)或 None + """ + global _rag_tool, _initialized + + # 防止重复初始化 + if _initialized and not force: + info("[RAG] 已初始化,跳过") + return _rag_tool -async def init_rag_tool(local_llm_creator): - """初始化 RAG 工具,失败返回 None""" try: info("🔄 正在初始化 RAG 检索系统...") - # 使用统一的嵌入服务获取接口 embeddings = get_embedding_service() retriever = create_parent_hybrid_retriever( collection_name="rag_documents", @@ -16,12 +47,26 @@ async def init_rag_tool(local_llm_creator): embeddings=embeddings ) rewrite_llm = local_llm_creator() + rag_tool = create_rag_tool( - retriever, rewrite_llm, - num_queries=3, rerank_top_n=5 + retriever=retriever, + llm=rewrite_llm, + num_queries=3, + rerank_top_n=5 ) - info("✅ RAG 检索工具初始化成功(全异步版本)") + + _rag_tool = rag_tool + _initialized = True + info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})") return rag_tool + except Exception as e: warning(f"⚠️ RAG 检索工具初始化失败: {e}") return None + + +def reset(): + """重置(用于测试)""" + global _rag_tool, _initialized + _rag_tool = None + _initialized = False diff --git a/tools/run.py b/tools/run.py index 8e4ca60..446ee55 100644 --- a/tools/run.py +++ b/tools/run.py @@ -1,20 +1,16 @@ #!/usr/bin/env python3 -"""统一入口:设置路径后运行 RAG 索引构建 CLI""" +"""统一入口:设置路径后运行测试""" import sys from pathlib import Path from dotenv import load_dotenv -# 路径设置 +# 路径设置 - 只添加 backend 目录 project_root = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(project_root)) -sys.path.insert(0, str(project_root / "backend")) +backend_path = project_root +sys.path.insert(0, str(backend_path)) load_dotenv(project_root / ".env") if __name__ == "__main__": - #from rag_indexer.cli import main - #from tools.test.test_rag_indexer_result import main - #from tools.test.test_rag_pipeline import main - from tools.test.test_fast_rag_fix import main - #from tools.test.test_graph_branches import main + from tools.test.test_graph_branches import main import asyncio asyncio.run(main()) diff --git a/tools/start.py b/tools/start.py index 5848df7..1e3cc43 100755 --- a/tools/start.py +++ b/tools/start.py @@ -14,7 +14,6 @@ from dotenv import load_dotenv # 路径设置 project_root = Path(__file__).resolve().parent.parent sys.path.insert(0, str(project_root)) -sys.path.insert(0, str(project_root / "backend")) load_dotenv(project_root / ".env") # 全局变量 diff --git a/tools/test/test_fast_rag_fix.py b/tools/test/test_fast_rag_fix.py deleted file mode 100644 index 7febf66..0000000 --- a/tools/test/test_fast_rag_fix.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python3 -""" -快速测试 - 测试 fast_rag 路径修复 -""" - -import asyncio -from backend.app.main_graph.state import MainGraphState, CurrentAction -from backend.app.main_graph.utils.main_graph_builder import build_react_main_graph -from backend.app.model_services.chat_services import get_all_chat_services -from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS - - -async def test_fast_rag_path(): - """测试 fast_rag 路径""" - print("=" * 60) - print("测试 fast_rag 路径修复") - print("=" * 60) - - # 1. 获取 LLM - chat_services = get_all_chat_services() - if not chat_services: - print("✗ 没有可用的 LLM 服务") - return - - llm = list(chat_services.values())[0] - print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}") - - # 2. 构建图 - graph = build_react_main_graph( - llm=llm, - tools=AVAILABLE_TOOLS, - use_hybrid_router=True - ).compile() - print(f"✓ 图构建完成") - - # 3. 测试问题 - test_query = "吕布和张飞谁厉害?" - print(f"\n测试问题: {test_query}") - - # 4. 创建状态 - input_state = { - "user_query": test_query, - "messages": [{"role": "user", "content": test_query}], - "user_id": "test_user", - "current_action": CurrentAction.NONE - } - - # 5. 执行 - print("开始执行...") - try: - result = await graph.ainvoke( - input_state, - config={"configurable": {"thread_id": "test_fast_rag"}} - ) - - print(f"\n✓ 执行成功!") - print(f"最终回答: {result.get('final_result', '')[:300]}") - - # 调试信息 - debug_info = result.get("debug_info", {}) - print(f"\n调试信息:") - if "fast_path_failed" in debug_info: - print(f" - fast_path_failed: {debug_info['fast_path_failed']}") - if "fast_path_fail_reason" in debug_info: - print(f" - fast_path_fail_reason: {debug_info['fast_path_fail_reason']}") - - except Exception as e: - print(f"\n✗ 执行失败: {e}") - import traceback - print(traceback.format_exc()) - return False - - return True - - -async def main(): - success = await test_fast_rag_path() - if success: - print("\n🎉 测试通过!") - else: - print("\n⚠️ 测试失败") - - -if __name__ == "__main__": - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\n测试被中断") diff --git a/tools/test/test_graph_branches.py b/tools/test/test_graph_branches.py index 6b8f04c..32ca0f3 100644 --- a/tools/test/test_graph_branches.py +++ b/tools/test/test_graph_branches.py @@ -7,47 +7,49 @@ import asyncio from pathlib import Path from dotenv import load_dotenv +# 添加 backend 到路径 +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "backend")) -from backend.app.main_graph.state import MainGraphState, CurrentAction -from backend.app.main_graph.utils.main_graph_builder import build_react_main_graph -from backend.app.model_services.chat_services import get_all_chat_services -from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS -from backend.app.main_graph.utils.rag_initializer import init_rag_tool +from app.main_graph.state import MainGraphState, CurrentAction +from app.main_graph.utils.main_graph_builder import build_react_main_graph +from app.model_services.chat_services import get_all_chat_services +from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS +from app.main_graph.utils.rag_initializer import init_rag_tool # ========== 测试用例配置 ========== TEST_CASES = [ - # 测试1: 简单闲聊 - 应该走 fast_chitchat - { - "name": "闲聊测试", - "query": "你好!", - "description": "测试快速闲聊分支" - }, + # # 测试1: 简单闲聊 - 应该走 fast_chitchat + # { + # "name": "闲聊测试", + # "query": "你好!", + # "description": "测试快速闲聊分支" + # }, # 测试2: 知识查询 - 应该走 fast_rag,然后可能升级到 react { "name": "知识查询测试", - "query": "什么是机器学习?", + "query": "吕布的事迹?", "description": "测试快速 RAG 分支" }, - # 测试3: 需要推理的复杂问题 - 应该直接到 React 循环 - { - "name": "复杂推理测试", - "query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?", - "description": "测试 React 循环推理分支" - }, - # 测试4: 需要工具调用的问题 - { - "name": "工具调用测试", - "query": "搜索一下今天的天气怎么样", - "description": "测试工具调用分支" - }, - # 测试5: 带记忆的对话 - { - "name": "记忆测试", - "query": "你刚才回答了我什么问题?", - "description": "测试记忆检索分支", - "thread_id": "test_memory_thread" - } + # # 测试3: 需要推理的复杂问题 - 应该直接到 React 循环 + # { + # "name": "复杂推理测试", + # "query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?", + # "description": "测试 React 循环推理分支" + # }, + # # 测试4: 需要工具调用的问题 + # { + # "name": "工具调用测试", + # "query": "搜索一下今天的天气怎么样", + # "description": "测试工具调用分支" + # }, + # # 测试5: 带记忆的对话 + # { + # "name": "记忆测试", + # "query": "你刚才回答了我什么问题?", + # "description": "测试记忆检索分支", + # "thread_id": "test_memory_thread" + # } ] @@ -56,36 +58,36 @@ async def setup_test_environment(): print("=" * 60) print("设置测试环境...") print("=" * 60) - + # 获取 LLM 服务 chat_services = get_all_chat_services() if not chat_services: raise RuntimeError("没有可用的 LLM 服务") - + llm = list(chat_services.values())[0] print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}") - + # 初始化 RAG 工具 def create_local_llm(): return llm - + rag_tool = await init_rag_tool(create_local_llm) tools = AVAILABLE_TOOLS.copy() if rag_tool: tools.append(rag_tool) print(f"✓ RAG 工具初始化成功") - + # 构建图 graph = build_react_main_graph( llm=llm, tools=tools, use_hybrid_router=True ).compile() - + print(f"✓ 图构建完成") print() - - return graph + + return graph, rag_tool def create_test_state(query: str, thread_id: str = None) -> dict: @@ -98,7 +100,7 @@ def create_test_state(query: str, thread_id: str = None) -> dict: } -async def run_single_test(graph, test_case: dict) -> dict: +async def run_single_test(graph, rag_tool, test_case: dict) -> dict: """运行单个测试""" name = test_case["name"] query = test_case["query"] @@ -115,9 +117,12 @@ async def run_single_test(graph, test_case: dict) -> dict: # 创建初始状态 input_state = create_test_state(query, thread_id) - # 配置 + # 配置(注入 RAG 工具) config = { - "configurable": {"thread_id": thread_id} + "configurable": { + "thread_id": thread_id, + "rag_tool": rag_tool + } } # 执行图 @@ -168,12 +173,12 @@ async def main(): print("=" * 60) # 设置环境 - graph = await setup_test_environment() + graph, rag_tool = await setup_test_environment() # 运行所有测试 results = [] for test_case in TEST_CASES: - result = await run_single_test(graph, test_case) + result = await run_single_test(graph, rag_tool, test_case) results.append(result) # 稍微间隔一下 diff --git a/tools/test/test_rag_pipeline.py b/tools/test/test_rag_pipeline.py index 1df959a..98d5081 100644 --- a/tools/test/test_rag_pipeline.py +++ b/tools/test/test_rag_pipeline.py @@ -63,7 +63,6 @@ async def test_rag_tool(): num_queries=3, rerank_top_n=5 ) - query = "吕布的经历" print(f"\n用户查询: {query}")