Files
ailine/tools/test/test_minimal_agent.py

206 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
极简 Agent 架构测试 - 适配新架构
"""
import sys
from pathlib import Path
# 添加项目路径
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
import asyncio
from dotenv import load_dotenv
# 加载环境变量
load_dotenv(project_root / ".env")
from backend.app.main_graph.state import AgentState
from backend.app.main_graph.main_graph_builder import build_agent_graph
from backend.app.model_services.chat_services import get_cached_chat_services
# ========== 测试用例配置 ==========
TEST_CASES = [
# 测试1: 简单闲聊
{
"name": "闲聊测试",
"query": "你好!",
"description": "测试简单对话"
},
# 测试2: 知识查询
{
"name": "知识库测试",
"query": "吕布的事迹?",
"description": "测试 RAG 工具调用"
},
# 测试3: 简单问题
{
"name": "简单问答测试",
"query": "介绍一下你自己",
"description": "测试直接回答能力"
},
]
async def setup_test_environment():
"""设置测试环境"""
print("=" * 60)
print("设置测试环境...")
print("=" * 60)
# 获取 LLM 服务
chat_services = get_cached_chat_services()
if not chat_services:
raise RuntimeError("没有可用的 LLM 服务")
print(f"✓ 可用模型: {list(chat_services.keys())}")
# 选择 zhipu 或 deepseek 作为测试模型,避免 Baosi API 的问题
test_model = None
if "zhipu" in chat_services:
test_model = "zhipu"
print(f"✓ 选择 zhipu 作为测试模型")
elif "deepseek" in chat_services:
test_model = "deepseek"
print(f"✓ 选择 deepseek 作为测试模型")
elif "local" in chat_services:
test_model = "local"
print(f"✓ 选择 local 作为测试模型")
else:
# 用第一个可用的
test_model = list(chat_services.keys())[0]
print(f"✓ 选择 {test_model} 作为测试模型")
# 只保留选中的模型,方便测试
test_chat_services = {test_model: chat_services[test_model]}
# 构建图(使用新的 build_agent_graph
graph_builder = build_agent_graph(
chat_services=test_chat_services
)
graph = graph_builder.compile()
print(f"✓ 图构建完成")
print()
return graph, test_chat_services
def create_test_state(query: str, user_id: str = "test_user") -> dict:
"""创建测试状态"""
from langchain_core.messages import HumanMessage
return {
"messages": [HumanMessage(content=query)],
"user_id": user_id,
}
async def run_single_test(graph, test_case: dict) -> dict:
"""运行单个测试"""
name = test_case["name"]
query = test_case["query"]
description = test_case["description"]
print(f"\n{'=' * 60}")
print(f"测试: {name}")
print(f"描述: {description}")
print(f"查询: {query}")
print(f"{'=' * 60}")
try:
# 创建初始状态
input_state = create_test_state(query)
# 配置
config = {
"configurable": {
"thread_id": f"test_{name}"
}
}
# 执行图
print("开始执行图...")
result = await graph.ainvoke(input_state, config=config)
# 提取最终回复(优先使用 final_reply
reply = result.get("final_reply", "")
if not reply and result.get("messages"):
reply = result["messages"][-1].content
print(f"\n✓ 执行完成")
print(f"最终回复: {reply[:500]}{'...' if len(reply) > 500 else ''}")
return {
"name": name,
"success": True,
"reply": reply,
"state": result
}
except Exception as e:
print(f"\n✗ 测试失败: {e}")
import traceback
print(f"堆栈: {traceback.format_exc()}")
return {
"name": name,
"success": False,
"error": str(e)
}
async def main():
"""主函数"""
print("\n" + "=" * 60)
print("极简 Agent 架构测试")
print("=" * 60)
try:
# 设置环境
graph, chat_services = await setup_test_environment()
# 运行所有测试
results = []
for test_case in TEST_CASES:
result = await run_single_test(graph, test_case)
results.append(result)
# 稍微间隔一下
await asyncio.sleep(0.5)
# 总结
print("\n" + "=" * 60)
print("测试总结")
print("=" * 60)
total = len(results)
passed = sum(1 for r in results if r["success"])
failed = total - passed
print(f"\n总测试数: {total}")
print(f"通过: {passed}")
print(f"失败: {failed}")
print("\n详细结果:")
for result in results:
status = "✓ 通过" if result["success"] else "✗ 失败"
print(f" {result['name']}: {status}")
print("\n" + "=" * 60)
if failed == 0:
print("🎉 所有测试通过!")
else:
print(f"⚠️ 有 {failed} 个测试失败")
print("=" * 60)
except Exception as e:
print(f"\n测试运行失败: {e}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(main())