#!/usr/bin/env python3 """ 快速测试 - 测试 fast_rag 路径修复 """ import sys import asyncio from pathlib import Path from dotenv import load_dotenv # 路径设置 project_root = Path(__file__).resolve().parent.parent sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root / "backend")) load_dotenv(project_root / ".env") from app.main_graph.state import MainGraphState, CurrentAction from app.main_graph.utils.main_graph_builder import build_react_main_graph from app.model_services.chat_services import get_all_chat_services from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS async def test_fast_rag_path(): """测试 fast_rag 路径""" print("=" * 60) print("测试 fast_rag 路径修复") print("=" * 60) # 1. 获取 LLM chat_services = get_all_chat_services() if not chat_services: print("✗ 没有可用的 LLM 服务") return llm = list(chat_services.values())[0] print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}") # 2. 构建图 graph = build_react_main_graph( llm=llm, tools=AVAILABLE_TOOLS, use_hybrid_router=True ).compile() print(f"✓ 图构建完成") # 3. 测试问题 test_query = "吕布和张飞谁厉害?" print(f"\n测试问题: {test_query}") # 4. 创建状态 input_state = { "user_query": test_query, "messages": [{"role": "user", "content": test_query}], "user_id": "test_user", "current_action": CurrentAction.NONE } # 5. 执行 print("开始执行...") try: result = await graph.ainvoke( input_state, config={"configurable": {"thread_id": "test_fast_rag"}} ) print(f"\n✓ 执行成功!") print(f"最终回答: {result.get('final_result', '')[:300]}") # 调试信息 debug_info = result.get("debug_info", {}) print(f"\n调试信息:") if "fast_path_failed" in debug_info: print(f" - fast_path_failed: {debug_info['fast_path_failed']}") if "fast_path_fail_reason" in debug_info: print(f" - fast_path_fail_reason: {debug_info['fast_path_fail_reason']}") except Exception as e: print(f"\n✗ 执行失败: {e}") import traceback print(traceback.format_exc()) return False return True async def main(): success = await test_fast_rag_path() if success: print("\n🎉 测试通过!") else: print("\n⚠️ 测试失败") if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\n测试被中断")