refactor: 重构快速路径流程,统一通过 llm_call 输出
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m31s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m31s
- 重构 fast_paths.py,让 fast_chitchat 和 fast_rag 都进入 llm_call 而不是直接设置 final_result - 修改 check_fast_path_success 函数返回 'llm_call' 而不是 'success' - 更新 main_graph_builder.py 的条件边配置,支持路由到 llm_call - 在快速路径节点中添加清除 state.final_result 的逻辑,避免复用旧结果 - 重构 RAG 工具初始化方式,使用模块级变量管理 - 修改 finalize.py 让它返回 final_result - 更新 agent_service.py 的 RAG 工具注入方式 - 简化 hybrid_router.py 的代码结构 - 清理 rag_nodes.py 的全局变量相关代码 - 更新相关测试文件
This commit is contained in:
@@ -7,47 +7,49 @@ import asyncio
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加 backend 到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "backend"))
|
||||
|
||||
from backend.app.main_graph.state import MainGraphState, CurrentAction
|
||||
from backend.app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from backend.app.model_services.chat_services import get_all_chat_services
|
||||
from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
||||
from backend.app.main_graph.utils.rag_initializer import init_rag_tool
|
||||
from app.main_graph.state import MainGraphState, CurrentAction
|
||||
from app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from app.model_services.chat_services import get_all_chat_services
|
||||
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
||||
from app.main_graph.utils.rag_initializer import init_rag_tool
|
||||
|
||||
|
||||
# ========== 测试用例配置 ==========
|
||||
TEST_CASES = [
|
||||
# 测试1: 简单闲聊 - 应该走 fast_chitchat
|
||||
{
|
||||
"name": "闲聊测试",
|
||||
"query": "你好!",
|
||||
"description": "测试快速闲聊分支"
|
||||
},
|
||||
# # 测试1: 简单闲聊 - 应该走 fast_chitchat
|
||||
# {
|
||||
# "name": "闲聊测试",
|
||||
# "query": "你好!",
|
||||
# "description": "测试快速闲聊分支"
|
||||
# },
|
||||
# 测试2: 知识查询 - 应该走 fast_rag,然后可能升级到 react
|
||||
{
|
||||
"name": "知识查询测试",
|
||||
"query": "什么是机器学习?",
|
||||
"query": "吕布的事迹?",
|
||||
"description": "测试快速 RAG 分支"
|
||||
},
|
||||
# 测试3: 需要推理的复杂问题 - 应该直接到 React 循环
|
||||
{
|
||||
"name": "复杂推理测试",
|
||||
"query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?",
|
||||
"description": "测试 React 循环推理分支"
|
||||
},
|
||||
# 测试4: 需要工具调用的问题
|
||||
{
|
||||
"name": "工具调用测试",
|
||||
"query": "搜索一下今天的天气怎么样",
|
||||
"description": "测试工具调用分支"
|
||||
},
|
||||
# 测试5: 带记忆的对话
|
||||
{
|
||||
"name": "记忆测试",
|
||||
"query": "你刚才回答了我什么问题?",
|
||||
"description": "测试记忆检索分支",
|
||||
"thread_id": "test_memory_thread"
|
||||
}
|
||||
# # 测试3: 需要推理的复杂问题 - 应该直接到 React 循环
|
||||
# {
|
||||
# "name": "复杂推理测试",
|
||||
# "query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?",
|
||||
# "description": "测试 React 循环推理分支"
|
||||
# },
|
||||
# # 测试4: 需要工具调用的问题
|
||||
# {
|
||||
# "name": "工具调用测试",
|
||||
# "query": "搜索一下今天的天气怎么样",
|
||||
# "description": "测试工具调用分支"
|
||||
# },
|
||||
# # 测试5: 带记忆的对话
|
||||
# {
|
||||
# "name": "记忆测试",
|
||||
# "query": "你刚才回答了我什么问题?",
|
||||
# "description": "测试记忆检索分支",
|
||||
# "thread_id": "test_memory_thread"
|
||||
# }
|
||||
]
|
||||
|
||||
|
||||
@@ -56,36 +58,36 @@ async def setup_test_environment():
|
||||
print("=" * 60)
|
||||
print("设置测试环境...")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 获取 LLM 服务
|
||||
chat_services = get_all_chat_services()
|
||||
if not chat_services:
|
||||
raise RuntimeError("没有可用的 LLM 服务")
|
||||
|
||||
|
||||
llm = list(chat_services.values())[0]
|
||||
print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}")
|
||||
|
||||
|
||||
# 初始化 RAG 工具
|
||||
def create_local_llm():
|
||||
return llm
|
||||
|
||||
|
||||
rag_tool = await init_rag_tool(create_local_llm)
|
||||
tools = AVAILABLE_TOOLS.copy()
|
||||
if rag_tool:
|
||||
tools.append(rag_tool)
|
||||
print(f"✓ RAG 工具初始化成功")
|
||||
|
||||
|
||||
# 构建图
|
||||
graph = build_react_main_graph(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
use_hybrid_router=True
|
||||
).compile()
|
||||
|
||||
|
||||
print(f"✓ 图构建完成")
|
||||
print()
|
||||
|
||||
return graph
|
||||
|
||||
return graph, rag_tool
|
||||
|
||||
|
||||
def create_test_state(query: str, thread_id: str = None) -> dict:
|
||||
@@ -98,7 +100,7 @@ def create_test_state(query: str, thread_id: str = None) -> dict:
|
||||
}
|
||||
|
||||
|
||||
async def run_single_test(graph, test_case: dict) -> dict:
|
||||
async def run_single_test(graph, rag_tool, test_case: dict) -> dict:
|
||||
"""运行单个测试"""
|
||||
name = test_case["name"]
|
||||
query = test_case["query"]
|
||||
@@ -115,9 +117,12 @@ async def run_single_test(graph, test_case: dict) -> dict:
|
||||
# 创建初始状态
|
||||
input_state = create_test_state(query, thread_id)
|
||||
|
||||
# 配置
|
||||
# 配置(注入 RAG 工具)
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id}
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"rag_tool": rag_tool
|
||||
}
|
||||
}
|
||||
|
||||
# 执行图
|
||||
@@ -168,12 +173,12 @@ async def main():
|
||||
print("=" * 60)
|
||||
|
||||
# 设置环境
|
||||
graph = await setup_test_environment()
|
||||
graph, rag_tool = await setup_test_environment()
|
||||
|
||||
# 运行所有测试
|
||||
results = []
|
||||
for test_case in TEST_CASES:
|
||||
result = await run_single_test(graph, test_case)
|
||||
result = await run_single_test(graph, rag_tool, test_case)
|
||||
results.append(result)
|
||||
|
||||
# 稍微间隔一下
|
||||
|
||||
Reference in New Issue
Block a user