206 lines
5.4 KiB
Python
206 lines
5.4 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
极简 Agent 架构测试 - 适配新架构
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# 添加项目路径
|
|||
|
|
project_root = Path(__file__).resolve().parent.parent
|
|||
|
|
sys.path.insert(0, str(project_root))
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
# 加载环境变量
|
|||
|
|
load_dotenv(project_root / ".env")
|
|||
|
|
|
|||
|
|
from backend.app.main_graph.state import AgentState
|
|||
|
|
from backend.app.main_graph.main_graph_builder import build_agent_graph
|
|||
|
|
from backend.app.model_services.chat_services import get_cached_chat_services
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 测试用例配置 ==========
|
|||
|
|
TEST_CASES = [
|
|||
|
|
# 测试1: 简单闲聊
|
|||
|
|
{
|
|||
|
|
"name": "闲聊测试",
|
|||
|
|
"query": "你好!",
|
|||
|
|
"description": "测试简单对话"
|
|||
|
|
},
|
|||
|
|
# 测试2: 知识查询
|
|||
|
|
{
|
|||
|
|
"name": "知识库测试",
|
|||
|
|
"query": "吕布的事迹?",
|
|||
|
|
"description": "测试 RAG 工具调用"
|
|||
|
|
},
|
|||
|
|
# 测试3: 简单问题
|
|||
|
|
{
|
|||
|
|
"name": "简单问答测试",
|
|||
|
|
"query": "介绍一下你自己",
|
|||
|
|
"description": "测试直接回答能力"
|
|||
|
|
},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def setup_test_environment():
|
|||
|
|
"""设置测试环境"""
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("设置测试环境...")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
# 获取 LLM 服务
|
|||
|
|
chat_services = get_cached_chat_services()
|
|||
|
|
if not chat_services:
|
|||
|
|
raise RuntimeError("没有可用的 LLM 服务")
|
|||
|
|
|
|||
|
|
print(f"✓ 可用模型: {list(chat_services.keys())}")
|
|||
|
|
|
|||
|
|
# 选择 zhipu 或 deepseek 作为测试模型,避免 Baosi API 的问题
|
|||
|
|
test_model = None
|
|||
|
|
if "zhipu" in chat_services:
|
|||
|
|
test_model = "zhipu"
|
|||
|
|
print(f"✓ 选择 zhipu 作为测试模型")
|
|||
|
|
elif "deepseek" in chat_services:
|
|||
|
|
test_model = "deepseek"
|
|||
|
|
print(f"✓ 选择 deepseek 作为测试模型")
|
|||
|
|
elif "local" in chat_services:
|
|||
|
|
test_model = "local"
|
|||
|
|
print(f"✓ 选择 local 作为测试模型")
|
|||
|
|
else:
|
|||
|
|
# 用第一个可用的
|
|||
|
|
test_model = list(chat_services.keys())[0]
|
|||
|
|
print(f"✓ 选择 {test_model} 作为测试模型")
|
|||
|
|
|
|||
|
|
# 只保留选中的模型,方便测试
|
|||
|
|
test_chat_services = {test_model: chat_services[test_model]}
|
|||
|
|
|
|||
|
|
# 构建图(使用新的 build_agent_graph)
|
|||
|
|
graph_builder = build_agent_graph(
|
|||
|
|
chat_services=test_chat_services
|
|||
|
|
)
|
|||
|
|
graph = graph_builder.compile()
|
|||
|
|
|
|||
|
|
print(f"✓ 图构建完成")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
return graph, test_chat_services
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_test_state(query: str, user_id: str = "test_user") -> dict:
|
|||
|
|
"""创建测试状态"""
|
|||
|
|
from langchain_core.messages import HumanMessage
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"messages": [HumanMessage(content=query)],
|
|||
|
|
"user_id": user_id,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def run_single_test(graph, test_case: dict) -> dict:
|
|||
|
|
"""运行单个测试"""
|
|||
|
|
name = test_case["name"]
|
|||
|
|
query = test_case["query"]
|
|||
|
|
description = test_case["description"]
|
|||
|
|
|
|||
|
|
print(f"\n{'=' * 60}")
|
|||
|
|
print(f"测试: {name}")
|
|||
|
|
print(f"描述: {description}")
|
|||
|
|
print(f"查询: {query}")
|
|||
|
|
print(f"{'=' * 60}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 创建初始状态
|
|||
|
|
input_state = create_test_state(query)
|
|||
|
|
|
|||
|
|
# 配置
|
|||
|
|
config = {
|
|||
|
|
"configurable": {
|
|||
|
|
"thread_id": f"test_{name}"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 执行图
|
|||
|
|
print("开始执行图...")
|
|||
|
|
result = await graph.ainvoke(input_state, config=config)
|
|||
|
|
|
|||
|
|
# 提取最终回复
|
|||
|
|
reply = ""
|
|||
|
|
if result.get("messages"):
|
|||
|
|
reply = result["messages"][-1].content
|
|||
|
|
|
|||
|
|
print(f"\n✓ 执行完成")
|
|||
|
|
print(f"最终回复: {reply[:500]}{'...' if len(reply) > 500 else ''}")
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"name": name,
|
|||
|
|
"success": True,
|
|||
|
|
"reply": reply,
|
|||
|
|
"state": result
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"\n✗ 测试失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
print(f"堆栈: {traceback.format_exc()}")
|
|||
|
|
return {
|
|||
|
|
"name": name,
|
|||
|
|
"success": False,
|
|||
|
|
"error": str(e)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("极简 Agent 架构测试")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 设置环境
|
|||
|
|
graph, chat_services = await setup_test_environment()
|
|||
|
|
|
|||
|
|
# 运行所有测试
|
|||
|
|
results = []
|
|||
|
|
for test_case in TEST_CASES:
|
|||
|
|
result = await run_single_test(graph, test_case)
|
|||
|
|
results.append(result)
|
|||
|
|
|
|||
|
|
# 稍微间隔一下
|
|||
|
|
await asyncio.sleep(0.5)
|
|||
|
|
|
|||
|
|
# 总结
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试总结")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
total = len(results)
|
|||
|
|
passed = sum(1 for r in results if r["success"])
|
|||
|
|
failed = total - passed
|
|||
|
|
|
|||
|
|
print(f"\n总测试数: {total}")
|
|||
|
|
print(f"通过: {passed}")
|
|||
|
|
print(f"失败: {failed}")
|
|||
|
|
|
|||
|
|
print("\n详细结果:")
|
|||
|
|
for result in results:
|
|||
|
|
status = "✓ 通过" if result["success"] else "✗ 失败"
|
|||
|
|
print(f" {result['name']}: {status}")
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
if failed == 0:
|
|||
|
|
print("🎉 所有测试通过!")
|
|||
|
|
else:
|
|||
|
|
print(f"⚠️ 有 {failed} 个测试失败")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"\n测试运行失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
print(traceback.format_exc())
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(main())
|