主要修复: 1. 修复 RAG 推理无限循环问题(大小写不匹配 + 缺少已检索结果检查) 2. 修复 intent_classifier.py 的绝对导入错误 3. 删除旧的 start.sh 脚本,添加新的启动脚本 4. 优化路由逻辑和状态管理
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,8 +11,6 @@
|
|||||||
!backend/**
|
!backend/**
|
||||||
!frontend/
|
!frontend/
|
||||||
!frontend/**
|
!frontend/**
|
||||||
!scripts/
|
|
||||||
!scripts/**
|
|
||||||
!rag_indexer/
|
!rag_indexer/
|
||||||
!rag_indexer/**
|
!rag_indexer/**
|
||||||
!docker/
|
!docker/
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ import json
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# 本地模块
|
# 本地模块
|
||||||
from app.main_graph.utils.main_graph_builder import build_react_main_graph
|
from ..main_graph.utils.main_graph_builder import build_react_main_graph
|
||||||
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||||
from app.main_graph.config import set_stream_writer
|
from ..main_graph.config import set_stream_writer
|
||||||
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
|
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
|
||||||
from app.main_graph.utils.rag_initializer import init_rag_tool
|
from ..main_graph.utils.rag_initializer import init_rag_tool
|
||||||
from app.core.intent_classifier import get_intent_classifier
|
from ..core.intent_classifier import get_intent_classifier
|
||||||
from app.logger import info, warning, error
|
from ..logger import info, warning, error
|
||||||
from app.main_graph.state import MainGraphState, CurrentAction
|
from ..main_graph.state import MainGraphState, CurrentAction
|
||||||
|
|
||||||
|
|
||||||
class AIAgentService:
|
class AIAgentService:
|
||||||
@@ -32,7 +32,7 @@ class AIAgentService:
|
|||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
# 0. 初始化 Mem0 客户端
|
# 0. 初始化 Mem0 客户端
|
||||||
from app.memory.mem0_client import Mem0Client
|
from ..memory.mem0_client import Mem0Client
|
||||||
# 创建一个临时的 LLM 用于 Mem0(用第一个可用的)
|
# 创建一个临时的 LLM 用于 Mem0(用第一个可用的)
|
||||||
chat_services = get_all_chat_services()
|
chat_services = get_all_chat_services()
|
||||||
temp_llm = None
|
temp_llm = None
|
||||||
@@ -49,7 +49,7 @@ class AIAgentService:
|
|||||||
self.tools.append(rag_tool)
|
self.tools.append(rag_tool)
|
||||||
self.tools_by_name[rag_tool.name] = rag_tool
|
self.tools_by_name[rag_tool.name] = rag_tool
|
||||||
# 关键:设置全局 RAG 工具,供 rag_nodes.py 使用
|
# 关键:设置全局 RAG 工具,供 rag_nodes.py 使用
|
||||||
from app.main_graph.nodes.rag_nodes import set_global_rag_tool
|
from ..main_graph.nodes.rag_nodes import set_global_rag_tool
|
||||||
set_global_rag_tool(rag_tool)
|
set_global_rag_tool(rag_tool)
|
||||||
|
|
||||||
# 2. 构建各模型的 Graph(使用新版 React 模式)
|
# 2. 构建各模型的 Graph(使用新版 React 模式)
|
||||||
@@ -86,7 +86,7 @@ class AIAgentService:
|
|||||||
"metadata": {"user_id": user_id}
|
"metadata": {"user_id": user_id}
|
||||||
}
|
}
|
||||||
# 新版状态输入:传入完整的 MainGraphState,关键是 user_query
|
# 新版状态输入:传入完整的 MainGraphState,关键是 user_query
|
||||||
from app.main_graph.state import MainGraphState, CurrentAction
|
from ..main_graph.state import MainGraphState, CurrentAction
|
||||||
input_state = {
|
input_state = {
|
||||||
"user_query": message,
|
"user_query": message,
|
||||||
"messages": [{"role": "user", "content": message}],
|
"messages": [{"role": "user", "content": message}],
|
||||||
|
|||||||
@@ -132,8 +132,17 @@ class ReactIntentReasoner:
|
|||||||
|
|
||||||
# 关键修改:不要在第一次 rag_retrieve 后就直接回答,允许再推理一次
|
# 关键修改:不要在第一次 rag_retrieve 后就直接回答,允许再推理一次
|
||||||
# 让推理逻辑有机会判断 RAG 结果好不好,要不要再检索或转 web search
|
# 让推理逻辑有机会判断 RAG 结果好不好,要不要再检索或转 web search
|
||||||
rag_count = previous_actions.count("rag_retrieve")
|
previous_actions = context.get("previous_actions", [])
|
||||||
|
rag_count = previous_actions.count("RETRIEVE_RAG") # 修复:大写
|
||||||
web_search_count = previous_actions.count("web_search")
|
web_search_count = previous_actions.count("web_search")
|
||||||
|
retrieved_docs = context.get("retrieved_docs", [])
|
||||||
|
|
||||||
|
# 如果已经有检索文档了,直接回答
|
||||||
|
if retrieved_docs and len(retrieved_docs) > 0:
|
||||||
|
result.action = ReasoningAction.DIRECT_RESPONSE
|
||||||
|
result.confidence = 0.95
|
||||||
|
result.reasoning = "已获取检索文档,直接回答"
|
||||||
|
return result
|
||||||
|
|
||||||
# 只有当 rag 或 web search 已经超过 1 次,或者已经有推理在 rag 之后,才直接回答
|
# 只有当 rag 或 web search 已经超过 1 次,或者已经有推理在 rag 之后,才直接回答
|
||||||
if rag_count >= 2 or web_search_count >= 1:
|
if rag_count >= 2 or web_search_count >= 1:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Optional, Dict, Any
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from backend.app.model_services.chat_services import get_small_llm_service
|
from ..model_services.chat_services import get_small_llm_service
|
||||||
|
|
||||||
|
|
||||||
class IntentType(Enum):
|
class IntentType(Enum):
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ from typing import Dict, Any, Optional, List
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.main_graph.state import MainGraphState
|
from ..state import MainGraphState
|
||||||
from app.logger import info, debug
|
from ...logger import info, debug
|
||||||
from app.model_services.chat_services import get_small_llm_service, get_chat_service
|
from ...model_services.chat_services import get_small_llm_service, get_chat_service
|
||||||
from app.main_graph.nodes.rag_nodes import rag_retrieve_node
|
from .rag_nodes import rag_retrieve_node
|
||||||
|
|
||||||
|
|
||||||
# ========== 核心数据类型 ==========
|
# ========== 核心数据类型 ==========
|
||||||
@@ -367,8 +367,8 @@ async def fast_rag_node(state: MainGraphState, config: Optional[Dict[str, Any]]
|
|||||||
debug(f"[Fast RAG] 发送事件失败: {e}")
|
debug(f"[Fast RAG] 发送事件失败: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 先尝试 RAG 检索
|
# 先尝试 RAG 检索 - 注意:rag_retrieve_node 是异步函数,需要 await
|
||||||
state = rag_retrieve_node(state, config)
|
state = await rag_retrieve_node(state, config)
|
||||||
|
|
||||||
# 检查检索结果
|
# 检查检索结果
|
||||||
rag_docs = getattr(state, "rag_docs", [])
|
rag_docs = getattr(state, "rag_docs", [])
|
||||||
|
|||||||
@@ -364,11 +364,15 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
|||||||
if "subgraph_completed" in previous_actions or state.final_result:
|
if "subgraph_completed" in previous_actions or state.final_result:
|
||||||
return "llm_call"
|
return "llm_call"
|
||||||
|
|
||||||
# 关键修复:如果已经执行过 rag_retrieve 并且又执行过推理,直接去 LLM_CALL
|
# 关键修复:检测 RAG 重复循环 - 如果发现"RETRIEVE_RAG"出现超过1次,直接去 LLM
|
||||||
# 这样的流程:推理1 → RAG → 推理2(判断 RAG 结果) → LLM_CALL
|
rag_count = previous_actions.count("RETRIEVE_RAG")
|
||||||
rag_count = previous_actions.count("rag_retrieve")
|
if rag_count >= 2:
|
||||||
if rag_count >= 1 and len(previous_actions) >= rag_count + 1:
|
info(f"[route_by_reasoning] 检测到 RAG 重复循环({rag_count}次),直接去 llm_call")
|
||||||
info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call")
|
return "llm_call"
|
||||||
|
|
||||||
|
# 关键修复:如果已经有 rag_docs 或 rag_context,说明已经检索过了,直接去 LLM
|
||||||
|
if (state.rag_docs and len(state.rag_docs) > 0) or (state.rag_context and len(state.rag_context) > 0):
|
||||||
|
info(f"[route_by_reasoning] 检测到已存在 RAG 检索结果,直接去 llm_call")
|
||||||
return "llm_call"
|
return "llm_call"
|
||||||
|
|
||||||
# 关键修复:限制最多 3 次推理,避免无限循环
|
# 关键修复:限制最多 3 次推理,避免无限循环
|
||||||
|
|||||||
@@ -2,19 +2,19 @@
|
|||||||
整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState
|
整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.main_graph.graph import StateGraph, START, END
|
from ..graph import StateGraph, START, END
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
from app.main_graph.state import MainGraphState
|
from ..state import MainGraphState
|
||||||
from app.main_graph.nodes.react_nodes import (
|
from ..nodes.react_nodes import (
|
||||||
init_state_node,
|
init_state_node,
|
||||||
react_reason_node,
|
react_reason_node,
|
||||||
web_search_node,
|
web_search_node,
|
||||||
error_handling_node,
|
error_handling_node,
|
||||||
route_by_reasoning
|
route_by_reasoning
|
||||||
)
|
)
|
||||||
from app.main_graph.nodes.hybrid_router import (
|
from ..nodes.hybrid_router import (
|
||||||
hybrid_router_node,
|
hybrid_router_node,
|
||||||
fast_chitchat_node,
|
fast_chitchat_node,
|
||||||
fast_rag_node,
|
fast_rag_node,
|
||||||
@@ -22,17 +22,17 @@ from app.main_graph.nodes.hybrid_router import (
|
|||||||
route_from_hybrid_decision,
|
route_from_hybrid_decision,
|
||||||
check_fast_path_success
|
check_fast_path_success
|
||||||
)
|
)
|
||||||
from app.main_graph.nodes.llm_call import create_llm_call_node
|
from ..nodes.llm_call import create_llm_call_node
|
||||||
from app.main_graph.nodes.rag_nodes import rag_retrieve_node
|
from ..nodes.rag_nodes import rag_retrieve_node
|
||||||
from app.main_graph.nodes.retrieve_memory import create_retrieve_memory_node
|
from ..nodes.retrieve_memory import create_retrieve_memory_node
|
||||||
from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||||
from app.main_graph.nodes.summarize import create_summarize_node
|
from ..nodes.summarize import create_summarize_node
|
||||||
from app.main_graph.nodes.finalize import finalize_node
|
from ..nodes.finalize import finalize_node
|
||||||
from app.subgraphs.contact import build_contact_subgraph
|
from ...subgraphs.contact import build_contact_subgraph
|
||||||
from app.subgraphs.dictionary import build_dictionary_subgraph
|
from ...subgraphs.dictionary import build_dictionary_subgraph
|
||||||
from app.subgraphs.news_analysis import build_news_analysis_subgraph
|
from ...subgraphs.news_analysis import build_news_analysis_subgraph
|
||||||
from app.memory.mem0_client import Mem0Client
|
from ...memory.mem0_client import Mem0Client
|
||||||
from app.logger import info, debug
|
from ...logger import info, debug
|
||||||
|
|
||||||
|
|
||||||
# ========== 检查是否需要总结 ==========
|
# ========== 检查是否需要总结 ==========
|
||||||
@@ -140,7 +140,7 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 捕获子图错误,传递给主图
|
# 捕获子图错误,传递给主图
|
||||||
from app.main_graph.state import ErrorRecord, ErrorSeverity
|
from ..state import ErrorRecord, ErrorSeverity
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
error_record = ErrorRecord(
|
error_record = ErrorRecord(
|
||||||
|
|||||||
117
scripts/start.sh
117
scripts/start.sh
@@ -1,117 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# =============================================================================
|
|
||||||
# AI Agent 启动与管理脚本 - 简化版
|
|
||||||
# 用法: ./scripts/start.sh [check|backend|frontend|both]
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
# 颜色定义
|
|
||||||
GREEN='\033[0;32m'
|
|
||||||
BLUE='\033[0;34m'
|
|
||||||
RED='\033[0;31m'
|
|
||||||
YELLOW='\033[1;33m'
|
|
||||||
NC='\033[0m' # No Color
|
|
||||||
|
|
||||||
# 项目根目录
|
|
||||||
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
|
||||||
|
|
||||||
echo -e "${BLUE}========================================${NC}"
|
|
||||||
echo -e "${BLUE} AI Agent - 个人生活助手${NC}"
|
|
||||||
echo -e "${BLUE}========================================${NC}"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 启动 Python 服务
|
|
||||||
# =============================================================================
|
|
||||||
start_backend() {
|
|
||||||
echo -e "\n${BLUE}🚀 启动后端服务 (端口 10079)...${NC}"
|
|
||||||
cd "$PROJECT_DIR"
|
|
||||||
|
|
||||||
# 加载 .env 文件中的环境变量
|
|
||||||
set -a
|
|
||||||
source .env 2>/dev/null || true
|
|
||||||
set +a
|
|
||||||
|
|
||||||
export PYTHONPATH="$PROJECT_DIR/backend"
|
|
||||||
export BACKEND_PORT=8079
|
|
||||||
python -m app.backend &
|
|
||||||
BACKEND_PID=$!
|
|
||||||
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
|
|
||||||
sleep 2
|
|
||||||
}
|
|
||||||
|
|
||||||
start_frontend() {
|
|
||||||
echo -e "\n${BLUE}🎨 启动前端界面 (端口 10501)...${NC}"
|
|
||||||
cd "$PROJECT_DIR"
|
|
||||||
|
|
||||||
# 加载 .env 文件中的环境变量
|
|
||||||
set -a
|
|
||||||
source .env 2>/dev/null || true
|
|
||||||
set +a
|
|
||||||
|
|
||||||
export PYTHONPATH="$PROJECT_DIR/frontend/src"
|
|
||||||
export API_URL="http://127.0.0.1:8079/chat"
|
|
||||||
streamlit run frontend/src/frontend_main.py --server.port 10501 --server.address 0.0.0.0 &
|
|
||||||
FRONTEND_PID=$!
|
|
||||||
echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}"
|
|
||||||
echo -e "${GREEN}✓ 访问地址:${NC}"
|
|
||||||
echo -e " 本地开发: http://127.0.0.1:10501"
|
|
||||||
}
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 清理函数
|
|
||||||
# =============================================================================
|
|
||||||
cleanup() {
|
|
||||||
echo -e "\n${RED}🛑 正在停止 Python 服务...${NC}"
|
|
||||||
if [ ! -z "$BACKEND_PID" ]; then
|
|
||||||
kill $BACKEND_PID 2>/dev/null || true
|
|
||||||
echo -e "${GREEN}✓ 后端服务已停止${NC}"
|
|
||||||
fi
|
|
||||||
if [ ! -z "$FRONTEND_PID" ]; then
|
|
||||||
kill $FRONTEND_PID 2>/dev/null || true
|
|
||||||
echo -e "${GREEN}✓ 前端服务已停止${NC}"
|
|
||||||
fi
|
|
||||||
exit 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# 捕获 Ctrl+C
|
|
||||||
trap cleanup SIGINT SIGTERM
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# 主逻辑
|
|
||||||
# =============================================================================
|
|
||||||
case "${1:-help}" in
|
|
||||||
backend)
|
|
||||||
start_backend
|
|
||||||
echo -e "\n${GREEN}后端服务正在运行,按 Ctrl+C 停止${NC}"
|
|
||||||
wait $BACKEND_PID
|
|
||||||
;;
|
|
||||||
|
|
||||||
frontend)
|
|
||||||
start_frontend
|
|
||||||
echo -e "\n${GREEN}前端服务正在运行,按 Ctrl+C 停止${NC}"
|
|
||||||
wait $FRONTEND_PID
|
|
||||||
;;
|
|
||||||
|
|
||||||
both)
|
|
||||||
start_backend
|
|
||||||
sleep 3
|
|
||||||
start_frontend
|
|
||||||
echo -e "\n${GREEN}所有服务正在运行,按 Ctrl+C 停止${NC}"
|
|
||||||
wait
|
|
||||||
;;
|
|
||||||
|
|
||||||
help|*)
|
|
||||||
echo -e "${BLUE}用法:${NC} $0 [command]"
|
|
||||||
echo ""
|
|
||||||
echo -e "${BLUE}命令:${NC}"
|
|
||||||
echo " backend 仅启动后端服务"
|
|
||||||
echo " frontend 仅启动前端服务"
|
|
||||||
echo " both 启动前后端服务(默认)"
|
|
||||||
echo " help 显示此帮助信息"
|
|
||||||
echo ""
|
|
||||||
echo -e "${BLUE}示例:${NC}"
|
|
||||||
echo " $0 both # 启动本地开发环境"
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
@@ -11,8 +11,10 @@ sys.path.insert(0, str(project_root / "backend"))
|
|||||||
load_dotenv(project_root / ".env")
|
load_dotenv(project_root / ".env")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from rag_indexer.cli import main
|
#from rag_indexer.cli import main
|
||||||
#from tools.test.test_rag_indexer_result import main
|
#from tools.test.test_rag_indexer_result import main
|
||||||
#from tools.test.test_rag_pipeline import main
|
#from tools.test.test_rag_pipeline import main
|
||||||
|
from tools.test.test_fast_rag_fix import main
|
||||||
|
#from tools.test.test_graph_branches import main
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
125
tools/start.py
Executable file
125
tools/start.py
Executable file
@@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
AI Agent 启动与管理脚本 - Python版
|
||||||
|
用法: python tools/testrun.py [check|backend|frontend|both]
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 路径设置
|
||||||
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
sys.path.insert(0, str(project_root / "backend"))
|
||||||
|
load_dotenv(project_root / ".env")
|
||||||
|
|
||||||
|
# 全局变量
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
|
||||||
|
def start_backend():
|
||||||
|
"""启动后端服务"""
|
||||||
|
print("\n🚀 启动后端服务 (端口 8079)...")
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["PYTHONPATH"] = str(project_root / "backend")
|
||||||
|
env["BACKEND_PORT"] = "8079"
|
||||||
|
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
[sys.executable, "-m", "app.backend"],
|
||||||
|
cwd=str(project_root),
|
||||||
|
env=env
|
||||||
|
)
|
||||||
|
processes.append(proc)
|
||||||
|
print(f"✓ 后端服务已启动 (PID: {proc.pid})")
|
||||||
|
time.sleep(2)
|
||||||
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
def start_frontend():
|
||||||
|
"""启动前端服务"""
|
||||||
|
print("\n🎨 启动前端界面 (端口 10501)...")
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["PYTHONPATH"] = str(project_root / "frontend/src")
|
||||||
|
env["API_URL"] = "http://127.0.0.1:8079/chat"
|
||||||
|
|
||||||
|
frontend_main = str(project_root / "frontend" / "src" / "frontend_main.py")
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
[
|
||||||
|
sys.executable, "-m", "streamlit", "run", frontend_main,
|
||||||
|
"--server.port", "10501", "--server.address", "0.0.0.0"
|
||||||
|
],
|
||||||
|
cwd=str(project_root),
|
||||||
|
env=env
|
||||||
|
)
|
||||||
|
processes.append(proc)
|
||||||
|
print(f"✓ 前端服务已启动 (PID: {proc.pid})")
|
||||||
|
print("✓ 访问地址:")
|
||||||
|
print(" 本地开发: http://127.0.0.1:10501")
|
||||||
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup(signum, frame):
|
||||||
|
"""清理函数 - 停止所有进程"""
|
||||||
|
print("\n🛑 正在停止服务...")
|
||||||
|
for i, proc in enumerate(processes):
|
||||||
|
if proc.poll() is None: # 进程还在运行
|
||||||
|
proc.terminate()
|
||||||
|
proc.wait(timeout=5)
|
||||||
|
print(f"✓ 服务 {i+1} 已停止")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def print_help():
|
||||||
|
"""显示帮助信息"""
|
||||||
|
print("========================================")
|
||||||
|
print(" AI Agent - 个人生活助手")
|
||||||
|
print("========================================")
|
||||||
|
print("\n用法: python tools/testrun.py [command]")
|
||||||
|
print("\n命令:")
|
||||||
|
print(" backend 仅启动后端服务")
|
||||||
|
print(" frontend 仅启动前端服务")
|
||||||
|
print(" both 启动前后端服务(默认)")
|
||||||
|
print(" help 显示此帮助信息")
|
||||||
|
print("\n示例:")
|
||||||
|
print(" python tools/testrun.py both # 启动本地开发环境")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("========================================")
|
||||||
|
print(" AI Agent - 个人生活助手")
|
||||||
|
print("========================================")
|
||||||
|
|
||||||
|
# 捕获信号
|
||||||
|
signal.signal(signal.SIGINT, cleanup)
|
||||||
|
signal.signal(signal.SIGTERM, cleanup)
|
||||||
|
|
||||||
|
cmd = sys.argv[1] if len(sys.argv) > 1 else "both"
|
||||||
|
|
||||||
|
if cmd == "backend":
|
||||||
|
start_backend()
|
||||||
|
print("\n后端服务正在运行,按 Ctrl+C 停止")
|
||||||
|
processes[0].wait()
|
||||||
|
elif cmd == "frontend":
|
||||||
|
start_frontend()
|
||||||
|
print("\n前端服务正在运行,按 Ctrl+C 停止")
|
||||||
|
processes[0].wait()
|
||||||
|
elif cmd == "both":
|
||||||
|
start_backend()
|
||||||
|
time.sleep(3)
|
||||||
|
start_frontend()
|
||||||
|
print("\n所有服务正在运行,按 Ctrl+C 停止")
|
||||||
|
for proc in processes:
|
||||||
|
proc.wait()
|
||||||
|
else:
|
||||||
|
print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
97
tools/test/test_fast_rag_fix.py
Normal file
97
tools/test/test_fast_rag_fix.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
快速测试 - 测试 fast_rag 路径修复
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 路径设置
|
||||||
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
sys.path.insert(0, str(project_root / "backend"))
|
||||||
|
load_dotenv(project_root / ".env")
|
||||||
|
|
||||||
|
from app.main_graph.state import MainGraphState, CurrentAction
|
||||||
|
from app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||||
|
from app.model_services.chat_services import get_all_chat_services
|
||||||
|
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
||||||
|
|
||||||
|
|
||||||
|
async def test_fast_rag_path():
|
||||||
|
"""测试 fast_rag 路径"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试 fast_rag 路径修复")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 1. 获取 LLM
|
||||||
|
chat_services = get_all_chat_services()
|
||||||
|
if not chat_services:
|
||||||
|
print("✗ 没有可用的 LLM 服务")
|
||||||
|
return
|
||||||
|
|
||||||
|
llm = list(chat_services.values())[0]
|
||||||
|
print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}")
|
||||||
|
|
||||||
|
# 2. 构建图
|
||||||
|
graph = build_react_main_graph(
|
||||||
|
llm=llm,
|
||||||
|
tools=AVAILABLE_TOOLS,
|
||||||
|
use_hybrid_router=True
|
||||||
|
).compile()
|
||||||
|
print(f"✓ 图构建完成")
|
||||||
|
|
||||||
|
# 3. 测试问题
|
||||||
|
test_query = "吕布和张飞谁厉害?"
|
||||||
|
print(f"\n测试问题: {test_query}")
|
||||||
|
|
||||||
|
# 4. 创建状态
|
||||||
|
input_state = {
|
||||||
|
"user_query": test_query,
|
||||||
|
"messages": [{"role": "user", "content": test_query}],
|
||||||
|
"user_id": "test_user",
|
||||||
|
"current_action": CurrentAction.NONE
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. 执行
|
||||||
|
print("开始执行...")
|
||||||
|
try:
|
||||||
|
result = await graph.ainvoke(
|
||||||
|
input_state,
|
||||||
|
config={"configurable": {"thread_id": "test_fast_rag"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n✓ 执行成功!")
|
||||||
|
print(f"最终回答: {result.get('final_result', '')[:300]}")
|
||||||
|
|
||||||
|
# 调试信息
|
||||||
|
debug_info = result.get("debug_info", {})
|
||||||
|
print(f"\n调试信息:")
|
||||||
|
if "fast_path_failed" in debug_info:
|
||||||
|
print(f" - fast_path_failed: {debug_info['fast_path_failed']}")
|
||||||
|
if "fast_path_fail_reason" in debug_info:
|
||||||
|
print(f" - fast_path_fail_reason: {debug_info['fast_path_fail_reason']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ 执行失败: {e}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
success = await test_fast_rag_path()
|
||||||
|
if success:
|
||||||
|
print("\n🎉 测试通过!")
|
||||||
|
else:
|
||||||
|
print("\n⚠️ 测试失败")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n测试被中断")
|
||||||
221
tools/test/test_graph_branches.py
Normal file
221
tools/test/test_graph_branches.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
主图完整测试 - 覆盖各个分支
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 路径设置
|
||||||
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
sys.path.insert(0, str(project_root / "backend"))
|
||||||
|
load_dotenv(project_root / ".env")
|
||||||
|
|
||||||
|
from app.main_graph.state import MainGraphState, CurrentAction
|
||||||
|
from app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||||
|
from app.model_services.chat_services import get_all_chat_services
|
||||||
|
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
|
||||||
|
from app.main_graph.utils.rag_initializer import init_rag_tool
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 测试用例配置 ==========
|
||||||
|
TEST_CASES = [
|
||||||
|
# 测试1: 简单闲聊 - 应该走 fast_chitchat
|
||||||
|
{
|
||||||
|
"name": "闲聊测试",
|
||||||
|
"query": "你好!",
|
||||||
|
"description": "测试快速闲聊分支"
|
||||||
|
},
|
||||||
|
# 测试2: 知识查询 - 应该走 fast_rag,然后可能升级到 react
|
||||||
|
{
|
||||||
|
"name": "知识查询测试",
|
||||||
|
"query": "什么是机器学习?",
|
||||||
|
"description": "测试快速 RAG 分支"
|
||||||
|
},
|
||||||
|
# 测试3: 需要推理的复杂问题 - 应该直接到 React 循环
|
||||||
|
{
|
||||||
|
"name": "复杂推理测试",
|
||||||
|
"query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?",
|
||||||
|
"description": "测试 React 循环推理分支"
|
||||||
|
},
|
||||||
|
# 测试4: 需要工具调用的问题
|
||||||
|
{
|
||||||
|
"name": "工具调用测试",
|
||||||
|
"query": "搜索一下今天的天气怎么样",
|
||||||
|
"description": "测试工具调用分支"
|
||||||
|
},
|
||||||
|
# 测试5: 带记忆的对话
|
||||||
|
{
|
||||||
|
"name": "记忆测试",
|
||||||
|
"query": "你刚才回答了我什么问题?",
|
||||||
|
"description": "测试记忆检索分支",
|
||||||
|
"thread_id": "test_memory_thread"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_test_environment():
|
||||||
|
"""设置测试环境"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("设置测试环境...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 获取 LLM 服务
|
||||||
|
chat_services = get_all_chat_services()
|
||||||
|
if not chat_services:
|
||||||
|
raise RuntimeError("没有可用的 LLM 服务")
|
||||||
|
|
||||||
|
llm = list(chat_services.values())[0]
|
||||||
|
print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}")
|
||||||
|
|
||||||
|
# 初始化 RAG 工具
|
||||||
|
def create_local_llm():
|
||||||
|
return llm
|
||||||
|
|
||||||
|
rag_tool = await init_rag_tool(create_local_llm)
|
||||||
|
tools = AVAILABLE_TOOLS.copy()
|
||||||
|
if rag_tool:
|
||||||
|
tools.append(rag_tool)
|
||||||
|
print(f"✓ RAG 工具初始化成功")
|
||||||
|
|
||||||
|
# 构建图
|
||||||
|
graph = build_react_main_graph(
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
use_hybrid_router=True
|
||||||
|
).compile()
|
||||||
|
|
||||||
|
print(f"✓ 图构建完成")
|
||||||
|
print()
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_state(query: str, thread_id: str = None) -> dict:
|
||||||
|
"""创建测试状态"""
|
||||||
|
return {
|
||||||
|
"user_query": query,
|
||||||
|
"messages": [{"role": "user", "content": query}],
|
||||||
|
"user_id": "test_user",
|
||||||
|
"current_action": CurrentAction.NONE
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def run_single_test(graph, test_case: dict) -> dict:
|
||||||
|
"""运行单个测试"""
|
||||||
|
name = test_case["name"]
|
||||||
|
query = test_case["query"]
|
||||||
|
description = test_case["description"]
|
||||||
|
thread_id = test_case.get("thread_id", "test_thread")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"测试: {name}")
|
||||||
|
print(f"描述: {description}")
|
||||||
|
print(f"查询: {query}")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建初始状态
|
||||||
|
input_state = create_test_state(query, thread_id)
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
config = {
|
||||||
|
"configurable": {"thread_id": thread_id}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行图
|
||||||
|
print("开始执行图...")
|
||||||
|
result = await graph.ainvoke(input_state, config=config)
|
||||||
|
|
||||||
|
# 检查结果
|
||||||
|
success = result.get("success", False)
|
||||||
|
final_result = result.get("final_result", "")
|
||||||
|
|
||||||
|
print(f"\n结果:")
|
||||||
|
print(f" 成功: {'✓' if success else '✗'}")
|
||||||
|
print(f" 最终回答: {final_result[:200]}{'...' if len(final_result) > 200 else ''}")
|
||||||
|
|
||||||
|
# 调试信息
|
||||||
|
if "debug_info" in result:
|
||||||
|
debug_info = result["debug_info"]
|
||||||
|
print(f" 调试信息:")
|
||||||
|
if "fast_path_failed" in debug_info:
|
||||||
|
print(f" - 快速路径失败: {debug_info['fast_path_failed']}")
|
||||||
|
if "fast_path_fail_reason" in debug_info:
|
||||||
|
print(f" - 失败原因: {debug_info['fast_path_fail_reason']}")
|
||||||
|
if "hybrid_decision" in debug_info:
|
||||||
|
decision = debug_info["hybrid_decision"]
|
||||||
|
print(f" - 路由决策: {decision.path if hasattr(decision, 'path') else 'unknown'}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"success": success,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ 测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
print(f"堆栈: {traceback.format_exc()}")
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("主图完整测试套件")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 设置环境
|
||||||
|
graph = await setup_test_environment()
|
||||||
|
|
||||||
|
# 运行所有测试
|
||||||
|
results = []
|
||||||
|
for test_case in TEST_CASES:
|
||||||
|
result = await run_single_test(graph, test_case)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# 稍微间隔一下
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# 总结
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试总结")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
total = len(results)
|
||||||
|
passed = sum(1 for r in results if r["success"])
|
||||||
|
failed = total - passed
|
||||||
|
|
||||||
|
print(f"\n总测试数: {total}")
|
||||||
|
print(f"通过: {passed}")
|
||||||
|
print(f"失败: {failed}")
|
||||||
|
|
||||||
|
print("\n详细结果:")
|
||||||
|
for result in results:
|
||||||
|
status = "✓ 通过" if result["success"] else "✗ 失败"
|
||||||
|
print(f" {result['name']}: {status}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
if failed == 0:
|
||||||
|
print("🎉 所有测试通过!")
|
||||||
|
else:
|
||||||
|
print(f"⚠️ 有 {failed} 个测试失败")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n测试被用户中断")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\n测试运行失败: {e}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
@@ -22,7 +22,7 @@ async def test_rag_pipeline_direct():
|
|||||||
rerank_top_n=5
|
rerank_top_n=5
|
||||||
)
|
)
|
||||||
|
|
||||||
query = "黄双银的经历"
|
query = "吕布的经历"
|
||||||
|
|
||||||
print(f"\n用户查询: {query}")
|
print(f"\n用户查询: {query}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
@@ -64,7 +64,7 @@ async def test_rag_tool():
|
|||||||
rerank_top_n=5
|
rerank_top_n=5
|
||||||
)
|
)
|
||||||
|
|
||||||
query = "黄双银的经历"
|
query = "吕布的经历"
|
||||||
|
|
||||||
print(f"\n用户查询: {query}")
|
print(f"\n用户查询: {query}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
@@ -91,7 +91,7 @@ async def test_custom_pipeline():
|
|||||||
rerank_top_n=3 # 只返回前 3 个最相关文档
|
rerank_top_n=3 # 只返回前 3 个最相关文档
|
||||||
)
|
)
|
||||||
|
|
||||||
query = "黄双银的经历"
|
query = "吕布的经历"
|
||||||
|
|
||||||
print(f"\n用户查询: {query}")
|
print(f"\n用户查询: {query}")
|
||||||
print(f"配置: num_queries=2, rerank_top_n=3")
|
print(f"配置: num_queries=2, rerank_top_n=3")
|
||||||
@@ -124,7 +124,7 @@ async def main():
|
|||||||
"""主测试函数"""
|
"""主测试函数"""
|
||||||
print("\n" + "="*80)
|
print("\n" + "="*80)
|
||||||
print("完整 RAG Pipeline 测试")
|
print("完整 RAG Pipeline 测试")
|
||||||
print("查询: '黄双银的经历'")
|
print("查询: '吕布的经历'")
|
||||||
print("="*80)
|
print("="*80)
|
||||||
|
|
||||||
# 测试 1: 直接使用 pipeline
|
# 测试 1: 直接使用 pipeline
|
||||||
|
|||||||
Reference in New Issue
Block a user