From c9bf21be0e7ebf3916c8577f7a2f2e80d8871280 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Mon, 4 May 2026 18:59:15 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20RAG=20=E6=97=A0?= =?UTF-8?q?=E9=99=90=E5=BE=AA=E7=8E=AF=E9=97=AE=E9=A2=98=E5=92=8C=E5=AF=BC?= =?UTF-8?q?=E5=85=A5=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 主要修复: 1. 修复 RAG 推理无限循环问题(大小写不匹配 + 缺少已检索结果检查) 2. 修复 intent_classifier.py 的绝对导入错误 3. 删除旧的 start.sh 脚本,添加新的启动脚本 4. 优化路由逻辑和状态管理 --- .gitignore | 2 - backend/app/agent/agent_service.py | 20 +- backend/app/core/intent.py | 11 +- backend/app/core/intent_classifier.py | 2 +- backend/app/main_graph/nodes/hybrid_router.py | 12 +- backend/app/main_graph/nodes/react_nodes.py | 14 +- .../main_graph/utils/main_graph_builder.py | 32 +-- scripts/start.sh | 117 ---------- tools/run.py | 6 +- tools/start.py | 125 ++++++++++ tools/test/test_fast_rag_fix.py | 97 ++++++++ tools/test/test_graph_branches.py | 221 ++++++++++++++++++ tools/test/test_rag_pipeline.py | 8 +- 13 files changed, 503 insertions(+), 164 deletions(-) delete mode 100755 scripts/start.sh create mode 100755 tools/start.py create mode 100644 tools/test/test_fast_rag_fix.py create mode 100644 tools/test/test_graph_branches.py diff --git a/.gitignore b/.gitignore index 5391e4e..2c5016e 100644 --- a/.gitignore +++ b/.gitignore @@ -11,8 +11,6 @@ !backend/** !frontend/ !frontend/** -!scripts/ -!scripts/** !rag_indexer/ !rag_indexer/** !docker/ diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 24c134b..6512ada 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -7,14 +7,14 @@ import json import asyncio # 本地模块 -from app.main_graph.utils.main_graph_builder import build_react_main_graph -from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME -from app.main_graph.config import set_stream_writer +from ..main_graph.utils.main_graph_builder import build_react_main_graph +from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +from ..main_graph.config import set_stream_writer from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider -from app.main_graph.utils.rag_initializer import init_rag_tool -from app.core.intent_classifier import get_intent_classifier -from app.logger import info, warning, error -from app.main_graph.state import MainGraphState, CurrentAction +from ..main_graph.utils.rag_initializer import init_rag_tool +from ..core.intent_classifier import get_intent_classifier +from ..logger import info, warning, error +from ..main_graph.state import MainGraphState, CurrentAction class AIAgentService: @@ -32,7 +32,7 @@ class AIAgentService: async def initialize(self): # 0. 初始化 Mem0 客户端 - from app.memory.mem0_client import Mem0Client + from ..memory.mem0_client import Mem0Client # 创建一个临时的 LLM 用于 Mem0(用第一个可用的) chat_services = get_all_chat_services() temp_llm = None @@ -49,7 +49,7 @@ class AIAgentService: self.tools.append(rag_tool) self.tools_by_name[rag_tool.name] = rag_tool # 关键:设置全局 RAG 工具,供 rag_nodes.py 使用 - from app.main_graph.nodes.rag_nodes import set_global_rag_tool + from ..main_graph.nodes.rag_nodes import set_global_rag_tool set_global_rag_tool(rag_tool) # 2. 构建各模型的 Graph(使用新版 React 模式) @@ -86,7 +86,7 @@ class AIAgentService: "metadata": {"user_id": user_id} } # 新版状态输入:传入完整的 MainGraphState,关键是 user_query - from app.main_graph.state import MainGraphState, CurrentAction + from ..main_graph.state import MainGraphState, CurrentAction input_state = { "user_query": message, "messages": [{"role": "user", "content": message}], diff --git a/backend/app/core/intent.py b/backend/app/core/intent.py index dd7639f..47bf56c 100644 --- a/backend/app/core/intent.py +++ b/backend/app/core/intent.py @@ -132,8 +132,17 @@ class ReactIntentReasoner: # 关键修改:不要在第一次 rag_retrieve 后就直接回答,允许再推理一次 # 让推理逻辑有机会判断 RAG 结果好不好,要不要再检索或转 web search - rag_count = previous_actions.count("rag_retrieve") + previous_actions = context.get("previous_actions", []) + rag_count = previous_actions.count("RETRIEVE_RAG") # 修复:大写 web_search_count = previous_actions.count("web_search") + retrieved_docs = context.get("retrieved_docs", []) + + # 如果已经有检索文档了,直接回答 + if retrieved_docs and len(retrieved_docs) > 0: + result.action = ReasoningAction.DIRECT_RESPONSE + result.confidence = 0.95 + result.reasoning = "已获取检索文档,直接回答" + return result # 只有当 rag 或 web search 已经超过 1 次,或者已经有推理在 rag 之后,才直接回答 if rag_count >= 2 or web_search_count >= 1: diff --git a/backend/app/core/intent_classifier.py b/backend/app/core/intent_classifier.py index b6c0493..9452ae5 100644 --- a/backend/app/core/intent_classifier.py +++ b/backend/app/core/intent_classifier.py @@ -6,7 +6,7 @@ from typing import Optional, Dict, Any import sys import os -from backend.app.model_services.chat_services import get_small_llm_service +from ..model_services.chat_services import get_small_llm_service class IntentType(Enum): diff --git a/backend/app/main_graph/nodes/hybrid_router.py b/backend/app/main_graph/nodes/hybrid_router.py index 7d3ad84..f626b8c 100644 --- a/backend/app/main_graph/nodes/hybrid_router.py +++ b/backend/app/main_graph/nodes/hybrid_router.py @@ -8,10 +8,10 @@ from typing import Dict, Any, Optional, List from dataclasses import dataclass, field from datetime import datetime -from app.main_graph.state import MainGraphState -from app.logger import info, debug -from app.model_services.chat_services import get_small_llm_service, get_chat_service -from app.main_graph.nodes.rag_nodes import rag_retrieve_node +from ..state import MainGraphState +from ...logger import info, debug +from ...model_services.chat_services import get_small_llm_service, get_chat_service +from .rag_nodes import rag_retrieve_node # ========== 核心数据类型 ========== @@ -367,8 +367,8 @@ async def fast_rag_node(state: MainGraphState, config: Optional[Dict[str, Any]] debug(f"[Fast RAG] 发送事件失败: {e}") try: - # 先尝试 RAG 检索 - state = rag_retrieve_node(state, config) + # 先尝试 RAG 检索 - 注意:rag_retrieve_node 是异步函数,需要 await + state = await rag_retrieve_node(state, config) # 检查检索结果 rag_docs = getattr(state, "rag_docs", []) diff --git a/backend/app/main_graph/nodes/react_nodes.py b/backend/app/main_graph/nodes/react_nodes.py index 7e6f102..251262c 100644 --- a/backend/app/main_graph/nodes/react_nodes.py +++ b/backend/app/main_graph/nodes/react_nodes.py @@ -364,11 +364,15 @@ def route_by_reasoning(state: MainGraphState) -> str: if "subgraph_completed" in previous_actions or state.final_result: return "llm_call" - # 关键修复:如果已经执行过 rag_retrieve 并且又执行过推理,直接去 LLM_CALL - # 这样的流程:推理1 → RAG → 推理2(判断 RAG 结果) → LLM_CALL - rag_count = previous_actions.count("rag_retrieve") - if rag_count >= 1 and len(previous_actions) >= rag_count + 1: - info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call") + # 关键修复:检测 RAG 重复循环 - 如果发现"RETRIEVE_RAG"出现超过1次,直接去 LLM + rag_count = previous_actions.count("RETRIEVE_RAG") + if rag_count >= 2: + info(f"[route_by_reasoning] 检测到 RAG 重复循环({rag_count}次),直接去 llm_call") + return "llm_call" + + # 关键修复:如果已经有 rag_docs 或 rag_context,说明已经检索过了,直接去 LLM + if (state.rag_docs and len(state.rag_docs) > 0) or (state.rag_context and len(state.rag_context) > 0): + info(f"[route_by_reasoning] 检测到已存在 RAG 检索结果,直接去 llm_call") return "llm_call" # 关键修复:限制最多 3 次推理,避免无限循环 diff --git a/backend/app/main_graph/utils/main_graph_builder.py b/backend/app/main_graph/utils/main_graph_builder.py index e1ab392..6180ea2 100644 --- a/backend/app/main_graph/utils/main_graph_builder.py +++ b/backend/app/main_graph/utils/main_graph_builder.py @@ -2,19 +2,19 @@ 整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState """ -from app.main_graph.graph import StateGraph, START, END +from ..graph import StateGraph, START, END from typing import Dict, Any, Optional from langchain_core.runnables.config import RunnableConfig -from app.main_graph.state import MainGraphState -from app.main_graph.nodes.react_nodes import ( +from ..state import MainGraphState +from ..nodes.react_nodes import ( init_state_node, react_reason_node, web_search_node, error_handling_node, route_by_reasoning ) -from app.main_graph.nodes.hybrid_router import ( +from ..nodes.hybrid_router import ( hybrid_router_node, fast_chitchat_node, fast_rag_node, @@ -22,17 +22,17 @@ from app.main_graph.nodes.hybrid_router import ( route_from_hybrid_decision, check_fast_path_success ) -from app.main_graph.nodes.llm_call import create_llm_call_node -from app.main_graph.nodes.rag_nodes import rag_retrieve_node -from app.main_graph.nodes.retrieve_memory import create_retrieve_memory_node -from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client -from app.main_graph.nodes.summarize import create_summarize_node -from app.main_graph.nodes.finalize import finalize_node -from app.subgraphs.contact import build_contact_subgraph -from app.subgraphs.dictionary import build_dictionary_subgraph -from app.subgraphs.news_analysis import build_news_analysis_subgraph -from app.memory.mem0_client import Mem0Client -from app.logger import info, debug +from ..nodes.llm_call import create_llm_call_node +from ..nodes.rag_nodes import rag_retrieve_node +from ..nodes.retrieve_memory import create_retrieve_memory_node +from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client +from ..nodes.summarize import create_summarize_node +from ..nodes.finalize import finalize_node +from ...subgraphs.contact import build_contact_subgraph +from ...subgraphs.dictionary import build_dictionary_subgraph +from ...subgraphs.news_analysis import build_news_analysis_subgraph +from ...memory.mem0_client import Mem0Client +from ...logger import info, debug # ========== 检查是否需要总结 ========== @@ -140,7 +140,7 @@ def wrap_subgraph_for_error_handling(subgraph, name: str): except Exception as e: # 捕获子图错误,传递给主图 - from app.main_graph.state import ErrorRecord, ErrorSeverity + from ..state import ErrorRecord, ErrorSeverity from datetime import datetime error_record = ErrorRecord( diff --git a/scripts/start.sh b/scripts/start.sh deleted file mode 100755 index 86d895b..0000000 --- a/scripts/start.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/bin/bash -# ============================================================================= -# AI Agent 启动与管理脚本 - 简化版 -# 用法: ./scripts/start.sh [check|backend|frontend|both] -# ============================================================================= - -set -e - -# 颜色定义 -GREEN='\033[0;32m' -BLUE='\033[0;34m' -RED='\033[0;31m' -YELLOW='\033[1;33m' -NC='\033[0m' # No Color - -# 项目根目录 -PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" - -echo -e "${BLUE}========================================${NC}" -echo -e "${BLUE} AI Agent - 个人生活助手${NC}" -echo -e "${BLUE}========================================${NC}" -echo "" - -# ============================================================================= -# 启动 Python 服务 -# ============================================================================= -start_backend() { - echo -e "\n${BLUE}🚀 启动后端服务 (端口 10079)...${NC}" - cd "$PROJECT_DIR" - - # 加载 .env 文件中的环境变量 - set -a - source .env 2>/dev/null || true - set +a - - export PYTHONPATH="$PROJECT_DIR/backend" - export BACKEND_PORT=8079 - python -m app.backend & - BACKEND_PID=$! - echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}" - sleep 2 -} - -start_frontend() { - echo -e "\n${BLUE}🎨 启动前端界面 (端口 10501)...${NC}" - cd "$PROJECT_DIR" - - # 加载 .env 文件中的环境变量 - set -a - source .env 2>/dev/null || true - set +a - - export PYTHONPATH="$PROJECT_DIR/frontend/src" - export API_URL="http://127.0.0.1:8079/chat" - streamlit run frontend/src/frontend_main.py --server.port 10501 --server.address 0.0.0.0 & - FRONTEND_PID=$! - echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}" - echo -e "${GREEN}✓ 访问地址:${NC}" - echo -e " 本地开发: http://127.0.0.1:10501" -} - -# ============================================================================= -# 清理函数 -# ============================================================================= -cleanup() { - echo -e "\n${RED}🛑 正在停止 Python 服务...${NC}" - if [ ! -z "$BACKEND_PID" ]; then - kill $BACKEND_PID 2>/dev/null || true - echo -e "${GREEN}✓ 后端服务已停止${NC}" - fi - if [ ! -z "$FRONTEND_PID" ]; then - kill $FRONTEND_PID 2>/dev/null || true - echo -e "${GREEN}✓ 前端服务已停止${NC}" - fi - exit 0 -} - -# 捕获 Ctrl+C -trap cleanup SIGINT SIGTERM - -# ============================================================================= -# 主逻辑 -# ============================================================================= -case "${1:-help}" in - backend) - start_backend - echo -e "\n${GREEN}后端服务正在运行,按 Ctrl+C 停止${NC}" - wait $BACKEND_PID - ;; - - frontend) - start_frontend - echo -e "\n${GREEN}前端服务正在运行,按 Ctrl+C 停止${NC}" - wait $FRONTEND_PID - ;; - - both) - start_backend - sleep 3 - start_frontend - echo -e "\n${GREEN}所有服务正在运行,按 Ctrl+C 停止${NC}" - wait - ;; - - help|*) - echo -e "${BLUE}用法:${NC} $0 [command]" - echo "" - echo -e "${BLUE}命令:${NC}" - echo " backend 仅启动后端服务" - echo " frontend 仅启动前端服务" - echo " both 启动前后端服务(默认)" - echo " help 显示此帮助信息" - echo "" - echo -e "${BLUE}示例:${NC}" - echo " $0 both # 启动本地开发环境" - ;; -esac diff --git a/tools/run.py b/tools/run.py index f3acdb9..8e4ca60 100644 --- a/tools/run.py +++ b/tools/run.py @@ -11,8 +11,10 @@ sys.path.insert(0, str(project_root / "backend")) load_dotenv(project_root / ".env") if __name__ == "__main__": - from rag_indexer.cli import main + #from rag_indexer.cli import main #from tools.test.test_rag_indexer_result import main #from tools.test.test_rag_pipeline import main + from tools.test.test_fast_rag_fix import main + #from tools.test.test_graph_branches import main import asyncio - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tools/start.py b/tools/start.py new file mode 100755 index 0000000..5848df7 --- /dev/null +++ b/tools/start.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +AI Agent 启动与管理脚本 - Python版 +用法: python tools/testrun.py [check|backend|frontend|both] +""" +import sys +import os +import time +import signal +import subprocess +from pathlib import Path +from dotenv import load_dotenv + +# 路径设置 +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) +sys.path.insert(0, str(project_root / "backend")) +load_dotenv(project_root / ".env") + +# 全局变量 +processes = [] + + +def start_backend(): + """启动后端服务""" + print("\n🚀 启动后端服务 (端口 8079)...") + + env = os.environ.copy() + env["PYTHONPATH"] = str(project_root / "backend") + env["BACKEND_PORT"] = "8079" + + proc = subprocess.Popen( + [sys.executable, "-m", "app.backend"], + cwd=str(project_root), + env=env + ) + processes.append(proc) + print(f"✓ 后端服务已启动 (PID: {proc.pid})") + time.sleep(2) + return proc + + +def start_frontend(): + """启动前端服务""" + print("\n🎨 启动前端界面 (端口 10501)...") + + env = os.environ.copy() + env["PYTHONPATH"] = str(project_root / "frontend/src") + env["API_URL"] = "http://127.0.0.1:8079/chat" + + frontend_main = str(project_root / "frontend" / "src" / "frontend_main.py") + proc = subprocess.Popen( + [ + sys.executable, "-m", "streamlit", "run", frontend_main, + "--server.port", "10501", "--server.address", "0.0.0.0" + ], + cwd=str(project_root), + env=env + ) + processes.append(proc) + print(f"✓ 前端服务已启动 (PID: {proc.pid})") + print("✓ 访问地址:") + print(" 本地开发: http://127.0.0.1:10501") + return proc + + +def cleanup(signum, frame): + """清理函数 - 停止所有进程""" + print("\n🛑 正在停止服务...") + for i, proc in enumerate(processes): + if proc.poll() is None: # 进程还在运行 + proc.terminate() + proc.wait(timeout=5) + print(f"✓ 服务 {i+1} 已停止") + sys.exit(0) + + +def print_help(): + """显示帮助信息""" + print("========================================") + print(" AI Agent - 个人生活助手") + print("========================================") + print("\n用法: python tools/testrun.py [command]") + print("\n命令:") + print(" backend 仅启动后端服务") + print(" frontend 仅启动前端服务") + print(" both 启动前后端服务(默认)") + print(" help 显示此帮助信息") + print("\n示例:") + print(" python tools/testrun.py both # 启动本地开发环境") + + +def main(): + """主函数""" + print("========================================") + print(" AI Agent - 个人生活助手") + print("========================================") + + # 捕获信号 + signal.signal(signal.SIGINT, cleanup) + signal.signal(signal.SIGTERM, cleanup) + + cmd = sys.argv[1] if len(sys.argv) > 1 else "both" + + if cmd == "backend": + start_backend() + print("\n后端服务正在运行,按 Ctrl+C 停止") + processes[0].wait() + elif cmd == "frontend": + start_frontend() + print("\n前端服务正在运行,按 Ctrl+C 停止") + processes[0].wait() + elif cmd == "both": + start_backend() + time.sleep(3) + start_frontend() + print("\n所有服务正在运行,按 Ctrl+C 停止") + for proc in processes: + proc.wait() + else: + print_help() + + +if __name__ == "__main__": + main() diff --git a/tools/test/test_fast_rag_fix.py b/tools/test/test_fast_rag_fix.py new file mode 100644 index 0000000..d16da8b --- /dev/null +++ b/tools/test/test_fast_rag_fix.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +快速测试 - 测试 fast_rag 路径修复 +""" +import sys +import asyncio +from pathlib import Path +from dotenv import load_dotenv + +# 路径设置 +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) +sys.path.insert(0, str(project_root / "backend")) +load_dotenv(project_root / ".env") + +from app.main_graph.state import MainGraphState, CurrentAction +from app.main_graph.utils.main_graph_builder import build_react_main_graph +from app.model_services.chat_services import get_all_chat_services +from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS + + +async def test_fast_rag_path(): + """测试 fast_rag 路径""" + print("=" * 60) + print("测试 fast_rag 路径修复") + print("=" * 60) + + # 1. 获取 LLM + chat_services = get_all_chat_services() + if not chat_services: + print("✗ 没有可用的 LLM 服务") + return + + llm = list(chat_services.values())[0] + print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}") + + # 2. 构建图 + graph = build_react_main_graph( + llm=llm, + tools=AVAILABLE_TOOLS, + use_hybrid_router=True + ).compile() + print(f"✓ 图构建完成") + + # 3. 测试问题 + test_query = "吕布和张飞谁厉害?" + print(f"\n测试问题: {test_query}") + + # 4. 创建状态 + input_state = { + "user_query": test_query, + "messages": [{"role": "user", "content": test_query}], + "user_id": "test_user", + "current_action": CurrentAction.NONE + } + + # 5. 执行 + print("开始执行...") + try: + result = await graph.ainvoke( + input_state, + config={"configurable": {"thread_id": "test_fast_rag"}} + ) + + print(f"\n✓ 执行成功!") + print(f"最终回答: {result.get('final_result', '')[:300]}") + + # 调试信息 + debug_info = result.get("debug_info", {}) + print(f"\n调试信息:") + if "fast_path_failed" in debug_info: + print(f" - fast_path_failed: {debug_info['fast_path_failed']}") + if "fast_path_fail_reason" in debug_info: + print(f" - fast_path_fail_reason: {debug_info['fast_path_fail_reason']}") + + except Exception as e: + print(f"\n✗ 执行失败: {e}") + import traceback + print(traceback.format_exc()) + return False + + return True + + +async def main(): + success = await test_fast_rag_path() + if success: + print("\n🎉 测试通过!") + else: + print("\n⚠️ 测试失败") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n测试被中断") diff --git a/tools/test/test_graph_branches.py b/tools/test/test_graph_branches.py new file mode 100644 index 0000000..5ed05be --- /dev/null +++ b/tools/test/test_graph_branches.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +""" +主图完整测试 - 覆盖各个分支 +""" +import sys +import asyncio +from pathlib import Path +from dotenv import load_dotenv + +# 路径设置 +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) +sys.path.insert(0, str(project_root / "backend")) +load_dotenv(project_root / ".env") + +from app.main_graph.state import MainGraphState, CurrentAction +from app.main_graph.utils.main_graph_builder import build_react_main_graph +from app.model_services.chat_services import get_all_chat_services +from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS +from app.main_graph.utils.rag_initializer import init_rag_tool + + +# ========== 测试用例配置 ========== +TEST_CASES = [ + # 测试1: 简单闲聊 - 应该走 fast_chitchat + { + "name": "闲聊测试", + "query": "你好!", + "description": "测试快速闲聊分支" + }, + # 测试2: 知识查询 - 应该走 fast_rag,然后可能升级到 react + { + "name": "知识查询测试", + "query": "什么是机器学习?", + "description": "测试快速 RAG 分支" + }, + # 测试3: 需要推理的复杂问题 - 应该直接到 React 循环 + { + "name": "复杂推理测试", + "query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?", + "description": "测试 React 循环推理分支" + }, + # 测试4: 需要工具调用的问题 + { + "name": "工具调用测试", + "query": "搜索一下今天的天气怎么样", + "description": "测试工具调用分支" + }, + # 测试5: 带记忆的对话 + { + "name": "记忆测试", + "query": "你刚才回答了我什么问题?", + "description": "测试记忆检索分支", + "thread_id": "test_memory_thread" + } +] + + +async def setup_test_environment(): + """设置测试环境""" + print("=" * 60) + print("设置测试环境...") + print("=" * 60) + + # 获取 LLM 服务 + chat_services = get_all_chat_services() + if not chat_services: + raise RuntimeError("没有可用的 LLM 服务") + + llm = list(chat_services.values())[0] + print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}") + + # 初始化 RAG 工具 + def create_local_llm(): + return llm + + rag_tool = await init_rag_tool(create_local_llm) + tools = AVAILABLE_TOOLS.copy() + if rag_tool: + tools.append(rag_tool) + print(f"✓ RAG 工具初始化成功") + + # 构建图 + graph = build_react_main_graph( + llm=llm, + tools=tools, + use_hybrid_router=True + ).compile() + + print(f"✓ 图构建完成") + print() + + return graph + + +def create_test_state(query: str, thread_id: str = None) -> dict: + """创建测试状态""" + return { + "user_query": query, + "messages": [{"role": "user", "content": query}], + "user_id": "test_user", + "current_action": CurrentAction.NONE + } + + +async def run_single_test(graph, test_case: dict) -> dict: + """运行单个测试""" + name = test_case["name"] + query = test_case["query"] + description = test_case["description"] + thread_id = test_case.get("thread_id", "test_thread") + + print(f"\n{'=' * 60}") + print(f"测试: {name}") + print(f"描述: {description}") + print(f"查询: {query}") + print(f"{'=' * 60}") + + try: + # 创建初始状态 + input_state = create_test_state(query, thread_id) + + # 配置 + config = { + "configurable": {"thread_id": thread_id} + } + + # 执行图 + print("开始执行图...") + result = await graph.ainvoke(input_state, config=config) + + # 检查结果 + success = result.get("success", False) + final_result = result.get("final_result", "") + + print(f"\n结果:") + print(f" 成功: {'✓' if success else '✗'}") + print(f" 最终回答: {final_result[:200]}{'...' if len(final_result) > 200 else ''}") + + # 调试信息 + if "debug_info" in result: + debug_info = result["debug_info"] + print(f" 调试信息:") + if "fast_path_failed" in debug_info: + print(f" - 快速路径失败: {debug_info['fast_path_failed']}") + if "fast_path_fail_reason" in debug_info: + print(f" - 失败原因: {debug_info['fast_path_fail_reason']}") + if "hybrid_decision" in debug_info: + decision = debug_info["hybrid_decision"] + print(f" - 路由决策: {decision.path if hasattr(decision, 'path') else 'unknown'}") + + return { + "name": name, + "success": success, + "result": result + } + + except Exception as e: + print(f"\n✗ 测试失败: {e}") + import traceback + print(f"堆栈: {traceback.format_exc()}") + return { + "name": name, + "success": False, + "error": str(e) + } + + +async def main(): + """主函数""" + print("\n" + "=" * 60) + print("主图完整测试套件") + print("=" * 60) + + # 设置环境 + graph = await setup_test_environment() + + # 运行所有测试 + results = [] + for test_case in TEST_CASES: + result = await run_single_test(graph, test_case) + results.append(result) + + # 稍微间隔一下 + await asyncio.sleep(1) + + # 总结 + print("\n" + "=" * 60) + print("测试总结") + print("=" * 60) + + total = len(results) + passed = sum(1 for r in results if r["success"]) + failed = total - passed + + print(f"\n总测试数: {total}") + print(f"通过: {passed}") + print(f"失败: {failed}") + + print("\n详细结果:") + for result in results: + status = "✓ 通过" if result["success"] else "✗ 失败" + print(f" {result['name']}: {status}") + + print("\n" + "=" * 60) + if failed == 0: + print("🎉 所有测试通过!") + else: + print(f"⚠️ 有 {failed} 个测试失败") + print("=" * 60) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n\n测试被用户中断") + except Exception as e: + print(f"\n\n测试运行失败: {e}") + import traceback + print(traceback.format_exc()) diff --git a/tools/test/test_rag_pipeline.py b/tools/test/test_rag_pipeline.py index cbe60a6..1df959a 100644 --- a/tools/test/test_rag_pipeline.py +++ b/tools/test/test_rag_pipeline.py @@ -22,7 +22,7 @@ async def test_rag_pipeline_direct(): rerank_top_n=5 ) - query = "黄双银的经历" + query = "吕布的经历" print(f"\n用户查询: {query}") print("-" * 80) @@ -64,7 +64,7 @@ async def test_rag_tool(): rerank_top_n=5 ) - query = "黄双银的经历" + query = "吕布的经历" print(f"\n用户查询: {query}") print("-" * 80) @@ -91,7 +91,7 @@ async def test_custom_pipeline(): rerank_top_n=3 # 只返回前 3 个最相关文档 ) - query = "黄双银的经历" + query = "吕布的经历" print(f"\n用户查询: {query}") print(f"配置: num_queries=2, rerank_top_n=3") @@ -124,7 +124,7 @@ async def main(): """主测试函数""" print("\n" + "="*80) print("完整 RAG Pipeline 测试") - print("查询: '黄双银的经历'") + print("查询: '吕布的经历'") print("="*80) # 测试 1: 直接使用 pipeline