Files
ailine/tools/test/test_graph_branches.py
root ef6fbc1521
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m36s
推理优化
2026-05-06 04:26:06 +08:00

213 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
主图完整测试 - 覆盖各个分支
"""
import asyncio
from backend.app.main_graph.state import MainGraphState, CurrentAction
from backend.app.main_graph.main_graph_builder import build_react_main_graph
from backend.app.model_services.chat_services import get_all_chat_services
from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
from backend.app.main_graph.utils.rag_initializer import init_rag_tool
# ========== 测试用例配置 ==========
TEST_CASES = [
# 测试1: 简单闲聊 - 应该走 fast_chitchat
{
"name": "闲聊测试",
"query": "你好!",
"description": "测试快速闲聊分支"
},
# 测试2: 知识查询 - 应该走 fast_rag然后可能升级到 react
{
"name": "知识查询测试",
"query": "吕布的事迹?",
"description": "测试快速 RAG 分支"
},
# # 测试3: 需要推理的复杂问题 - 应该直接到 React 循环
# {
# "name": "复杂推理测试",
# "query": "请帮我分析如果我有10万元想要在一年内获得15%的收益,有哪些低风险的投资方案?",
# "description": "测试 React 循环推理分支"
# },
# # 测试4: 需要工具调用的问题
# {
# "name": "联网工具调用测试",
# "query": "搜索一下今天的天气怎么样",
# "description": "测试工具调用分支"
# },
# 测试5: 带记忆的对话
{
"name": "记忆测试",
"query": "你刚才回答了我什么问题?",
"description": "测试记忆检索分支",
"thread_id": "test_memory_thread"
}
]
async def setup_test_environment():
"""设置测试环境"""
print("=" * 60)
print("设置测试环境...")
print("=" * 60)
# 获取 LLM 服务
chat_services = get_all_chat_services()
if not chat_services:
raise RuntimeError("没有可用的 LLM 服务")
print(f"✓ 可用模型: {list(chat_services.keys())}")
# 初始化 RAG 工具
rag_tool = await init_rag_tool()
tools = AVAILABLE_TOOLS.copy()
if rag_tool:
tools.append(rag_tool)
print(f"✓ RAG 工具初始化成功")
# 构建图(使用新的 API: chat_services 而不是 llm
graph = build_react_main_graph(
chat_services=chat_services,
tools=tools,
use_hybrid_router=True
).compile()
print(f"✓ 图构建完成")
print()
return graph, rag_tool
def create_test_state(query: str, thread_id: str = None) -> dict:
"""创建测试状态"""
return {
"user_query": query,
"messages": [{"role": "user", "content": query}],
"user_id": "test_user",
"current_action": CurrentAction.NONE
}
async def run_single_test(graph, rag_tool, test_case: dict) -> dict:
"""运行单个测试"""
name = test_case["name"]
query = test_case["query"]
description = test_case["description"]
thread_id = test_case.get("thread_id", "test_thread")
print(f"\n{'=' * 60}")
print(f"测试: {name}")
print(f"描述: {description}")
print(f"查询: {query}")
print(f"{'=' * 60}")
try:
# 创建初始状态
input_state = create_test_state(query, thread_id)
# 配置(注入 RAG 工具)
config = {
"configurable": {
"thread_id": thread_id,
"rag_tool": rag_tool
}
}
# 执行图
print("开始执行图...")
result = await graph.ainvoke(input_state, config=config)
# 检查结果
success = result.get("success", False)
final_result = result.get("final_result", "")
print(f"\n结果:")
print(f" 成功: {'' if success else ''}")
print(f" 最终回答: {final_result[:200]}{'...' if len(final_result) > 200 else ''}")
# 调试信息
if "debug_info" in result:
debug_info = result["debug_info"]
print(f" 调试信息:")
if "fast_path_failed" in debug_info:
print(f" - 快速路径失败: {debug_info['fast_path_failed']}")
if "fast_path_fail_reason" in debug_info:
print(f" - 失败原因: {debug_info['fast_path_fail_reason']}")
if "hybrid_decision" in debug_info:
decision = debug_info["hybrid_decision"]
print(f" - 路由决策: {decision.path if hasattr(decision, 'path') else 'unknown'}")
return {
"name": name,
"success": success,
"result": result
}
except Exception as e:
print(f"\n✗ 测试失败: {e}")
import traceback
print(f"堆栈: {traceback.format_exc()}")
return {
"name": name,
"success": False,
"error": str(e)
}
async def main():
"""主函数"""
print("\n" + "=" * 60)
print("主图完整测试套件")
print("=" * 60)
# 设置环境
graph, rag_tool = await setup_test_environment()
# 运行所有测试
results = []
for test_case in TEST_CASES:
result = await run_single_test(graph, rag_tool, test_case)
results.append(result)
# 稍微间隔一下
await asyncio.sleep(1)
# 总结
print("\n" + "=" * 60)
print("测试总结")
print("=" * 60)
total = len(results)
passed = sum(1 for r in results if r["success"])
failed = total - passed
print(f"\n总测试数: {total}")
print(f"通过: {passed}")
print(f"失败: {failed}")
print("\n详细结果:")
for result in results:
status = "✓ 通过" if result["success"] else "✗ 失败"
print(f" {result['name']}: {status}")
print("\n" + "=" * 60)
if failed == 0:
print("🎉 所有测试通过!")
else:
print(f"⚠️ 有 {failed} 个测试失败")
print("=" * 60)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n\n测试被用户中断")
except Exception as e:
print(f"\n\n测试运行失败: {e}")
import traceback
print(traceback.format_exc())