#!/usr/bin/env python3 """ 主图完整测试 - 覆盖各个分支 """ import asyncio from backend.app.main_graph.state import MainGraphState, CurrentAction from backend.app.main_graph.main_graph_builder import build_react_main_graph from backend.app.model_services.chat_services import get_all_chat_services from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS from backend.app.main_graph.utils.rag_initializer import init_rag_tool # ========== 测试用例配置 ========== TEST_CASES = [ # 测试1: 简单闲聊 - 应该走 fast_chitchat { "name": "闲聊测试", "query": "你好!", "description": "测试快速闲聊分支" }, # 测试2: 知识查询 - 应该走 fast_rag,然后可能升级到 react { "name": "知识查询测试", "query": "吕布的事迹?", "description": "测试快速 RAG 分支" }, # 测试3: 需要推理的复杂问题 - 应该直接到 React 循环 { "name": "复杂推理测试", "query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?", "description": "测试 React 循环推理分支" }, # 测试4: 需要工具调用的问题 { "name": "联网工具调用测试", "query": "搜索一下今天的天气怎么样", "description": "测试工具调用分支" }, # 测试5: 带记忆的对话 { "name": "记忆测试", "query": "你刚才回答了我什么问题?", "description": "测试记忆检索分支", "thread_id": "test_memory_thread" } ] async def setup_test_environment(): """设置测试环境""" print("=" * 60) print("设置测试环境...") print("=" * 60) # 获取 LLM 服务 chat_services = get_all_chat_services() if not chat_services: raise RuntimeError("没有可用的 LLM 服务") print(f"✓ 可用模型: {list(chat_services.keys())}") # 初始化 RAG 工具 rag_tool = await init_rag_tool() tools = AVAILABLE_TOOLS.copy() if rag_tool: tools.append(rag_tool) print(f"✓ RAG 工具初始化成功") # 构建图(使用新的 API: chat_services 而不是 llm) graph = build_react_main_graph( chat_services=chat_services, tools=tools, use_hybrid_router=True ).compile() print(f"✓ 图构建完成") print() return graph, rag_tool def create_test_state(query: str, thread_id: str = None) -> dict: """创建测试状态""" return { "user_query": query, "messages": [{"role": "user", "content": query}], "user_id": "test_user", "current_action": CurrentAction.NONE } async def run_single_test(graph, rag_tool, test_case: dict) -> dict: """运行单个测试""" name = test_case["name"] query = test_case["query"] description = test_case["description"] thread_id = test_case.get("thread_id", "test_thread") print(f"\n{'=' * 60}") print(f"测试: {name}") print(f"描述: {description}") print(f"查询: {query}") print(f"{'=' * 60}") try: # 创建初始状态 input_state = create_test_state(query, thread_id) # 配置(注入 RAG 工具) config = { "configurable": { "thread_id": thread_id, "rag_tool": rag_tool } } # 执行图 print("开始执行图...") result = await graph.ainvoke(input_state, config=config) # 检查结果 success = result.get("success", False) final_result = result.get("final_result", "") print(f"\n结果:") print(f" 成功: {'✓' if success else '✗'}") print(f" 最终回答: {final_result[:200]}{'...' if len(final_result) > 200 else ''}") # 调试信息 if "debug_info" in result: debug_info = result["debug_info"] print(f" 调试信息:") if "fast_path_failed" in debug_info: print(f" - 快速路径失败: {debug_info['fast_path_failed']}") if "fast_path_fail_reason" in debug_info: print(f" - 失败原因: {debug_info['fast_path_fail_reason']}") if "hybrid_decision" in debug_info: decision = debug_info["hybrid_decision"] print(f" - 路由决策: {decision.path if hasattr(decision, 'path') else 'unknown'}") return { "name": name, "success": success, "result": result } except Exception as e: print(f"\n✗ 测试失败: {e}") import traceback print(f"堆栈: {traceback.format_exc()}") return { "name": name, "success": False, "error": str(e) } async def main(): """主函数""" print("\n" + "=" * 60) print("主图完整测试套件") print("=" * 60) # 设置环境 graph, rag_tool = await setup_test_environment() # 运行所有测试 results = [] for test_case in TEST_CASES: result = await run_single_test(graph, rag_tool, test_case) results.append(result) # 稍微间隔一下 await asyncio.sleep(1) # 总结 print("\n" + "=" * 60) print("测试总结") print("=" * 60) total = len(results) passed = sum(1 for r in results if r["success"]) failed = total - passed print(f"\n总测试数: {total}") print(f"通过: {passed}") print(f"失败: {failed}") print("\n详细结果:") for result in results: status = "✓ 通过" if result["success"] else "✗ 失败" print(f" {result['name']}: {status}") print("\n" + "=" * 60) if failed == 0: print("🎉 所有测试通过!") else: print(f"⚠️ 有 {failed} 个测试失败") print("=" * 60) if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\n\n测试被用户中断") except Exception as e: print(f"\n\n测试运行失败: {e}") import traceback print(traceback.format_exc())