205 lines
7.2 KiB
Python
205 lines
7.2 KiB
Python
"""
|
||
快速路径节点模块
|
||
包含闲聊、RAG、工具等快速处理节点
|
||
"""
|
||
|
||
from typing import Optional
|
||
from langchain_core.runnables.config import RunnableConfig
|
||
|
||
from ..state import MainGraphState
|
||
from ...logger import info, debug
|
||
from ...model_services.chat_services import get_small_llm_service, get_chat_service
|
||
from .rag_nodes import rag_retrieve_node
|
||
from ._utils import dispatch_custom_event
|
||
|
||
|
||
# ========== 闲聊回复模板 ==========
|
||
CHITCHAT_TEMPLATES = {
|
||
"谢谢": "不客气!如果还有其他问题,请随时告诉我 😊",
|
||
"再见": "再见!期待下次为您服务 👋",
|
||
"你好": "你好!有什么我可以帮您的吗?",
|
||
"默认": None # 使用 LLM 生成
|
||
}
|
||
|
||
CHITCHAT_KEYWORDS = {
|
||
"谢谢": ["谢谢", "感谢", "thanks", "thank you"],
|
||
"再见": ["再见", "拜拜", "bye", "goodbye"],
|
||
"你好": ["你好", "您好", "hi", "hello", "hey", "早上好", "晚上好", "下午好"],
|
||
}
|
||
|
||
|
||
# ========== 闲聊节点 ==========
|
||
async def fast_chitchat_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||
"""快速闲聊节点"""
|
||
state.current_phase = "fast_chitchat"
|
||
query = state.user_query or ""
|
||
info(f"[Fast Chitchat] 处理: {query[:50]}")
|
||
|
||
# 发送开始事件
|
||
await dispatch_custom_event("fast_path_start", {"path": "fast_chitchat"}, config)
|
||
|
||
# 清除之前的 final_result,让 llm_call 生成新回答
|
||
state.final_result = None
|
||
|
||
# 标记快速路径成功,但不设置 final_result,让 llm_call 生成回答
|
||
state.success = True
|
||
state.current_phase = "llm_call"
|
||
state.debug_info["fast_chitchat_success"] = True
|
||
|
||
# 发送完成事件
|
||
await dispatch_custom_event("fast_path_end", {"path": "fast_chitchat", "success": True}, config)
|
||
|
||
return state
|
||
|
||
|
||
def _match_chitchat_template(query: str) -> str:
|
||
"""匹配闲聊模板"""
|
||
query_clean = query.strip().lower()
|
||
|
||
for intent, keywords in CHITCHAT_KEYWORDS.items():
|
||
if any(kw in query_clean for kw in keywords):
|
||
return CHITCHAT_TEMPLATES[intent]
|
||
|
||
# 默认:使用 LLM 生成
|
||
try:
|
||
llm = get_small_llm_service()
|
||
response = llm.invoke(f"你是一个友好的助手。用户说:{query}。请简短友好地回复:")
|
||
return response.content
|
||
except Exception:
|
||
return "你好!有什么我可以帮您的吗?"
|
||
|
||
|
||
# ========== 快速 RAG 节点 ==========
|
||
async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||
"""快速 RAG 节点:只负责 RAG 检索,然后交给 llm_call 生成回答"""
|
||
state.current_phase = "fast_rag"
|
||
query = state.user_query or ""
|
||
info(f"[Fast RAG] 开始处理: {query[:50]}")
|
||
|
||
# 获取 RAG 工具
|
||
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
|
||
rag_tool = get_rag_tool()
|
||
info(f"[Fast RAG] 获取到 rag_tool: {rag_tool is not None}")
|
||
|
||
# 发送开始事件
|
||
await dispatch_custom_event("fast_path_start", {"path": "fast_rag"}, config)
|
||
|
||
# 清除之前的 final_result,让 llm_call 生成新回答
|
||
state.final_result = None
|
||
|
||
# 如果没有 rag_tool,升级到 React 循环
|
||
if not rag_tool:
|
||
info("[Fast RAG] 未找到 RAG 工具,升级到 React 循环")
|
||
return _mark_fast_path_failed(state, "未找到 RAG 工具")
|
||
|
||
try:
|
||
# 尝试 RAG 检索
|
||
state = await rag_retrieve_node(state, config)
|
||
|
||
# 检查检索结果
|
||
if _has_valid_rag_results(state):
|
||
info(f"[Fast RAG] 检索有效,进入 llm_call 生成回答")
|
||
await dispatch_custom_event("fast_path_end", {"path": "fast_rag", "success": True}, config)
|
||
# 注意:这里不设置 final_result,让 llm_call 节点处理
|
||
return state
|
||
|
||
# 无效结果:升级到 React 循环
|
||
info("[Fast RAG] 无有效检索结果,升级到 React 循环")
|
||
return _mark_fast_path_failed(state, "无有效检索结果")
|
||
|
||
except Exception as e:
|
||
info(f"[Fast RAG] 执行失败: {e}")
|
||
return _mark_fast_path_failed(state, str(e))
|
||
|
||
|
||
def _has_valid_rag_results(state: MainGraphState) -> bool:
|
||
"""检查 RAG 结果是否有效"""
|
||
rag_docs = getattr(state, "rag_docs", [])
|
||
rag_context = getattr(state, "rag_context", "")
|
||
return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10)
|
||
|
||
|
||
async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState:
|
||
"""使用小模型快速生成回答"""
|
||
try:
|
||
chat_llm = get_chat_service()
|
||
rag_context = state.rag_context or str(state.rag_docs)[:2000]
|
||
|
||
prompt = f"""请根据以下信息回答用户问题:
|
||
|
||
检索到的信息:
|
||
{rag_context}
|
||
|
||
用户问题:{query}
|
||
|
||
请给出简洁、准确的回答:"""
|
||
|
||
# 使用流式输出
|
||
from backend.app.main_graph.config import get_stream_writer
|
||
writer = get_stream_writer()
|
||
|
||
full_content = ""
|
||
async for chunk in chat_llm.astream(prompt):
|
||
content = getattr(chunk, 'content', '')
|
||
if content:
|
||
full_content += content
|
||
# 流式输出
|
||
if writer and hasattr(writer, '__call__'):
|
||
try:
|
||
writer({
|
||
"type": "llm_token",
|
||
"token": content
|
||
})
|
||
except Exception:
|
||
pass
|
||
|
||
state.final_result = full_content
|
||
state.success = True
|
||
state.current_phase = "finalizing"
|
||
state.debug_info["fast_rag_success"] = True
|
||
return state
|
||
|
||
except Exception as e:
|
||
info(f"[Fast RAG] 快速回答生成失败: {e}")
|
||
return _mark_fast_path_failed(state, "回答生成失败")
|
||
|
||
|
||
# ========== 快速工具节点 ==========
|
||
async def fast_tool_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||
"""快速工具节点"""
|
||
state.current_phase = "fast_tool"
|
||
|
||
decision = state.debug_info.get("hybrid_decision", {})
|
||
suggested_tools = decision.get("suggested_tools", [])
|
||
info(f"[Fast Tool] 开始处理,建议工具: {suggested_tools}")
|
||
|
||
await dispatch_custom_event("fast_path_start", {"path": "fast_tool", "suggested_tools": suggested_tools}, config)
|
||
|
||
# 无明确工具建议,升级到 React 循环
|
||
if not suggested_tools:
|
||
info("[Fast Tool] 无明确工具建议,升级到 React 循环")
|
||
return _mark_fast_path_failed(state, "无明确工具建议")
|
||
|
||
# 当前版本暂不支持快速工具调用,升级到 React 循环
|
||
info("[Fast Tool] 快速工具调用暂未完善,升级到 React 循环")
|
||
return _mark_fast_path_failed(state, "快速工具调用暂未完善")
|
||
|
||
|
||
# ========== 公共函数 ==========
|
||
def _mark_fast_path_failed(state: MainGraphState, reason: str = "") -> MainGraphState:
|
||
"""标记快速路径失败,准备升级到 React 循环"""
|
||
state.debug_info["fast_path_failed"] = True
|
||
state.debug_info["fast_path_fail_reason"] = reason
|
||
state.success = False
|
||
info(f"[Fast Path] 标记失败,准备升级: {reason}")
|
||
return state
|
||
|
||
|
||
# ========== 导出 ==========
|
||
__all__ = [
|
||
"fast_chitchat_node",
|
||
"fast_rag_node",
|
||
"fast_tool_node",
|
||
"_mark_fast_path_failed",
|
||
]
|