refactor: 重构快速路径流程,统一通过 llm_call 输出
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m31s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m31s
- 重构 fast_paths.py,让 fast_chitchat 和 fast_rag 都进入 llm_call 而不是直接设置 final_result - 修改 check_fast_path_success 函数返回 'llm_call' 而不是 'success' - 更新 main_graph_builder.py 的条件边配置,支持路由到 llm_call - 在快速路径节点中添加清除 state.final_result 的逻辑,避免复用旧结果 - 重构 RAG 工具初始化方式,使用模块级变量管理 - 修改 finalize.py 让它返回 final_result - 更新 agent_service.py 的 RAG 工具注入方式 - 简化 hybrid_router.py 的代码结构 - 清理 rag_nodes.py 的全局变量相关代码 - 更新相关测试文件
This commit is contained in:
14
tools/run.py
14
tools/run.py
@@ -1,20 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
"""统一入口:设置路径后运行 RAG 索引构建 CLI"""
|
||||
"""统一入口:设置路径后运行测试"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 路径设置
|
||||
# 路径设置 - 只添加 backend 目录
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
sys.path.insert(0, str(project_root / "backend"))
|
||||
backend_path = project_root
|
||||
sys.path.insert(0, str(backend_path))
|
||||
load_dotenv(project_root / ".env")
|
||||
|
||||
if __name__ == "__main__":
|
||||
#from rag_indexer.cli import main
|
||||
#from tools.test.test_rag_indexer_result import main
|
||||
#from tools.test.test_rag_pipeline import main
|
||||
from tools.test.test_fast_rag_fix import main
|
||||
#from tools.test.test_graph_branches import main
|
||||
from tools.test.test_graph_branches import main
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -14,7 +14,6 @@ from dotenv import load_dotenv
|
||||
# 路径设置
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
sys.path.insert(0, str(project_root / "backend"))
|
||||
load_dotenv(project_root / ".env")
|
||||
|
||||
# 全局变量
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
快速测试 - 测试 fast_rag 路径修复
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from backend.app.main_graph.state import MainGraphState, CurrentAction
|
||||
from backend.app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from backend.app.model_services.chat_services import get_all_chat_services
|
||||
from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
||||
|
||||
|
||||
async def test_fast_rag_path():
|
||||
"""测试 fast_rag 路径"""
|
||||
print("=" * 60)
|
||||
print("测试 fast_rag 路径修复")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取 LLM
|
||||
chat_services = get_all_chat_services()
|
||||
if not chat_services:
|
||||
print("✗ 没有可用的 LLM 服务")
|
||||
return
|
||||
|
||||
llm = list(chat_services.values())[0]
|
||||
print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}")
|
||||
|
||||
# 2. 构建图
|
||||
graph = build_react_main_graph(
|
||||
llm=llm,
|
||||
tools=AVAILABLE_TOOLS,
|
||||
use_hybrid_router=True
|
||||
).compile()
|
||||
print(f"✓ 图构建完成")
|
||||
|
||||
# 3. 测试问题
|
||||
test_query = "吕布和张飞谁厉害?"
|
||||
print(f"\n测试问题: {test_query}")
|
||||
|
||||
# 4. 创建状态
|
||||
input_state = {
|
||||
"user_query": test_query,
|
||||
"messages": [{"role": "user", "content": test_query}],
|
||||
"user_id": "test_user",
|
||||
"current_action": CurrentAction.NONE
|
||||
}
|
||||
|
||||
# 5. 执行
|
||||
print("开始执行...")
|
||||
try:
|
||||
result = await graph.ainvoke(
|
||||
input_state,
|
||||
config={"configurable": {"thread_id": "test_fast_rag"}}
|
||||
)
|
||||
|
||||
print(f"\n✓ 执行成功!")
|
||||
print(f"最终回答: {result.get('final_result', '')[:300]}")
|
||||
|
||||
# 调试信息
|
||||
debug_info = result.get("debug_info", {})
|
||||
print(f"\n调试信息:")
|
||||
if "fast_path_failed" in debug_info:
|
||||
print(f" - fast_path_failed: {debug_info['fast_path_failed']}")
|
||||
if "fast_path_fail_reason" in debug_info:
|
||||
print(f" - fast_path_fail_reason: {debug_info['fast_path_fail_reason']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 执行失败: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def main():
|
||||
success = await test_fast_rag_path()
|
||||
if success:
|
||||
print("\n🎉 测试通过!")
|
||||
else:
|
||||
print("\n⚠️ 测试失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n测试被中断")
|
||||
@@ -7,47 +7,49 @@ import asyncio
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加 backend 到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "backend"))
|
||||
|
||||
from backend.app.main_graph.state import MainGraphState, CurrentAction
|
||||
from backend.app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from backend.app.model_services.chat_services import get_all_chat_services
|
||||
from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
||||
from backend.app.main_graph.utils.rag_initializer import init_rag_tool
|
||||
from app.main_graph.state import MainGraphState, CurrentAction
|
||||
from app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from app.model_services.chat_services import get_all_chat_services
|
||||
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
||||
from app.main_graph.utils.rag_initializer import init_rag_tool
|
||||
|
||||
|
||||
# ========== 测试用例配置 ==========
|
||||
TEST_CASES = [
|
||||
# 测试1: 简单闲聊 - 应该走 fast_chitchat
|
||||
{
|
||||
"name": "闲聊测试",
|
||||
"query": "你好!",
|
||||
"description": "测试快速闲聊分支"
|
||||
},
|
||||
# # 测试1: 简单闲聊 - 应该走 fast_chitchat
|
||||
# {
|
||||
# "name": "闲聊测试",
|
||||
# "query": "你好!",
|
||||
# "description": "测试快速闲聊分支"
|
||||
# },
|
||||
# 测试2: 知识查询 - 应该走 fast_rag,然后可能升级到 react
|
||||
{
|
||||
"name": "知识查询测试",
|
||||
"query": "什么是机器学习?",
|
||||
"query": "吕布的事迹?",
|
||||
"description": "测试快速 RAG 分支"
|
||||
},
|
||||
# 测试3: 需要推理的复杂问题 - 应该直接到 React 循环
|
||||
{
|
||||
"name": "复杂推理测试",
|
||||
"query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?",
|
||||
"description": "测试 React 循环推理分支"
|
||||
},
|
||||
# 测试4: 需要工具调用的问题
|
||||
{
|
||||
"name": "工具调用测试",
|
||||
"query": "搜索一下今天的天气怎么样",
|
||||
"description": "测试工具调用分支"
|
||||
},
|
||||
# 测试5: 带记忆的对话
|
||||
{
|
||||
"name": "记忆测试",
|
||||
"query": "你刚才回答了我什么问题?",
|
||||
"description": "测试记忆检索分支",
|
||||
"thread_id": "test_memory_thread"
|
||||
}
|
||||
# # 测试3: 需要推理的复杂问题 - 应该直接到 React 循环
|
||||
# {
|
||||
# "name": "复杂推理测试",
|
||||
# "query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?",
|
||||
# "description": "测试 React 循环推理分支"
|
||||
# },
|
||||
# # 测试4: 需要工具调用的问题
|
||||
# {
|
||||
# "name": "工具调用测试",
|
||||
# "query": "搜索一下今天的天气怎么样",
|
||||
# "description": "测试工具调用分支"
|
||||
# },
|
||||
# # 测试5: 带记忆的对话
|
||||
# {
|
||||
# "name": "记忆测试",
|
||||
# "query": "你刚才回答了我什么问题?",
|
||||
# "description": "测试记忆检索分支",
|
||||
# "thread_id": "test_memory_thread"
|
||||
# }
|
||||
]
|
||||
|
||||
|
||||
@@ -56,36 +58,36 @@ async def setup_test_environment():
|
||||
print("=" * 60)
|
||||
print("设置测试环境...")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 获取 LLM 服务
|
||||
chat_services = get_all_chat_services()
|
||||
if not chat_services:
|
||||
raise RuntimeError("没有可用的 LLM 服务")
|
||||
|
||||
|
||||
llm = list(chat_services.values())[0]
|
||||
print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}")
|
||||
|
||||
|
||||
# 初始化 RAG 工具
|
||||
def create_local_llm():
|
||||
return llm
|
||||
|
||||
|
||||
rag_tool = await init_rag_tool(create_local_llm)
|
||||
tools = AVAILABLE_TOOLS.copy()
|
||||
if rag_tool:
|
||||
tools.append(rag_tool)
|
||||
print(f"✓ RAG 工具初始化成功")
|
||||
|
||||
|
||||
# 构建图
|
||||
graph = build_react_main_graph(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
use_hybrid_router=True
|
||||
).compile()
|
||||
|
||||
|
||||
print(f"✓ 图构建完成")
|
||||
print()
|
||||
|
||||
return graph
|
||||
|
||||
return graph, rag_tool
|
||||
|
||||
|
||||
def create_test_state(query: str, thread_id: str = None) -> dict:
|
||||
@@ -98,7 +100,7 @@ def create_test_state(query: str, thread_id: str = None) -> dict:
|
||||
}
|
||||
|
||||
|
||||
async def run_single_test(graph, test_case: dict) -> dict:
|
||||
async def run_single_test(graph, rag_tool, test_case: dict) -> dict:
|
||||
"""运行单个测试"""
|
||||
name = test_case["name"]
|
||||
query = test_case["query"]
|
||||
@@ -115,9 +117,12 @@ async def run_single_test(graph, test_case: dict) -> dict:
|
||||
# 创建初始状态
|
||||
input_state = create_test_state(query, thread_id)
|
||||
|
||||
# 配置
|
||||
# 配置(注入 RAG 工具)
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id}
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"rag_tool": rag_tool
|
||||
}
|
||||
}
|
||||
|
||||
# 执行图
|
||||
@@ -168,12 +173,12 @@ async def main():
|
||||
print("=" * 60)
|
||||
|
||||
# 设置环境
|
||||
graph = await setup_test_environment()
|
||||
graph, rag_tool = await setup_test_environment()
|
||||
|
||||
# 运行所有测试
|
||||
results = []
|
||||
for test_case in TEST_CASES:
|
||||
result = await run_single_test(graph, test_case)
|
||||
result = await run_single_test(graph, rag_tool, test_case)
|
||||
results.append(result)
|
||||
|
||||
# 稍微间隔一下
|
||||
|
||||
@@ -63,7 +63,6 @@ async def test_rag_tool():
|
||||
num_queries=3,
|
||||
rerank_top_n=5
|
||||
)
|
||||
|
||||
query = "吕布的经历"
|
||||
|
||||
print(f"\n用户查询: {query}")
|
||||
|
||||
Reference in New Issue
Block a user