主要修复: 1. 修复 RAG 推理无限循环问题(大小写不匹配 + 缺少已检索结果检查) 2. 修复 intent_classifier.py 的绝对导入错误 3. 删除旧的 start.sh 脚本,添加新的启动脚本 4. 优化路由逻辑和状态管理
This commit is contained in:
97
tools/test/test_fast_rag_fix.py
Normal file
97
tools/test/test_fast_rag_fix.py
Normal file
@@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
快速测试 - 测试 fast_rag 路径修复
|
||||
"""
|
||||
import sys
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 路径设置
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
sys.path.insert(0, str(project_root / "backend"))
|
||||
load_dotenv(project_root / ".env")
|
||||
|
||||
from app.main_graph.state import MainGraphState, CurrentAction
|
||||
from app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from app.model_services.chat_services import get_all_chat_services
|
||||
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
||||
|
||||
|
||||
async def test_fast_rag_path():
|
||||
"""测试 fast_rag 路径"""
|
||||
print("=" * 60)
|
||||
print("测试 fast_rag 路径修复")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取 LLM
|
||||
chat_services = get_all_chat_services()
|
||||
if not chat_services:
|
||||
print("✗ 没有可用的 LLM 服务")
|
||||
return
|
||||
|
||||
llm = list(chat_services.values())[0]
|
||||
print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}")
|
||||
|
||||
# 2. 构建图
|
||||
graph = build_react_main_graph(
|
||||
llm=llm,
|
||||
tools=AVAILABLE_TOOLS,
|
||||
use_hybrid_router=True
|
||||
).compile()
|
||||
print(f"✓ 图构建完成")
|
||||
|
||||
# 3. 测试问题
|
||||
test_query = "吕布和张飞谁厉害?"
|
||||
print(f"\n测试问题: {test_query}")
|
||||
|
||||
# 4. 创建状态
|
||||
input_state = {
|
||||
"user_query": test_query,
|
||||
"messages": [{"role": "user", "content": test_query}],
|
||||
"user_id": "test_user",
|
||||
"current_action": CurrentAction.NONE
|
||||
}
|
||||
|
||||
# 5. 执行
|
||||
print("开始执行...")
|
||||
try:
|
||||
result = await graph.ainvoke(
|
||||
input_state,
|
||||
config={"configurable": {"thread_id": "test_fast_rag"}}
|
||||
)
|
||||
|
||||
print(f"\n✓ 执行成功!")
|
||||
print(f"最终回答: {result.get('final_result', '')[:300]}")
|
||||
|
||||
# 调试信息
|
||||
debug_info = result.get("debug_info", {})
|
||||
print(f"\n调试信息:")
|
||||
if "fast_path_failed" in debug_info:
|
||||
print(f" - fast_path_failed: {debug_info['fast_path_failed']}")
|
||||
if "fast_path_fail_reason" in debug_info:
|
||||
print(f" - fast_path_fail_reason: {debug_info['fast_path_fail_reason']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 执行失败: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def main():
|
||||
success = await test_fast_rag_path()
|
||||
if success:
|
||||
print("\n🎉 测试通过!")
|
||||
else:
|
||||
print("\n⚠️ 测试失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n测试被中断")
|
||||
Reference in New Issue
Block a user