This commit is contained in:
@@ -7,16 +7,17 @@ import time
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
||||
from app.logger import info
|
||||
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
||||
from ...logger import info
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
|
||||
def _get_rag_tool() -> Optional[callable]:
|
||||
"""获取 RAG 工具"""
|
||||
from app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
return get_rag_tool()
|
||||
|
||||
|
||||
@@ -35,6 +36,9 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
|
||||
# 调用 RAG 工具
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
|
||||
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
|
||||
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
|
||||
info(f"[RAG Core] ========================================")
|
||||
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
|
||||
@@ -47,7 +51,7 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
|
||||
|
||||
|
||||
# ========== RAG 检索节点 ==========
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""RAG 检索节点:带超时和重试"""
|
||||
state.current_phase = "rag_retrieving"
|
||||
start_time = time.time()
|
||||
@@ -156,7 +160,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None
|
||||
return state
|
||||
|
||||
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""重新检索节点"""
|
||||
state.current_phase = "rag_re_retrieving"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user