Files
ailine/backend/app/main_graph/nodes/fast_paths.py
root 3ae9daa01a
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m44s
导入方式修改
2026-05-05 23:17:00 +08:00

205 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
快速路径节点模块
包含闲聊、RAG、工具等快速处理节点
"""
from typing import Optional
from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState
from ...logger import info, debug
from ...model_services.chat_services import get_small_llm_service, get_chat_service
from .rag_nodes import rag_retrieve_node
from ._utils import dispatch_custom_event
# ========== 闲聊回复模板 ==========
CHITCHAT_TEMPLATES = {
"谢谢": "不客气!如果还有其他问题,请随时告诉我 😊",
"再见": "再见!期待下次为您服务 👋",
"你好": "你好!有什么我可以帮您的吗?",
"默认": None # 使用 LLM 生成
}
CHITCHAT_KEYWORDS = {
"谢谢": ["谢谢", "感谢", "thanks", "thank you"],
"再见": ["再见", "拜拜", "bye", "goodbye"],
"你好": ["你好", "您好", "hi", "hello", "hey", "早上好", "晚上好", "下午好"],
}
# ========== 闲聊节点 ==========
async def fast_chitchat_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""快速闲聊节点"""
state.current_phase = "fast_chitchat"
query = state.user_query or ""
info(f"[Fast Chitchat] 处理: {query[:50]}")
# 发送开始事件
await dispatch_custom_event("fast_path_start", {"path": "fast_chitchat"}, config)
# 清除之前的 final_result让 llm_call 生成新回答
state.final_result = None
# 标记快速路径成功,但不设置 final_result让 llm_call 生成回答
state.success = True
state.current_phase = "llm_call"
state.debug_info["fast_chitchat_success"] = True
# 发送完成事件
await dispatch_custom_event("fast_path_end", {"path": "fast_chitchat", "success": True}, config)
return state
def _match_chitchat_template(query: str) -> str:
"""匹配闲聊模板"""
query_clean = query.strip().lower()
for intent, keywords in CHITCHAT_KEYWORDS.items():
if any(kw in query_clean for kw in keywords):
return CHITCHAT_TEMPLATES[intent]
# 默认:使用 LLM 生成
try:
llm = get_small_llm_service()
response = llm.invoke(f"你是一个友好的助手。用户说:{query}。请简短友好地回复:")
return response.content
except Exception:
return "你好!有什么我可以帮您的吗?"
# ========== 快速 RAG 节点 ==========
async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""快速 RAG 节点:只负责 RAG 检索,然后交给 llm_call 生成回答"""
state.current_phase = "fast_rag"
query = state.user_query or ""
info(f"[Fast RAG] 开始处理: {query[:50]}")
# 获取 RAG 工具
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
rag_tool = get_rag_tool()
info(f"[Fast RAG] 获取到 rag_tool: {rag_tool is not None}")
# 发送开始事件
await dispatch_custom_event("fast_path_start", {"path": "fast_rag"}, config)
# 清除之前的 final_result让 llm_call 生成新回答
state.final_result = None
# 如果没有 rag_tool升级到 React 循环
if not rag_tool:
info("[Fast RAG] 未找到 RAG 工具,升级到 React 循环")
return _mark_fast_path_failed(state, "未找到 RAG 工具")
try:
# 尝试 RAG 检索
state = await rag_retrieve_node(state, config)
# 检查检索结果
if _has_valid_rag_results(state):
info(f"[Fast RAG] 检索有效,进入 llm_call 生成回答")
await dispatch_custom_event("fast_path_end", {"path": "fast_rag", "success": True}, config)
# 注意:这里不设置 final_result让 llm_call 节点处理
return state
# 无效结果:升级到 React 循环
info("[Fast RAG] 无有效检索结果,升级到 React 循环")
return _mark_fast_path_failed(state, "无有效检索结果")
except Exception as e:
info(f"[Fast RAG] 执行失败: {e}")
return _mark_fast_path_failed(state, str(e))
def _has_valid_rag_results(state: MainGraphState) -> bool:
"""检查 RAG 结果是否有效"""
rag_docs = getattr(state, "rag_docs", [])
rag_context = getattr(state, "rag_context", "")
return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10)
async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState:
"""使用小模型快速生成回答"""
try:
chat_llm = get_chat_service()
rag_context = state.rag_context or str(state.rag_docs)[:2000]
prompt = f"""请根据以下信息回答用户问题:
检索到的信息:
{rag_context}
用户问题:{query}
请给出简洁、准确的回答:"""
# 使用流式输出
from backend.app.main_graph.config import get_stream_writer
writer = get_stream_writer()
full_content = ""
async for chunk in chat_llm.astream(prompt):
content = getattr(chunk, 'content', '')
if content:
full_content += content
# 流式输出
if writer and hasattr(writer, '__call__'):
try:
writer({
"type": "llm_token",
"token": content
})
except Exception:
pass
state.final_result = full_content
state.success = True
state.current_phase = "finalizing"
state.debug_info["fast_rag_success"] = True
return state
except Exception as e:
info(f"[Fast RAG] 快速回答生成失败: {e}")
return _mark_fast_path_failed(state, "回答生成失败")
# ========== 快速工具节点 ==========
async def fast_tool_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""快速工具节点"""
state.current_phase = "fast_tool"
decision = state.debug_info.get("hybrid_decision", {})
suggested_tools = decision.get("suggested_tools", [])
info(f"[Fast Tool] 开始处理,建议工具: {suggested_tools}")
await dispatch_custom_event("fast_path_start", {"path": "fast_tool", "suggested_tools": suggested_tools}, config)
# 无明确工具建议,升级到 React 循环
if not suggested_tools:
info("[Fast Tool] 无明确工具建议,升级到 React 循环")
return _mark_fast_path_failed(state, "无明确工具建议")
# 当前版本暂不支持快速工具调用,升级到 React 循环
info("[Fast Tool] 快速工具调用暂未完善,升级到 React 循环")
return _mark_fast_path_failed(state, "快速工具调用暂未完善")
# ========== 公共函数 ==========
def _mark_fast_path_failed(state: MainGraphState, reason: str = "") -> MainGraphState:
"""标记快速路径失败,准备升级到 React 循环"""
state.debug_info["fast_path_failed"] = True
state.debug_info["fast_path_fail_reason"] = reason
state.success = False
info(f"[Fast Path] 标记失败,准备升级: {reason}")
return state
# ========== 导出 ==========
__all__ = [
"fast_chitchat_node",
"fast_rag_node",
"fast_tool_node",
"_mark_fast_path_failed",
]