206 lines
5.5 KiB
Python
206 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
极简 Agent 架构测试 - 适配新架构
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# 添加项目路径
|
||
project_root = Path(__file__).resolve().parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
import asyncio
|
||
from dotenv import load_dotenv
|
||
|
||
# 加载环境变量
|
||
load_dotenv(project_root / ".env")
|
||
|
||
from backend.app.main_graph.state import AgentState
|
||
from backend.app.main_graph.main_graph_builder import build_agent_graph
|
||
from backend.app.model_services.chat_services import get_cached_chat_services
|
||
|
||
|
||
# ========== 测试用例配置 ==========
|
||
TEST_CASES = [
|
||
# 测试1: 简单闲聊
|
||
{
|
||
"name": "闲聊测试",
|
||
"query": "你好!",
|
||
"description": "测试简单对话"
|
||
},
|
||
# 测试2: 知识查询
|
||
{
|
||
"name": "知识库测试",
|
||
"query": "吕布的事迹?",
|
||
"description": "测试 RAG 工具调用"
|
||
},
|
||
# 测试3: 简单问题
|
||
{
|
||
"name": "简单问答测试",
|
||
"query": "介绍一下你自己",
|
||
"description": "测试直接回答能力"
|
||
},
|
||
]
|
||
|
||
|
||
async def setup_test_environment():
|
||
"""设置测试环境"""
|
||
print("=" * 60)
|
||
print("设置测试环境...")
|
||
print("=" * 60)
|
||
|
||
# 获取 LLM 服务
|
||
chat_services = get_cached_chat_services()
|
||
if not chat_services:
|
||
raise RuntimeError("没有可用的 LLM 服务")
|
||
|
||
print(f"✓ 可用模型: {list(chat_services.keys())}")
|
||
|
||
# 选择 zhipu 或 deepseek 作为测试模型,避免 Baosi API 的问题
|
||
test_model = None
|
||
if "zhipu" in chat_services:
|
||
test_model = "zhipu"
|
||
print(f"✓ 选择 zhipu 作为测试模型")
|
||
elif "deepseek" in chat_services:
|
||
test_model = "deepseek"
|
||
print(f"✓ 选择 deepseek 作为测试模型")
|
||
elif "local" in chat_services:
|
||
test_model = "local"
|
||
print(f"✓ 选择 local 作为测试模型")
|
||
else:
|
||
# 用第一个可用的
|
||
test_model = list(chat_services.keys())[0]
|
||
print(f"✓ 选择 {test_model} 作为测试模型")
|
||
|
||
# 只保留选中的模型,方便测试
|
||
test_chat_services = {test_model: chat_services[test_model]}
|
||
|
||
# 构建图(使用新的 build_agent_graph)
|
||
graph_builder = build_agent_graph(
|
||
chat_services=test_chat_services
|
||
)
|
||
graph = graph_builder.compile()
|
||
|
||
print(f"✓ 图构建完成")
|
||
print()
|
||
|
||
return graph, test_chat_services
|
||
|
||
|
||
def create_test_state(query: str, user_id: str = "test_user") -> dict:
|
||
"""创建测试状态"""
|
||
from langchain_core.messages import HumanMessage
|
||
|
||
return {
|
||
"messages": [HumanMessage(content=query)],
|
||
"user_id": user_id,
|
||
}
|
||
|
||
|
||
async def run_single_test(graph, test_case: dict) -> dict:
|
||
"""运行单个测试"""
|
||
name = test_case["name"]
|
||
query = test_case["query"]
|
||
description = test_case["description"]
|
||
|
||
print(f"\n{'=' * 60}")
|
||
print(f"测试: {name}")
|
||
print(f"描述: {description}")
|
||
print(f"查询: {query}")
|
||
print(f"{'=' * 60}")
|
||
|
||
try:
|
||
# 创建初始状态
|
||
input_state = create_test_state(query)
|
||
|
||
# 配置
|
||
config = {
|
||
"configurable": {
|
||
"thread_id": f"test_{name}"
|
||
}
|
||
}
|
||
|
||
# 执行图
|
||
print("开始执行图...")
|
||
result = await graph.ainvoke(input_state, config=config)
|
||
|
||
# 提取最终回复(优先使用 final_reply)
|
||
reply = result.get("final_reply", "")
|
||
if not reply and result.get("messages"):
|
||
reply = result["messages"][-1].content
|
||
|
||
print(f"\n✓ 执行完成")
|
||
print(f"最终回复: {reply[:500]}{'...' if len(reply) > 500 else ''}")
|
||
|
||
return {
|
||
"name": name,
|
||
"success": True,
|
||
"reply": reply,
|
||
"state": result
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"\n✗ 测试失败: {e}")
|
||
import traceback
|
||
print(f"堆栈: {traceback.format_exc()}")
|
||
return {
|
||
"name": name,
|
||
"success": False,
|
||
"error": str(e)
|
||
}
|
||
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
print("\n" + "=" * 60)
|
||
print("极简 Agent 架构测试")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 设置环境
|
||
graph, chat_services = await setup_test_environment()
|
||
|
||
# 运行所有测试
|
||
results = []
|
||
for test_case in TEST_CASES:
|
||
result = await run_single_test(graph, test_case)
|
||
results.append(result)
|
||
|
||
# 稍微间隔一下
|
||
await asyncio.sleep(0.5)
|
||
|
||
# 总结
|
||
print("\n" + "=" * 60)
|
||
print("测试总结")
|
||
print("=" * 60)
|
||
|
||
total = len(results)
|
||
passed = sum(1 for r in results if r["success"])
|
||
failed = total - passed
|
||
|
||
print(f"\n总测试数: {total}")
|
||
print(f"通过: {passed}")
|
||
print(f"失败: {failed}")
|
||
|
||
print("\n详细结果:")
|
||
for result in results:
|
||
status = "✓ 通过" if result["success"] else "✗ 失败"
|
||
print(f" {result['name']}: {status}")
|
||
|
||
print("\n" + "=" * 60)
|
||
if failed == 0:
|
||
print("🎉 所有测试通过!")
|
||
else:
|
||
print(f"⚠️ 有 {failed} 个测试失败")
|
||
print("=" * 60)
|
||
|
||
except Exception as e:
|
||
print(f"\n测试运行失败: {e}")
|
||
import traceback
|
||
print(traceback.format_exc())
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|