主要修复: 1. 修复 RAG 推理无限循环问题(大小写不匹配 + 缺少已检索结果检查) 2. 修复 intent_classifier.py 的绝对导入错误 3. 删除旧的 start.sh 脚本,添加新的启动脚本 4. 优化路由逻辑和状态管理
This commit is contained in:
@@ -7,14 +7,14 @@ import json
|
||||
import asyncio
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from app.main_graph.config import set_stream_writer
|
||||
from ..main_graph.utils.main_graph_builder import build_react_main_graph
|
||||
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from ..main_graph.config import set_stream_writer
|
||||
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
|
||||
from app.main_graph.utils.rag_initializer import init_rag_tool
|
||||
from app.core.intent_classifier import get_intent_classifier
|
||||
from app.logger import info, warning, error
|
||||
from app.main_graph.state import MainGraphState, CurrentAction
|
||||
from ..main_graph.utils.rag_initializer import init_rag_tool
|
||||
from ..core.intent_classifier import get_intent_classifier
|
||||
from ..logger import info, warning, error
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
@@ -32,7 +32,7 @@ class AIAgentService:
|
||||
|
||||
async def initialize(self):
|
||||
# 0. 初始化 Mem0 客户端
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from ..memory.mem0_client import Mem0Client
|
||||
# 创建一个临时的 LLM 用于 Mem0(用第一个可用的)
|
||||
chat_services = get_all_chat_services()
|
||||
temp_llm = None
|
||||
@@ -49,7 +49,7 @@ class AIAgentService:
|
||||
self.tools.append(rag_tool)
|
||||
self.tools_by_name[rag_tool.name] = rag_tool
|
||||
# 关键:设置全局 RAG 工具,供 rag_nodes.py 使用
|
||||
from app.main_graph.nodes.rag_nodes import set_global_rag_tool
|
||||
from ..main_graph.nodes.rag_nodes import set_global_rag_tool
|
||||
set_global_rag_tool(rag_tool)
|
||||
|
||||
# 2. 构建各模型的 Graph(使用新版 React 模式)
|
||||
@@ -86,7 +86,7 @@ class AIAgentService:
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
# 新版状态输入:传入完整的 MainGraphState,关键是 user_query
|
||||
from app.main_graph.state import MainGraphState, CurrentAction
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
input_state = {
|
||||
"user_query": message,
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
|
||||
@@ -132,8 +132,17 @@ class ReactIntentReasoner:
|
||||
|
||||
# 关键修改:不要在第一次 rag_retrieve 后就直接回答,允许再推理一次
|
||||
# 让推理逻辑有机会判断 RAG 结果好不好,要不要再检索或转 web search
|
||||
rag_count = previous_actions.count("rag_retrieve")
|
||||
previous_actions = context.get("previous_actions", [])
|
||||
rag_count = previous_actions.count("RETRIEVE_RAG") # 修复:大写
|
||||
web_search_count = previous_actions.count("web_search")
|
||||
retrieved_docs = context.get("retrieved_docs", [])
|
||||
|
||||
# 如果已经有检索文档了,直接回答
|
||||
if retrieved_docs and len(retrieved_docs) > 0:
|
||||
result.action = ReasoningAction.DIRECT_RESPONSE
|
||||
result.confidence = 0.95
|
||||
result.reasoning = "已获取检索文档,直接回答"
|
||||
return result
|
||||
|
||||
# 只有当 rag 或 web search 已经超过 1 次,或者已经有推理在 rag 之后,才直接回答
|
||||
if rag_count >= 2 or web_search_count >= 1:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, Dict, Any
|
||||
import sys
|
||||
import os
|
||||
|
||||
from backend.app.model_services.chat_services import get_small_llm_service
|
||||
from ..model_services.chat_services import get_small_llm_service
|
||||
|
||||
|
||||
class IntentType(Enum):
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.logger import info, debug
|
||||
from app.model_services.chat_services import get_small_llm_service, get_chat_service
|
||||
from app.main_graph.nodes.rag_nodes import rag_retrieve_node
|
||||
from ..state import MainGraphState
|
||||
from ...logger import info, debug
|
||||
from ...model_services.chat_services import get_small_llm_service, get_chat_service
|
||||
from .rag_nodes import rag_retrieve_node
|
||||
|
||||
|
||||
# ========== 核心数据类型 ==========
|
||||
@@ -367,8 +367,8 @@ async def fast_rag_node(state: MainGraphState, config: Optional[Dict[str, Any]]
|
||||
debug(f"[Fast RAG] 发送事件失败: {e}")
|
||||
|
||||
try:
|
||||
# 先尝试 RAG 检索
|
||||
state = rag_retrieve_node(state, config)
|
||||
# 先尝试 RAG 检索 - 注意:rag_retrieve_node 是异步函数,需要 await
|
||||
state = await rag_retrieve_node(state, config)
|
||||
|
||||
# 检查检索结果
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
|
||||
@@ -364,11 +364,15 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
if "subgraph_completed" in previous_actions or state.final_result:
|
||||
return "llm_call"
|
||||
|
||||
# 关键修复:如果已经执行过 rag_retrieve 并且又执行过推理,直接去 LLM_CALL
|
||||
# 这样的流程:推理1 → RAG → 推理2(判断 RAG 结果) → LLM_CALL
|
||||
rag_count = previous_actions.count("rag_retrieve")
|
||||
if rag_count >= 1 and len(previous_actions) >= rag_count + 1:
|
||||
info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call")
|
||||
# 关键修复:检测 RAG 重复循环 - 如果发现"RETRIEVE_RAG"出现超过1次,直接去 LLM
|
||||
rag_count = previous_actions.count("RETRIEVE_RAG")
|
||||
if rag_count >= 2:
|
||||
info(f"[route_by_reasoning] 检测到 RAG 重复循环({rag_count}次),直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# 关键修复:如果已经有 rag_docs 或 rag_context,说明已经检索过了,直接去 LLM
|
||||
if (state.rag_docs and len(state.rag_docs) > 0) or (state.rag_context and len(state.rag_context) > 0):
|
||||
info(f"[route_by_reasoning] 检测到已存在 RAG 检索结果,直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# 关键修复:限制最多 3 次推理,避免无限循环
|
||||
|
||||
@@ -2,19 +2,19 @@
|
||||
整合后的完整主图构建器 - 所有节点都直接操作 MainGraphState
|
||||
"""
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from ..graph import StateGraph, START, END
|
||||
from typing import Dict, Any, Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.main_graph.nodes.react_nodes import (
|
||||
from ..state import MainGraphState
|
||||
from ..nodes.react_nodes import (
|
||||
init_state_node,
|
||||
react_reason_node,
|
||||
web_search_node,
|
||||
error_handling_node,
|
||||
route_by_reasoning
|
||||
)
|
||||
from app.main_graph.nodes.hybrid_router import (
|
||||
from ..nodes.hybrid_router import (
|
||||
hybrid_router_node,
|
||||
fast_chitchat_node,
|
||||
fast_rag_node,
|
||||
@@ -22,17 +22,17 @@ from app.main_graph.nodes.hybrid_router import (
|
||||
route_from_hybrid_decision,
|
||||
check_fast_path_success
|
||||
)
|
||||
from app.main_graph.nodes.llm_call import create_llm_call_node
|
||||
from app.main_graph.nodes.rag_nodes import rag_retrieve_node
|
||||
from app.main_graph.nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from app.main_graph.nodes.summarize import create_summarize_node
|
||||
from app.main_graph.nodes.finalize import finalize_node
|
||||
from app.subgraphs.contact import build_contact_subgraph
|
||||
from app.subgraphs.dictionary import build_dictionary_subgraph
|
||||
from app.subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.logger import info, debug
|
||||
from ..nodes.llm_call import create_llm_call_node
|
||||
from ..nodes.rag_nodes import rag_retrieve_node
|
||||
from ..nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from ..nodes.summarize import create_summarize_node
|
||||
from ..nodes.finalize import finalize_node
|
||||
from ...subgraphs.contact import build_contact_subgraph
|
||||
from ...subgraphs.dictionary import build_dictionary_subgraph
|
||||
from ...subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...logger import info, debug
|
||||
|
||||
|
||||
# ========== 检查是否需要总结 ==========
|
||||
@@ -140,7 +140,7 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
|
||||
except Exception as e:
|
||||
# 捕获子图错误,传递给主图
|
||||
from app.main_graph.state import ErrorRecord, ErrorSeverity
|
||||
from ..state import ErrorRecord, ErrorSeverity
|
||||
from datetime import datetime
|
||||
|
||||
error_record = ErrorRecord(
|
||||
|
||||
Reference in New Issue
Block a user