This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ...logger import info, debug
|
||||
from backend.app.logger import info, debug
|
||||
from ...model_services.chat_services import get_small_llm_service, get_chat_service
|
||||
from .rag_nodes import rag_retrieve_node
|
||||
from ._utils import dispatch_custom_event
|
||||
@@ -113,10 +113,18 @@ async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig]
|
||||
|
||||
|
||||
def _has_valid_rag_results(state: MainGraphState) -> bool:
|
||||
"""检查 RAG 结果是否有效"""
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
"""检查 RAG 结果是否有效(基于置信度)"""
|
||||
from .rag_nodes import RAG_CONFIDENCE_THRESHOLD
|
||||
rag_context = getattr(state, "rag_context", "")
|
||||
return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10)
|
||||
rag_confidence = getattr(state, "rag_confidence", 0.0)
|
||||
|
||||
# 有结果且置信度足够
|
||||
has_content = rag_context and len(rag_context) > 0
|
||||
has_confidence = rag_confidence >= RAG_CONFIDENCE_THRESHOLD
|
||||
|
||||
info(f"[Fast RAG Check] has_content={has_content}, rag_confidence={rag_confidence:.2f}, threshold={RAG_CONFIDENCE_THRESHOLD}")
|
||||
|
||||
return has_content and has_confidence
|
||||
|
||||
|
||||
async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState:
|
||||
|
||||
Reference in New Issue
Block a user