204 lines
7.1 KiB
Python
204 lines
7.1 KiB
Python
|
|
"""
|
|||
|
|
快速路径节点模块
|
|||
|
|
包含闲聊、RAG、工具等快速处理节点
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from ..state import MainGraphState
|
|||
|
|
from ...logger import info, debug
|
|||
|
|
from ...model_services.chat_services import get_small_llm_service, get_chat_service
|
|||
|
|
from .rag_nodes import rag_retrieve_node
|
|||
|
|
from ._utils import dispatch_custom_event
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 闲聊回复模板 ==========
|
|||
|
|
CHITCHAT_TEMPLATES = {
|
|||
|
|
"谢谢": "不客气!如果还有其他问题,请随时告诉我 😊",
|
|||
|
|
"再见": "再见!期待下次为您服务 👋",
|
|||
|
|
"你好": "你好!有什么我可以帮您的吗?",
|
|||
|
|
"默认": None # 使用 LLM 生成
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
CHITCHAT_KEYWORDS = {
|
|||
|
|
"谢谢": ["谢谢", "感谢", "thanks", "thank you"],
|
|||
|
|
"再见": ["再见", "拜拜", "bye", "goodbye"],
|
|||
|
|
"你好": ["你好", "您好", "hi", "hello", "hey", "早上好", "晚上好", "下午好"],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 闲聊节点 ==========
|
|||
|
|
async def fast_chitchat_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
|||
|
|
"""快速闲聊节点"""
|
|||
|
|
state.current_phase = "fast_chitchat"
|
|||
|
|
query = state.user_query or ""
|
|||
|
|
info(f"[Fast Chitchat] 处理: {query[:50]}")
|
|||
|
|
|
|||
|
|
# 发送开始事件
|
|||
|
|
await dispatch_custom_event("fast_path_start", {"path": "fast_chitchat"}, config)
|
|||
|
|
|
|||
|
|
# 清除之前的 final_result,让 llm_call 生成新回答
|
|||
|
|
state.final_result = None
|
|||
|
|
|
|||
|
|
# 标记快速路径成功,但不设置 final_result,让 llm_call 生成回答
|
|||
|
|
state.success = True
|
|||
|
|
state.current_phase = "llm_call"
|
|||
|
|
state.debug_info["fast_chitchat_success"] = True
|
|||
|
|
|
|||
|
|
# 发送完成事件
|
|||
|
|
await dispatch_custom_event("fast_path_end", {"path": "fast_chitchat", "success": True}, config)
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _match_chitchat_template(query: str) -> str:
|
|||
|
|
"""匹配闲聊模板"""
|
|||
|
|
query_clean = query.strip().lower()
|
|||
|
|
|
|||
|
|
for intent, keywords in CHITCHAT_KEYWORDS.items():
|
|||
|
|
if any(kw in query_clean for kw in keywords):
|
|||
|
|
return CHITCHAT_TEMPLATES[intent]
|
|||
|
|
|
|||
|
|
# 默认:使用 LLM 生成
|
|||
|
|
try:
|
|||
|
|
llm = get_small_llm_service()
|
|||
|
|
response = llm.invoke(f"你是一个友好的助手。用户说:{query}。请简短友好地回复:")
|
|||
|
|
return response.content
|
|||
|
|
except Exception:
|
|||
|
|
return "你好!有什么我可以帮您的吗?"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 快速 RAG 节点 ==========
|
|||
|
|
async def fast_rag_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
|||
|
|
"""快速 RAG 节点:只负责 RAG 检索,然后交给 llm_call 生成回答"""
|
|||
|
|
state.current_phase = "fast_rag"
|
|||
|
|
query = state.user_query or ""
|
|||
|
|
info(f"[Fast RAG] 开始处理: {query[:50]}")
|
|||
|
|
|
|||
|
|
# 获取 RAG 工具
|
|||
|
|
from app.main_graph.utils.rag_initializer import get_rag_tool
|
|||
|
|
rag_tool = get_rag_tool()
|
|||
|
|
info(f"[Fast RAG] 获取到 rag_tool: {rag_tool is not None}")
|
|||
|
|
|
|||
|
|
# 发送开始事件
|
|||
|
|
await dispatch_custom_event("fast_path_start", {"path": "fast_rag"}, config)
|
|||
|
|
|
|||
|
|
# 清除之前的 final_result,让 llm_call 生成新回答
|
|||
|
|
state.final_result = None
|
|||
|
|
|
|||
|
|
# 如果没有 rag_tool,升级到 React 循环
|
|||
|
|
if not rag_tool:
|
|||
|
|
info("[Fast RAG] 未找到 RAG 工具,升级到 React 循环")
|
|||
|
|
return _mark_fast_path_failed(state, "未找到 RAG 工具")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 尝试 RAG 检索
|
|||
|
|
state = await rag_retrieve_node(state, config)
|
|||
|
|
|
|||
|
|
# 检查检索结果
|
|||
|
|
if _has_valid_rag_results(state):
|
|||
|
|
info(f"[Fast RAG] 检索有效,进入 llm_call 生成回答")
|
|||
|
|
await dispatch_custom_event("fast_path_end", {"path": "fast_rag", "success": True}, config)
|
|||
|
|
# 注意:这里不设置 final_result,让 llm_call 节点处理
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
# 无效结果:升级到 React 循环
|
|||
|
|
info("[Fast RAG] 无有效检索结果,升级到 React 循环")
|
|||
|
|
return _mark_fast_path_failed(state, "无有效检索结果")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
info(f"[Fast RAG] 执行失败: {e}")
|
|||
|
|
return _mark_fast_path_failed(state, str(e))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _has_valid_rag_results(state: MainGraphState) -> bool:
|
|||
|
|
"""检查 RAG 结果是否有效"""
|
|||
|
|
rag_docs = getattr(state, "rag_docs", [])
|
|||
|
|
rag_context = getattr(state, "rag_context", "")
|
|||
|
|
return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState:
|
|||
|
|
"""使用小模型快速生成回答"""
|
|||
|
|
try:
|
|||
|
|
chat_llm = get_chat_service()
|
|||
|
|
rag_context = state.rag_context or str(state.rag_docs)[:2000]
|
|||
|
|
|
|||
|
|
prompt = f"""请根据以下信息回答用户问题:
|
|||
|
|
|
|||
|
|
检索到的信息:
|
|||
|
|
{rag_context}
|
|||
|
|
|
|||
|
|
用户问题:{query}
|
|||
|
|
|
|||
|
|
请给出简洁、准确的回答:"""
|
|||
|
|
|
|||
|
|
# 使用流式输出
|
|||
|
|
from app.main_graph.config import get_stream_writer
|
|||
|
|
writer = get_stream_writer()
|
|||
|
|
|
|||
|
|
full_content = ""
|
|||
|
|
async for chunk in chat_llm.astream(prompt):
|
|||
|
|
content = getattr(chunk, 'content', '')
|
|||
|
|
if content:
|
|||
|
|
full_content += content
|
|||
|
|
# 流式输出
|
|||
|
|
if writer and hasattr(writer, '__call__'):
|
|||
|
|
try:
|
|||
|
|
writer({
|
|||
|
|
"type": "llm_token",
|
|||
|
|
"token": content
|
|||
|
|
})
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
state.final_result = full_content
|
|||
|
|
state.success = True
|
|||
|
|
state.current_phase = "finalizing"
|
|||
|
|
state.debug_info["fast_rag_success"] = True
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
info(f"[Fast RAG] 快速回答生成失败: {e}")
|
|||
|
|
return _mark_fast_path_failed(state, "回答生成失败")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 快速工具节点 ==========
|
|||
|
|
async def fast_tool_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
|||
|
|
"""快速工具节点"""
|
|||
|
|
state.current_phase = "fast_tool"
|
|||
|
|
|
|||
|
|
decision = state.debug_info.get("hybrid_decision", {})
|
|||
|
|
suggested_tools = decision.get("suggested_tools", [])
|
|||
|
|
info(f"[Fast Tool] 开始处理,建议工具: {suggested_tools}")
|
|||
|
|
|
|||
|
|
await dispatch_custom_event("fast_path_start", {"path": "fast_tool", "suggested_tools": suggested_tools}, config)
|
|||
|
|
|
|||
|
|
# 无明确工具建议,升级到 React 循环
|
|||
|
|
if not suggested_tools:
|
|||
|
|
info("[Fast Tool] 无明确工具建议,升级到 React 循环")
|
|||
|
|
return _mark_fast_path_failed(state, "无明确工具建议")
|
|||
|
|
|
|||
|
|
# 当前版本暂不支持快速工具调用,升级到 React 循环
|
|||
|
|
info("[Fast Tool] 快速工具调用暂未完善,升级到 React 循环")
|
|||
|
|
return _mark_fast_path_failed(state, "快速工具调用暂未完善")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 公共函数 ==========
|
|||
|
|
def _mark_fast_path_failed(state: MainGraphState, reason: str = "") -> MainGraphState:
|
|||
|
|
"""标记快速路径失败,准备升级到 React 循环"""
|
|||
|
|
state.debug_info["fast_path_failed"] = True
|
|||
|
|
state.debug_info["fast_path_fail_reason"] = reason
|
|||
|
|
state.success = False
|
|||
|
|
info(f"[Fast Path] 标记失败,准备升级: {reason}")
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ========== 导出 ==========
|
|||
|
|
__all__ = [
|
|||
|
|
"fast_chitchat_node",
|
|||
|
|
"fast_rag_node",
|
|||
|
|
"fast_tool_node",
|
|||
|
|
"_mark_fast_path_failed",
|
|||
|
|
]
|