2026-05-04 18:59:15 +08:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
"""
|
|
|
|
|
|
主图完整测试 - 覆盖各个分支
|
|
|
|
|
|
"""
|
|
|
|
|
|
import sys
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
|
2026-05-05 04:32:42 +08:00
|
|
|
|
# 添加 backend 到路径
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "backend"))
|
2026-05-04 18:59:15 +08:00
|
|
|
|
|
2026-05-05 04:32:42 +08:00
|
|
|
|
from app.main_graph.state import MainGraphState, CurrentAction
|
|
|
|
|
|
from app.main_graph.utils.main_graph_builder import build_react_main_graph
|
|
|
|
|
|
from app.model_services.chat_services import get_all_chat_services
|
|
|
|
|
|
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
|
|
|
|
|
from app.main_graph.utils.rag_initializer import init_rag_tool
|
2026-05-04 18:59:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 测试用例配置 ==========
|
|
|
|
|
|
TEST_CASES = [
|
2026-05-05 04:32:42 +08:00
|
|
|
|
# # 测试1: 简单闲聊 - 应该走 fast_chitchat
|
|
|
|
|
|
# {
|
|
|
|
|
|
# "name": "闲聊测试",
|
|
|
|
|
|
# "query": "你好!",
|
|
|
|
|
|
# "description": "测试快速闲聊分支"
|
|
|
|
|
|
# },
|
2026-05-04 18:59:15 +08:00
|
|
|
|
# 测试2: 知识查询 - 应该走 fast_rag,然后可能升级到 react
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "知识查询测试",
|
2026-05-05 04:32:42 +08:00
|
|
|
|
"query": "吕布的事迹?",
|
2026-05-04 18:59:15 +08:00
|
|
|
|
"description": "测试快速 RAG 分支"
|
|
|
|
|
|
},
|
2026-05-05 04:32:42 +08:00
|
|
|
|
# # 测试3: 需要推理的复杂问题 - 应该直接到 React 循环
|
|
|
|
|
|
# {
|
|
|
|
|
|
# "name": "复杂推理测试",
|
|
|
|
|
|
# "query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?",
|
|
|
|
|
|
# "description": "测试 React 循环推理分支"
|
|
|
|
|
|
# },
|
|
|
|
|
|
# # 测试4: 需要工具调用的问题
|
|
|
|
|
|
# {
|
|
|
|
|
|
# "name": "工具调用测试",
|
|
|
|
|
|
# "query": "搜索一下今天的天气怎么样",
|
|
|
|
|
|
# "description": "测试工具调用分支"
|
|
|
|
|
|
# },
|
|
|
|
|
|
# # 测试5: 带记忆的对话
|
|
|
|
|
|
# {
|
|
|
|
|
|
# "name": "记忆测试",
|
|
|
|
|
|
# "query": "你刚才回答了我什么问题?",
|
|
|
|
|
|
# "description": "测试记忆检索分支",
|
|
|
|
|
|
# "thread_id": "test_memory_thread"
|
|
|
|
|
|
# }
|
2026-05-04 18:59:15 +08:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def setup_test_environment():
|
|
|
|
|
|
"""设置测试环境"""
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
print("设置测试环境...")
|
|
|
|
|
|
print("=" * 60)
|
2026-05-05 04:32:42 +08:00
|
|
|
|
|
2026-05-04 18:59:15 +08:00
|
|
|
|
# 获取 LLM 服务
|
|
|
|
|
|
chat_services = get_all_chat_services()
|
|
|
|
|
|
if not chat_services:
|
|
|
|
|
|
raise RuntimeError("没有可用的 LLM 服务")
|
2026-05-05 04:32:42 +08:00
|
|
|
|
|
2026-05-04 18:59:15 +08:00
|
|
|
|
llm = list(chat_services.values())[0]
|
|
|
|
|
|
print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}")
|
2026-05-05 04:32:42 +08:00
|
|
|
|
|
2026-05-04 18:59:15 +08:00
|
|
|
|
# 初始化 RAG 工具
|
|
|
|
|
|
def create_local_llm():
|
|
|
|
|
|
return llm
|
2026-05-05 04:32:42 +08:00
|
|
|
|
|
2026-05-04 18:59:15 +08:00
|
|
|
|
rag_tool = await init_rag_tool(create_local_llm)
|
|
|
|
|
|
tools = AVAILABLE_TOOLS.copy()
|
|
|
|
|
|
if rag_tool:
|
|
|
|
|
|
tools.append(rag_tool)
|
|
|
|
|
|
print(f"✓ RAG 工具初始化成功")
|
2026-05-05 04:32:42 +08:00
|
|
|
|
|
2026-05-04 18:59:15 +08:00
|
|
|
|
# 构建图
|
|
|
|
|
|
graph = build_react_main_graph(
|
|
|
|
|
|
llm=llm,
|
|
|
|
|
|
tools=tools,
|
|
|
|
|
|
use_hybrid_router=True
|
|
|
|
|
|
).compile()
|
2026-05-05 04:32:42 +08:00
|
|
|
|
|
2026-05-04 18:59:15 +08:00
|
|
|
|
print(f"✓ 图构建完成")
|
|
|
|
|
|
print()
|
2026-05-05 04:32:42 +08:00
|
|
|
|
|
|
|
|
|
|
return graph, rag_tool
|
2026-05-04 18:59:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_test_state(query: str, thread_id: str = None) -> dict:
|
|
|
|
|
|
"""创建测试状态"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"user_query": query,
|
|
|
|
|
|
"messages": [{"role": "user", "content": query}],
|
|
|
|
|
|
"user_id": "test_user",
|
|
|
|
|
|
"current_action": CurrentAction.NONE
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-05 04:32:42 +08:00
|
|
|
|
async def run_single_test(graph, rag_tool, test_case: dict) -> dict:
|
2026-05-04 18:59:15 +08:00
|
|
|
|
"""运行单个测试"""
|
|
|
|
|
|
name = test_case["name"]
|
|
|
|
|
|
query = test_case["query"]
|
|
|
|
|
|
description = test_case["description"]
|
|
|
|
|
|
thread_id = test_case.get("thread_id", "test_thread")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 60}")
|
|
|
|
|
|
print(f"测试: {name}")
|
|
|
|
|
|
print(f"描述: {description}")
|
|
|
|
|
|
print(f"查询: {query}")
|
|
|
|
|
|
print(f"{'=' * 60}")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 创建初始状态
|
|
|
|
|
|
input_state = create_test_state(query, thread_id)
|
|
|
|
|
|
|
2026-05-05 04:32:42 +08:00
|
|
|
|
# 配置(注入 RAG 工具)
|
2026-05-04 18:59:15 +08:00
|
|
|
|
config = {
|
2026-05-05 04:32:42 +08:00
|
|
|
|
"configurable": {
|
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
|
"rag_tool": rag_tool
|
|
|
|
|
|
}
|
2026-05-04 18:59:15 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 执行图
|
|
|
|
|
|
print("开始执行图...")
|
|
|
|
|
|
result = await graph.ainvoke(input_state, config=config)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查结果
|
|
|
|
|
|
success = result.get("success", False)
|
|
|
|
|
|
final_result = result.get("final_result", "")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n结果:")
|
|
|
|
|
|
print(f" 成功: {'✓' if success else '✗'}")
|
|
|
|
|
|
print(f" 最终回答: {final_result[:200]}{'...' if len(final_result) > 200 else ''}")
|
|
|
|
|
|
|
|
|
|
|
|
# 调试信息
|
|
|
|
|
|
if "debug_info" in result:
|
|
|
|
|
|
debug_info = result["debug_info"]
|
|
|
|
|
|
print(f" 调试信息:")
|
|
|
|
|
|
if "fast_path_failed" in debug_info:
|
|
|
|
|
|
print(f" - 快速路径失败: {debug_info['fast_path_failed']}")
|
|
|
|
|
|
if "fast_path_fail_reason" in debug_info:
|
|
|
|
|
|
print(f" - 失败原因: {debug_info['fast_path_fail_reason']}")
|
|
|
|
|
|
if "hybrid_decision" in debug_info:
|
|
|
|
|
|
decision = debug_info["hybrid_decision"]
|
|
|
|
|
|
print(f" - 路由决策: {decision.path if hasattr(decision, 'path') else 'unknown'}")
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"name": name,
|
|
|
|
|
|
"success": success,
|
|
|
|
|
|
"result": result
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"\n✗ 测试失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
print(f"堆栈: {traceback.format_exc()}")
|
|
|
|
|
|
return {
|
|
|
|
|
|
"name": name,
|
|
|
|
|
|
"success": False,
|
|
|
|
|
|
"error": str(e)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main():
|
|
|
|
|
|
"""主函数"""
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("主图完整测试套件")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置环境
|
2026-05-05 04:32:42 +08:00
|
|
|
|
graph, rag_tool = await setup_test_environment()
|
2026-05-04 18:59:15 +08:00
|
|
|
|
|
|
|
|
|
|
# 运行所有测试
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for test_case in TEST_CASES:
|
2026-05-05 04:32:42 +08:00
|
|
|
|
result = await run_single_test(graph, rag_tool, test_case)
|
2026-05-04 18:59:15 +08:00
|
|
|
|
results.append(result)
|
|
|
|
|
|
|
|
|
|
|
|
# 稍微间隔一下
|
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
|
|
|
|
|
|
# 总结
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("测试总结")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
total = len(results)
|
|
|
|
|
|
passed = sum(1 for r in results if r["success"])
|
|
|
|
|
|
failed = total - passed
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n总测试数: {total}")
|
|
|
|
|
|
print(f"通过: {passed}")
|
|
|
|
|
|
print(f"失败: {failed}")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n详细结果:")
|
|
|
|
|
|
for result in results:
|
|
|
|
|
|
status = "✓ 通过" if result["success"] else "✗ 失败"
|
|
|
|
|
|
print(f" {result['name']}: {status}")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
if failed == 0:
|
|
|
|
|
|
print("🎉 所有测试通过!")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"⚠️ 有 {failed} 个测试失败")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
try:
|
|
|
|
|
|
asyncio.run(main())
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
|
print("\n\n测试被用户中断")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"\n\n测试运行失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
print(traceback.format_exc())
|