refactor: 重构目录结构 - 简化层级
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-04-29 12:52:41 +08:00
parent 223d1c9afd
commit ef5113bffb
54 changed files with 42 additions and 1819 deletions

View File

@@ -0,0 +1 @@
"""主图模块 - LangGraph 主流程"""

View File

@@ -0,0 +1,83 @@
"""
LangGraph 状态图构建模块 - 精简版,仅负责组装图
所有节点逻辑已拆分到独立模块
"""
from langchain_core.language_models import BaseLLM
from app.main_graph.graph import StateGraph, START, END
from .state import MessagesState, GraphContext
from ..nodes import (
should_continue,
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
finalize_node,
)
from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client
from ..memory import Mem0Client
class GraphBuilder:
"""LangGraph 状态图构建器 - 仅负责组装图"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
self.mem0_client = Mem0Client(llm)
def build(self) -> StateGraph:
"""
构建未编译的状态图
Returns:
StateGraph 实例
"""
# 注入全局客户端
set_mem0_client(self.mem0_client)
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
llm_call_node = create_llm_call_node(self.llm, self.tools)
tool_call_node = create_tool_call_node(self.tools_by_name)
summarize_node = create_summarize_node(self.mem0_client)
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("memory_trigger", memory_trigger_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
builder.add_node("finalize", finalize_node)
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "memory_trigger")
builder.add_edge("memory_trigger", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{
"tool_node": "tool_node",
"summarize": "summarize",
"finalize": "finalize"
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("summarize", "finalize")
builder.add_edge("finalize", END)
return builder

View File

@@ -0,0 +1 @@
"""主图节点"""

View File

@@ -0,0 +1,45 @@
"""
完成事件节点模块
负责发送完成事件包含token使用情况和耗时信息
"""
from typing import Any, Dict
from app.main_graph.config import get_stream_writer
# 本地模块
from app.main_graph.state import MessagesState
from ..utils.logging import log_state_change
from ..logger import info, error
from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
完成事件节点 - 发送完成事件包含token使用情况和耗时信息
Args:
state: 当前对话状态
config: 运行时配置
Returns:
空字典(完成节点,无状态更新)
"""
log_state_change("finalize", state, "进入")
try:
# 获取流式写入器并发送完成事件
writer = get_stream_writer()
writer({
"type": "custom",
"data": {
"type": "done",
"token_usage": state.get("last_token_usage", {}),
"elapsed_time": state.get("last_elapsed_time", 0.0)
}
})
info("🏁 [完成事件] 已发送完成事件包含token使用情况和耗时信息")
except Exception as e:
error(f"❌ [完成事件] 发送完成事件时发生异常: {e}")
log_state_change("finalize", state, "离开")
return {}

View File

@@ -0,0 +1,150 @@
"""
LLM 调用节点模块
负责调用大语言模型并处理响应
"""
import time
from typing import Any, Dict
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AIMessage
# 本地模块
from app.main_graph.state import MessagesState
from ..agent.prompts import create_system_prompt
from ..utils.logging import log_state_change
from ..logger import debug, info, error
def create_llm_call_node(llm: BaseLLM, tools: list):
"""
工厂函数:创建 LLM 调用节点
Args:
llm: LangChain LLM 实例
tools: 工具列表
Returns:
异步节点函数
"""
# 构建调用链
prompt = create_system_prompt(tools)
llm_with_tools = llm.bind_tools(tools)
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
chain = prompt | llm_with_tools
from langchain_core.runnables.config import RunnableConfig
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
LLM 调用节点(异步方法)
Args:
state: 当前对话状态
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
Returns:
更新后的状态字典
"""
log_state_change("llm_call", state, "进入")
memory_context = state.get("memory_context", "暂无用户信息")
start_time = time.time()
try:
# 恢复为:手动进行 astream并将所有的 chunk 拼接成最终的 response 返回。
# LangGraph 会自动监听这期间产生的所有 token。
chunks = []
async for chunk in chain.astream(
{
"messages": state["messages"],
"memory_context": memory_context
},
config=config
):
chunks.append(chunk)
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0]
for chunk in chunks[1:]:
response = response + chunk
else:
response = AIMessage(content="")
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
debug("="*80 + "\n")
result = {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器
}
log_state_change("llm_call", {**state, **result}, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
error_result = {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器
}
log_state_change("llm_call", state, "离开(异常)")
return error_result
return call_llm

View File

@@ -0,0 +1,38 @@
from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from app.main_graph.state import MessagesState
from ..memory.mem0_client import Mem0Client
from ..logger import info
# 全局变量,在 GraphBuilder 中注入
_mem0_client: Mem0Client = None
def set_mem0_client(client: Mem0Client):
global _mem0_client
_mem0_client = client
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
if _mem0_client is None:
return {}
messages = state.get("messages", [])
if not messages:
return {}
last_msg = messages[-1]
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
# 触发词(可自行扩展)
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
if any(word in content for word in trigger_words):
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化
if not _mem0_client._initialized:
await _mem0_client.initialize()
# 将用户消息作为事实来源提交给 Mem0
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
mem0_messages = [{"role": "user", "content": content}]
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
return {} # 不修改状态

View File

@@ -0,0 +1,294 @@
"""
RAG 节点模块 - 真正利用已有 RAG 代码
包含:
- rag_retrieve_node: RAG 检索节点(带超时重试)
- rag_re_retrieve_node: 重新检索节点
- 集成 backend/app/rag/tools.py 和 rag_initializer.py
"""
import time
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from .retry_utils import (
RetryConfig,
RAG_RETRY_CONFIG,
create_retry_wrapper_for_node
)
# 真正导入和利用已有 RAG 代码
from ..rag.tools import create_rag_tool_sync
from ..rag.pipeline import RAGPipeline
# ========== 全局 RAG 工具实例(延迟初始化)==========
_GLOBAL_RAG_TOOL: Optional[Any] = None
_GLOBAL_RAG_PIPELINE: Optional[RAGPipeline] = None
def get_global_rag_tool() -> Optional[Any]:
"""
获取全局 RAG 工具(单例模式)
Returns:
RAG 工具实例或 None
"""
return _GLOBAL_RAG_TOOL
def set_global_rag_tool(tool: Any) -> None:
"""
设置全局 RAG 工具(通常在应用启动时调用)
Args:
tool: RAG 工具实例
"""
global _GLOBAL_RAG_TOOL
_GLOBAL_RAG_TOOL = tool
def set_global_rag_pipeline(pipeline: RAGPipeline) -> None:
"""
设置全局 RAG Pipeline
Args:
pipeline: RAGPipeline 实例
"""
global _GLOBAL_RAG_PIPELINE
_GLOBAL_RAG_PIPELINE = pipeline
# ========== 从状态获取 RAG 工具 ==========
def get_rag_tool_from_state(state: MainGraphState) -> Optional[Any]:
"""
从状态中获取 RAG 工具(如果有)
Args:
state: 主图状态
Returns:
RAG 工具实例或 None
"""
# 优先从状态获取
if "rag_tool" in state.debug_info:
return state.debug_info["rag_tool"]
# 其次从全局获取
return get_global_rag_tool()
# ========== 工具:将 RAG 工具注入到状态 ==========
def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphState:
"""
将 RAG 工具注入到状态中,供后续节点使用
Args:
state: 主图状态
rag_tool: RAG 工具实例
Returns:
更新后的状态
"""
state.debug_info["rag_tool"] = rag_tool
state.debug_info["rag_tool_injected"] = datetime.now().isoformat()
return state
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(真正利用 rag/tools.py
Args:
state: 主图状态
Returns:
更新后的状态
"""
# 获取检索查询(优先使用推理结果中的优化查询)
retrieval_query = state.user_query
if "reasoning_result" in state.debug_info:
reasoning_result = state.debug_info["reasoning_result"]
if hasattr(reasoning_result, "retrieval_config"):
cfg = reasoning_result.retrieval_config
if cfg and cfg.retrieval_query:
retrieval_query = cfg.retrieval_query
# 尝试获取 RAG 工具(多种方式)
rag_tool = get_rag_tool_from_state(state)
if rag_tool:
# 使用真正的 RAG 工具(来自 rag/tools.py
try:
# 调用 LangChain Tool 的 invoke 方法
rag_context = rag_tool.invoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_retrieval", "content": rag_context}
]
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_tool"
return state
except Exception as e:
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
elif _GLOBAL_RAG_PIPELINE:
# 使用 RAG Pipeline 直接检索
try:
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
state.rag_docs = [
{"source": doc.metadata.get("source", "unknown"), "content": doc.page_content}
for doc in documents
]
else:
state.rag_context = f"未找到与 '{retrieval_query}' 相关的知识库信息。"
state.rag_docs = []
state.rag_retrieved = True
state.success = True
state.debug_info["rag_source"] = "rag_pipeline"
return state
except Exception as e:
raise RuntimeError(f"RAG Pipeline 调用失败: {str(e)}") from e
else:
# 没有可用的 RAG 工具/Pipeline
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点(带超时和重试)==========
def rag_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
Args:
state: 主图状态
Returns:
更新后的状态
"""
state.current_phase = "rag_retrieving"
start_time = time.time()
last_error = None
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 成功
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
# 指数退避等待
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
context={
"query": state.user_query,
"total_time": time.time() - start_time,
"timeout": RAG_RETRY_CONFIG.timeout,
"has_rag_tool": get_global_rag_tool() is not None,
"has_rag_pipeline": _GLOBAL_RAG_PIPELINE is not None
}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
return state
# ========== 重新检索节点 ==========
def rag_re_retrieve_node(state: MainGraphState) -> MainGraphState:
"""
重新检索节点:用于第二次检索(不同的参数)
Args:
state: 主图状态
Returns:
更新后的状态
"""
state.current_phase = "rag_re_retrieving"
# 记录原始检索信息
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
# 可以在这里修改检索参数(例如:调整查询、增加 k 值)
# 暂时复用同一个检索逻辑
return rag_retrieve_node(state)
# ========== 便捷函数:从 rag_initializer 初始化 ==========
async def initialize_rag_from_initializer() -> None:
"""
从 rag_initializer 初始化 RAG便捷函数
注意:这是示例代码,实际使用时需要提供 local_llm_creator
"""
try:
from app.main_graph.utils.rag_initializer import init_rag_tool
# 注意:这里需要传入 local_llm_creator
# 示例:
# def my_llm_creator():
# from ..model_services import get_llm
# return get_llm()
#
# rag_tool = await init_rag_tool(my_llm_creator)
# set_global_rag_tool(rag_tool)
print("⚠️ initialize_rag_from_initializer 需要传入 local_llm_creator")
print("⚠️ 请在应用启动时调用 init_rag_tool() 并设置全局 RAG 工具")
except ImportError as e:
print(f"⚠️ 无法导入 rag_initializer: {e}")
except Exception as e:
print(f"⚠️ RAG 初始化失败: {e}")
# ========== 导出 ==========
__all__ = [
# 节点函数
"rag_retrieve_node",
"rag_re_retrieve_node",
# 工具函数
"inject_rag_tool_to_state",
"get_rag_tool_from_state",
# 全局 RAG 管理
"get_global_rag_tool",
"set_global_rag_tool",
"set_global_rag_pipeline",
# 初始化
"initialize_rag_from_initializer"
]

View File

@@ -0,0 +1,297 @@
"""
React 模式节点模块 - 带超时和重试功能
包含:
- react_reason_node: 使用 intent.py 进行推理
- error_handling_node: 错误处理节点
- final_response_node: 最终回答节点
- init_state_node: 初始化状态节点
注意:为了兼容 LangGraph 的同步接口,我们保留了同步的 react_reason 调用
但内部会根据情况使用规则推理或尝试异步调用
"""
import sys
from typing import Dict, Any, Optional
from datetime import datetime
# 导入我们的 intent.py
from app.core.intent import (
react_reason,
get_route_by_reasoning,
ReasoningAction,
ReasoningResult
)
from app.core.state_base import StateUtils
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from .retry_utils import (
RetryConfig,
SUBGRAPH_RETRY_CONFIG
)
# ========== 1. React 推理节点 ==========
def react_reason_node(state: MainGraphState) -> MainGraphState:
"""
React 模式推理节点:判断下一步做什么
Returns: 更新后的状态
"""
state.current_phase = "react_reasoning"
state.reasoning_step += 1
# 检查是否超过最大步数
if state.reasoning_step > state.max_steps:
state.current_phase = "max_steps_exceeded"
state.final_result = (
f"❌ 推理步数超过限制(最大 {state.max_steps} 步),"
f"已执行 {state.reasoning_step - 1} 步。"
f"请简化您的问题或分批提问。"
)
state.success = False
return state
# 准备上下文
context = {
"retrieved_docs": state.rag_docs,
"previous_actions": [h.get("action") for h in state.reasoning_history],
"messages": state.messages,
"errors": state.errors
}
# 使用 intent.py 进行推理
# 注意:这里使用同步版本,内部会根据情况处理
result: ReasoningResult = react_reason(state.user_query, context)
# 记录推理历史
state.reasoning_history.append({
"step": state.reasoning_step,
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning,
"timestamp": datetime.now().isoformat()
})
# 更新状态
state.debug_info["last_reasoning"] = {
"action": result.action.name,
"confidence": result.confidence,
"reasoning": result.reasoning
}
# 保存推理结果到状态
state.debug_info["reasoning_result"] = result
# 确定下一步动作
state.last_action = result.action.name
return state
# ========== 2. 错误处理节点 ==========
def error_handling_node(state: MainGraphState) -> MainGraphState:
"""
错误处理节点:处理子图/工具调用错误
返回结构化错误信息,格式如下:
{
"tool/node": "...",
"status": "failed",
"error": "...",
"retries_exceeded": true/false,
"suggestion": "..."
}
"""
state.current_phase = "error_handling"
if not state.current_error:
state.current_phase = "react_reasoning"
return state
error = state.current_error
# 更新错误状态
state.error_message = f"{error.error_type}: {error.error_message}"
# 记录结构化错误信息
structured_error = {
"tool": error.source,
"status": "failed",
"error": error.error_message,
"retries_exceeded": error.retry_count >= error.max_retries,
"retry_count": error.retry_count,
"max_retries": error.max_retries
}
# 根据错误类型添加建议
if "RAG" in error.error_type:
structured_error["suggestion"] = "尝试重新表述问题或直接询问"
elif "subgraph" in error.source or "contact" in error.source:
structured_error["suggestion"] = "子图执行失败,请尝试简化查询"
elif "timeout" in error.error_message.lower():
structured_error["suggestion"] = "请求超时,请稍后再试"
else:
structured_error["suggestion"] = "请尝试其他方式提问"
state.debug_info["structured_error"] = structured_error
# 策略1: 检查是否可以重试
can_retry = (
error.severity in [ErrorSeverity.WARNING, ErrorSeverity.ERROR]
and error.retry_count < error.max_retries
)
if can_retry:
error.retry_count += 1
state.retry_action = error.source
state.debug_info["retry_count"] = error.retry_count
if "RAG" in error.error_type:
state.last_action = "RE_RETRIEVE_RAG"
elif "subgraph" in error.source:
state.last_action = "DIRECT_RESPONSE"
else:
state.last_action = "REASON"
state.current_phase = "retrying"
return state
# 策略2: 无法重试,尝试降级方案
if error.severity != ErrorSeverity.FATAL:
state.final_result = (
f"⚠️ 遇到一些问题:\n"
f"```json\n{structured_error}\n```\n"
f"但我会尽力用现有信息回答您。"
)
state.success = True
state.current_phase = "finalizing"
return state
# 策略3: 致命错误
state.final_result = (
f"❌ 服务暂时不可用,请稍后再试。\n"
f"```json\n{structured_error}\n```"
)
state.success = False
state.current_phase = "finalizing"
return state
# ========== 3. 最终回答节点 ==========
def final_response_node(state: MainGraphState) -> MainGraphState:
"""
最终回答节点:整理并生成最终回答
"""
state.current_phase = "finalizing"
# 如果已经有 final_result 了,直接返回
if state.final_result:
state.current_phase = "done"
return state
# 构建最终回答
parts = []
# 添加 RAG 上下文(如果有)
if state.rag_context:
parts.append(state.rag_context)
parts.append("---")
# 添加子图结果(如果有)
if state.contact_result and hasattr(state.contact_result, "get"):
if state.contact_result.get("final_result"):
parts.append(state.contact_result["final_result"])
if state.dictionary_result and hasattr(state.dictionary_result, "get"):
if state.dictionary_result.get("final_result"):
parts.append(state.dictionary_result["final_result"])
if state.news_result and hasattr(state.news_result, "get"):
if state.news_result.get("final_result"):
parts.append(state.news_result["final_result"])
# 如果都没有,用默认回答
if not parts:
parts.append(f"我理解了您的问题:{state.user_query}")
state.final_result = "\n".join(parts)
state.success = True
state.current_phase = "done"
state.end_time = datetime.now().isoformat()
return state
# ========== 4. 初始化状态节点 ==========
def init_state_node(state: MainGraphState) -> MainGraphState:
"""
初始化状态节点:在流程开始时设置初始值
"""
state.current_phase = "initializing"
state.reasoning_step = 0
state.start_time = datetime.now().isoformat()
# 从 messages 中提取用户查询
if not state.user_query and state.messages:
last_msg = state.messages[-1]
state.user_query = getattr(last_msg, "content", str(last_msg))
return state
# ========== 5. 条件路由函数 ==========
def route_by_reasoning(state: MainGraphState) -> str:
"""
根据推理结果决定下一步路由
Returns: 路由标识,对应 graph_builder.py 中的边
"""
# 先检查特殊情况
if state.current_phase == "max_steps_exceeded":
return "final_response"
if state.current_phase == "error_handling" or state.current_error:
return "handle_error"
if state.current_phase == "finalizing" or state.current_phase == "done":
return "final_response"
if state.current_phase == "retrying":
if state.retry_action and "rag" in state.retry_action.lower():
return "rag_retrieve"
return "react_reason"
# 获取推理结果
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result:
return "final_response"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
# 注意:这些名称必须与 subgraph_builder.py 中定义的节点名称一致
route_mapping = {
"direct_response": "final_response",
"retrieve_rag": "rag_retrieve",
"re_retrieve_rag": "rag_retrieve",
"clarify": "final_response",
"call_tool": "final_response", # 暂时映射到 final_response后续可以扩展
"contact": "contact_subgraph",
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
return route_mapping.get(route, "final_response")
# ========== 导出 ==========
__all__ = [
"init_state_node",
"react_reason_node",
"error_handling_node",
"final_response_node",
"route_by_reasoning"
]

View File

@@ -0,0 +1,76 @@
"""
记忆检索节点模块
负责从 Mem0 检索相关长期记忆
"""
from typing import Any, Dict
# 本地模块
from .state import MessagesState
from ..memory.mem0_client import Mem0Client
from ..utils.logging import log_state_change
from ..logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆检索节点
Args:
mem0_client: Mem0 客户端实例
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆检索节点 - 使用 Mem0
Args:
state: 当前对话状态
config: 运行时配置
Returns:
包含 memory_context 的状态更新
"""
log_state_change("retrieve_memory", state, "进入")
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 兼容 dict 和对象两种消息格式
last_msg = state["messages"][-1]
if isinstance(last_msg, dict):
query = str(last_msg.get("content", ""))
else:
query = str(last_msg.content)
memory_text_parts = []
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
await mem0_client.initialize()
if mem0_client.mem0:
try:
# 异步调用 Mem0 语义检索
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
if facts:
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
else:
debug("🔍 [记忆检索] 未找到相关记忆")
except Exception as e:
from app.logger import warning
warning(f"⚠️ Mem0 检索失败: {e}")
else:
from app.logger import warning
warning("⚠️ Mem0 未初始化,跳过记忆检索")
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
result = {"memory_context": memory_context}
log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
return retrieve_memory

View File

@@ -0,0 +1,48 @@
"""
路由决策节点
根据当前状态决定下一步走向
"""
from typing import Literal
from langchain_core.messages import AIMessage
# 本地模块
from ..config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
from app.main_graph.state import MessagesState
from ..logger import info
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']:
"""
决定下一步:工具调用、生成摘要还是结束
Args:
state: 当前对话状态
Returns:
下一个节点名称
"""
last_message = state["messages"][-1]
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
return 'tool_node'
# 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值
if isinstance(last_message, AIMessage):
turns = state.get("turns_since_last_summary", 0)
if turns >= MEMORY_SUMMARIZE_INTERVAL:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'")
return 'summarize'
else:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
return 'finalize'
# 3. 其他情况(如只有用户消息)直接结束
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
return 'finalize'

View File

@@ -0,0 +1,87 @@
"""
记忆存储节点模块
负责将对话历史提交给 Mem0 进行事实提取和存储
"""
from typing import Any, Dict
# 本地模块
from app.main_graph.state import MessagesState
from ..memory.mem0_client import Mem0Client
from ..utils.logging import log_state_change
from ..logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆存储节点
Args:
mem0_client: Mem0 客户端实例
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆存储节点 - 使用 Mem0
Args:
state: 当前对话状态
config: 运行时配置
Returns:
重置计数器的状态更新
"""
log_state_change("summarize", state, "进入")
messages = state["messages"]
if len(messages) < 4:
debug("📝 [记忆添加] 对话过短,跳过")
return {"turns_since_last_summary": 0}
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
await mem0_client.initialize()
# 将整个对话历史转换为 Mem0 需要的消息格式
mem0_messages = []
for msg in messages:
# 兼容 dict 和对象两种格式
if isinstance(msg, dict):
msg_type = msg.get("type", "")
msg_content = msg.get("content", "")
else:
msg_type = getattr(msg, 'type', '')
msg_content = getattr(msg, 'content', '')
if msg_type == "human":
mem0_messages.append({"role": "user", "content": msg_content})
elif msg_type == "ai":
mem0_messages.append({"role": "assistant", "content": msg_content})
elif msg_type == "tool":
mem0_messages.append({"role": "system", "content": f"[工具返回] {msg_content}"})
if mem0_client.mem0:
try:
# 异步调用 Mem0 自动提取并存储事实
success = await mem0_client.add_memories(
mem0_messages,
user_id=user_id
)
if success:
info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取")
except Exception as e:
error(f"❌ Mem0 记忆添加失败: {e}")
else:
warning("⚠️ Mem0 未初始化,跳过记忆添加")
log_state_change("summarize", state, "离开")
return {"turns_since_last_summary": 0}
return summarize_conversation

View File

@@ -0,0 +1,101 @@
"""
工具执行节点模块
负责执行 AI 调用的工具函数
"""
import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from app.main_graph.config import get_stream_writer
# 本地模块
from app.main_graph.state import MessagesState
from ..utils.logging import log_state_change
from ..logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""
工厂函数:创建工具执行节点
Args:
tools_by_name: 名称到工具函数的映射字典
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def call_tools(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
工具执行节点(异步方法)
Args:
state: 当前对话状态
config: 运行时配置
Returns:
包含 ToolMessage 的状态更新
"""
log_state_change("tool_node", state, "进入")
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
log_state_change("tool_node", state, "离开(无工具调用)")
return {"messages": []}
results = []
loop = asyncio.get_event_loop()
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = tools_by_name.get(tool_name)
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
if tool_func is None:
err_msg = f"Tool {tool_name} not found"
debug(f" └─ ❌ {err_msg}")
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
continue
# 获取流式写入器并发送工具调用开始事件
writer = get_stream_writer()
writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}})
try:
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
)
# 字符打印
result_preview = str(observation).replace("\n", " ")
debug(f" └─ ✅ 结果: {result_preview}")
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
# 发送工具调用完成事件
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}})
except Exception as e:
debug(f" └─ ❌ 异常: {e}")
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
# 发送工具调用失败事件
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}})
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
result = {"messages": results}
log_state_change("tool_node", {**state, **result}, "离开")
return result
return call_tools

View File

@@ -0,0 +1,113 @@
"""
主图状态定义 - React 模式增强版
Main Graph State Definition - React Mode Enhanced
"""
from enum import Enum, auto
from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List
from dataclasses import dataclass, field
from app.main_graph.graph import add_messages
from langchain_core.messages import BaseMessage
# ========== 兼容旧代码的类型 ==========
class MessagesState(TypedDict):
"""旧的MessagesState类型保留兼容性"""
messages: Annotated[Sequence[BaseMessage], add_messages]
class GraphContext(TypedDict):
"""旧的GraphContext类型保留兼容性"""
llm_calls: int
memory_context: str
system_prompt: str
# ========== 新的类型 ==========
class CurrentAction(Enum):
"""主图当前操作类型"""
NONE = auto()
GENERAL_CHAT = auto()
NEWS_ANALYSIS = auto()
DICTIONARY = auto()
CONTACT = auto()
class ErrorSeverity(Enum):
"""错误严重程度"""
INFO = auto() # 信息级别,继续执行
WARNING = auto() # 警告级别,可以重试
ERROR = auto() # 错误级别,需要处理
FATAL = auto() # 致命错误,终止执行
@dataclass
class ErrorRecord:
"""错误记录"""
error_type: str
error_message: str
severity: ErrorSeverity = ErrorSeverity.ERROR
source: str = "" # 来源:哪个节点/子图/工具
timestamp: str = ""
retry_count: int = 0 # 已重试次数
max_retries: int = 3 # 最大重试次数
context: Dict[str, Any] = field(default_factory=dict) # 错误上下文
@dataclass
class MainGraphState:
"""
主图状态 - React 循环推理版本
包含:
1. 旧代码的MessagesState兼容性字段
2. React 推理控制字段
3. 循环和错误处理
4. 子图结果占位
5. 用户信息
"""
# ========== 兼容性字段保留旧的MessagesState ==========
messages: Annotated[Sequence[BaseMessage], add_messages] = field(default_factory=list)
llm_calls: int = 0
memory_context: str = ""
system_prompt: str = ""
# ========== 主图控制字段 ==========
user_query: str = "" # 用户当前查询
current_action: CurrentAction = CurrentAction.NONE # 当前操作
intent_confidence: float = 0.0 # 意图识别置信度
# ========== React 推理专用字段 ==========
reasoning_step: int = 0 # 当前推理步数
max_steps: int = 40 # 最大推理步数≤40
last_action: str = "" # 上一步动作
reasoning_history: List[Dict[str, Any]] = field(default_factory=list) # 推理历史
# ========== RAG 相关字段 ==========
rag_context: str = "" # RAG 检索到的上下文
rag_retrieved: bool = False # 是否已经检索过
rag_docs: List[Dict[str, Any]] = field(default_factory=list) # 检索到的文档
# ========== 错误处理字段 ==========
errors: List[ErrorRecord] = field(default_factory=list) # 错误列表
current_error: Optional[ErrorRecord] = None # 当前错误
retry_action: Optional[str] = None # 重试动作
# ========== 子图结果占位 ==========
news_result: Optional[Dict[str, Any]] = None # 资讯子图结果
dictionary_result: Optional[Dict[str, Any]] = None # 词典子图结果
contact_result: Optional[Dict[str, Any]] = None # 通讯录子图结果
# ========== 用户信息 ==========
user_id: str = ""
# ========== 执行状态 ==========
current_phase: str = "init"
error_message: str = ""
final_result: str = ""
success: bool = False
# ========== 元数据 ==========
start_time: Optional[str] = None
end_time: Optional[str] = None
debug_info: Dict[str, Any] = field(default_factory=dict)

View File

@@ -0,0 +1 @@
"""主图工具"""

View File

@@ -0,0 +1,17 @@
"""
工具定义模块 - 子图工具 + RAG 工具
Subgraph Tools + RAG Tools
"""
# 子图工具
from .subgraph_tools import (
SUBGRAPH_TOOLS,
SUBGRAPH_TOOLS_BY_NAME,
dictionary_tool,
news_analysis_tool,
contact_tool
)
# 工具列表和映射(全局常量)
AVAILABLE_TOOLS = SUBGRAPH_TOOLS.copy()
TOOLS_BY_NAME = SUBGRAPH_TOOLS_BY_NAME.copy()

View File

@@ -0,0 +1,193 @@
"""
子图工具模块 - 将三个子图包装成 LangChain 工具
Subgraph Tools Module - Wrap three subgraphs as LangChain tools
"""
from langchain_core.tools import tool
from typing import Optional
# ============== 词典子图工具 ==============
@tool
def dictionary_tool(query: str, action: Optional[str] = None) -> str:
"""
词典/翻译工具 - 查询单词、翻译文本、提取术语、获取每日一词
Args:
query: 用户查询内容(单词、句子、文本等)
action: 可选,指定操作类型("query" 查询单词,"translate" 翻译,
"extract" 提取术语,"daily" 每日一词,不指定则自动识别)
Returns:
格式化的结果文本
"""
try:
from app.subgraphs.dictionary import (
DictionaryState,
DictionaryAction,
parse_intent,
format_result
)
from app.subgraphs.dictionary.nodes import (
query_word, translate_text, extract_terms, get_daily_word
)
# 创建初始状态
state = DictionaryState(user_query=query, user_id="default")
# 处理 action
if action == "query":
state.action = DictionaryAction.QUERY
state.action_params = {"word": query}
state = query_word(state)
elif action == "translate":
state.action = DictionaryAction.TRANSLATE
state.source_text = query
state = translate_text(state)
elif action == "daily":
state.action = DictionaryAction.DAILY_WORD
state = get_daily_word(state)
elif action == "extract":
state.action = DictionaryAction.EXTRACT
state.action_params = {"text": query}
state = extract_terms(state)
else:
# 自动解析意图
state = parse_intent(state)
if state.action == DictionaryAction.QUERY:
state = query_word(state)
elif state.action == DictionaryAction.TRANSLATE:
state = translate_text(state)
elif state.action == DictionaryAction.DAILY_WORD:
state = get_daily_word(state)
elif state.action == DictionaryAction.EXTRACT:
state = extract_terms(state)
# 格式化结果
state = format_result(state)
return state.final_result or "操作完成"
except Exception as e:
return f"词典工具执行出错:{str(e)}"
# ============== 资讯分析子图工具 ==============
@tool
def news_analysis_tool(query: str, action: Optional[str] = None) -> str:
"""
资讯分析工具 - 查询新闻、分析URL、提取关键词、生成报告
Args:
query: 用户查询内容关键词、URL、文本等
action: 可选,指定操作类型("query" 查询新闻,"analyze" 分析URL
"keywords" 提取关键词,"report" 生成报告,不指定则自动识别)
Returns:
格式化的结果文本
"""
try:
from app.subgraphs.news_analysis import (
NewsAnalysisState,
NewsAction,
parse_intent,
format_result
)
from app.subgraphs.news_analysis.nodes import (
query_news, analyze_url, extract_keywords, generate_report
)
# 创建初始状态
state = NewsAnalysisState(user_query=query, user_id="default")
# 处理 action
if action == "query":
state.action = NewsAction.QUERY_NEWS
state = query_news(state)
elif action == "analyze":
state.action = NewsAction.ANALYZE_URL
state.custom_urls = [query]
state = analyze_url(state)
elif action == "keywords":
state.action = NewsAction.EXTRACT_KEYWORDS
state = extract_keywords(state)
elif action == "report":
state.action = NewsAction.GENERATE_REPORT
state = generate_report(state)
else:
# 自动解析意图
state = parse_intent(state)
if state.action == NewsAction.QUERY_NEWS:
state = query_news(state)
elif state.action == NewsAction.ANALYZE_URL:
state.custom_urls = [query]
state = analyze_url(state)
elif state.action == NewsAction.EXTRACT_KEYWORDS:
state = extract_keywords(state)
elif state.action == NewsAction.GENERATE_REPORT:
state = generate_report(state)
# 格式化结果
state = format_result(state)
return state.final_result or "操作完成"
except Exception as e:
return f"资讯分析工具执行出错:{str(e)}"
# ============== 通讯录子图工具 ==============
@tool
def contact_tool(query: str, action: Optional[str] = None) -> str:
"""
通讯录工具 - 查询联系人、添加联系人、管理通讯录
Args:
query: 用户查询内容(姓名、电话、信息等)
action: 可选,指定操作类型(不指定则自动识别)
Returns:
格式化的结果文本
"""
try:
from app.subgraphs.contact import (
ContactState,
ContactAction,
parse_intent,
format_result
)
from app.subgraphs.contact.nodes import (
query_contact, add_contact, list_contacts
)
# 创建初始状态
state = ContactState(user_query=query, user_id="default")
# 自动解析意图
state = parse_intent(state)
# 根据解析结果执行操作
if state.action == ContactAction.QUERY:
state = query_contact(state)
elif state.action == ContactAction.ADD:
state = add_contact(state)
elif state.action == ContactAction.LIST:
state = list_contacts(state)
# 格式化结果
state = format_result(state)
return state.final_result or "操作完成"
except Exception as e:
return f"通讯录工具执行出错:{str(e)}"
# ============== 工具列表 ==============
SUBGRAPH_TOOLS = [
dictionary_tool,
news_analysis_tool,
contact_tool
]
SUBGRAPH_TOOLS_BY_NAME = {tool.name: tool for tool in SUBGRAPH_TOOLS}

View File

@@ -0,0 +1 @@
"""主图工具函数"""

View File

@@ -0,0 +1,27 @@
# app/rag_initializer.py
from ..rag.tools import create_rag_tool_sync
from rag_core import create_parent_retriever
from ..model_services import get_embedding_service
from ..logger import info, warning
async def init_rag_tool(local_llm_creator):
"""初始化 RAG 工具,失败返回 None"""
try:
info("🔄 正在初始化 RAG 检索系统...")
# 使用统一的嵌入服务获取接口
embeddings = get_embedding_service()
retriever = create_parent_retriever(
collection_name="rag_documents",
search_k=5,
embeddings=embeddings
)
rewrite_llm = local_llm_creator()
rag_tool = create_rag_tool_sync(
retriever, rewrite_llm,
num_queries=3, rerank_top_n=5
)
info("✅ RAG 检索工具初始化成功")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None

View File

@@ -0,0 +1,332 @@
"""
超时和重试工具模块
为 React 模式提供超时控制和重试机制
"""
import time
import asyncio
from functools import wraps
from typing import Callable, Any, Optional, Type, Tuple, Union
from dataclasses import dataclass, field
from enum import Enum, auto
class RetryStrategy(Enum):
"""重试策略"""
FIXED = auto() # 固定间隔
EXPONENTIAL = auto() # 指数退避
LINEAR = auto() # 线性增长
@dataclass
class RetryConfig:
"""重试配置"""
max_retries: int = 3 # 最大重试次数
base_delay: float = 1.0 # 基础延迟(秒)
max_delay: float = 10.0 # 最大延迟(秒)
strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
timeout: Optional[float] = 30.0 # 单次调用超时(秒)
recoverable_exceptions: Tuple[Type[Exception], ...] = field(
default_factory=lambda: (Exception,)
)
unrecoverable_exceptions: Tuple[Type[Exception], ...] = field(
default_factory=tuple
)
@dataclass
class RetryResult:
"""重试结果"""
success: bool
result: Any = None
error: Optional[Exception] = None
retry_count: int = 0
total_time: float = 0.0
timed_out: bool = False
# ========== 同步重试装饰器 ==========
def with_retry(
config: Optional[RetryConfig] = None,
max_retries: int = 3,
timeout: Optional[float] = 30.0,
base_delay: float = 1.0,
on_retry: Optional[Callable[[int, Exception], None]] = None
):
"""
同步重试装饰器
Args:
config: 重试配置对象
max_retries: 最大重试次数(如果没有 config
timeout: 单次调用超时(秒)
base_delay: 基础延迟(秒)
on_retry: 重试回调函数(retry_count, exception)
"""
if config is None:
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
base_delay=base_delay
)
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> RetryResult:
start_time = time.time()
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行函数(带超时)
if config.timeout:
# 使用信号量或线程实现超时(简化版)
result = func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# 成功
total_time = time.time() - start_time
return RetryResult(
success=True,
result=result,
retry_count=attempt,
total_time=total_time
)
except Exception as e:
last_error = e
# 检查是否是不可恢复的异常
if isinstance(e, config.unrecoverable_exceptions):
break
# 检查是否达到最大重试次数
if attempt >= config.max_retries:
break
# 计算延迟
delay = _calculate_delay(attempt, config)
# 回调通知
if on_retry:
on_retry(attempt + 1, e)
# 等待
time.sleep(delay)
# 所有重试都失败
total_time = time.time() - start_time
return RetryResult(
success=False,
error=last_error,
retry_count=config.max_retries,
total_time=total_time
)
return wrapper
return decorator
# ========== 异步重试装饰器 ==========
def with_async_retry(
config: Optional[RetryConfig] = None,
max_retries: int = 3,
timeout: Optional[float] = 30.0,
base_delay: float = 1.0,
on_retry: Optional[Callable[[int, Exception], None]] = None
):
"""
异步重试装饰器
"""
if config is None:
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
base_delay=base_delay
)
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs) -> RetryResult:
start_time = time.time()
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行函数(带超时)
if config.timeout:
result = await asyncio.wait_for(
func(*args, **kwargs),
timeout=config.timeout
)
else:
result = await func(*args, **kwargs)
# 成功
total_time = time.time() - start_time
return RetryResult(
success=True,
result=result,
retry_count=attempt,
total_time=total_time
)
except asyncio.TimeoutError as e:
last_error = e
timed_out = True
except Exception as e:
last_error = e
timed_out = False
# 检查是否是不可恢复的异常
if isinstance(e, config.unrecoverable_exceptions):
break
# 检查是否达到最大重试次数
if attempt >= config.max_retries:
break
# 计算延迟
delay = _calculate_delay(attempt, config)
# 回调通知
if on_retry:
on_retry(attempt + 1, last_error)
# 等待
await asyncio.sleep(delay)
# 所有重试都失败
total_time = time.time() - start_time
return RetryResult(
success=False,
error=last_error,
retry_count=config.max_retries,
total_time=total_time,
timed_out=isinstance(last_error, asyncio.TimeoutError)
)
return wrapper
return decorator
# ========== 辅助函数 ==========
def _calculate_delay(attempt: int, config: RetryConfig) -> float:
"""计算延迟时间"""
if config.strategy == RetryStrategy.FIXED:
delay = config.base_delay
elif config.strategy == RetryStrategy.LINEAR:
delay = config.base_delay * (attempt + 1)
elif config.strategy == RetryStrategy.EXPONENTIAL:
delay = config.base_delay * (2 ** attempt)
else:
delay = config.base_delay
# 不超过最大延迟
return min(delay, config.max_delay)
# ========== 为 React 节点设计的超时重试包装器 ==========
def create_retry_wrapper_for_node(
node_func: Callable,
node_name: str,
max_retries: int = 2,
timeout: float = 30.0
):
"""
为 React 节点创建带重试和超时的包装器
Args:
node_func: 原始节点函数
node_name: 节点名称(用于错误标识)
max_retries: 最大重试次数
timeout: 单次执行超时
Returns: 包装后的节点函数
"""
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
strategy=RetryStrategy.EXPONENTIAL
)
@wraps(node_func)
def wrapped_node(state):
# 记录开始时间
start_time = time.time()
# 重试循环
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行节点
result = node_func(state)
# 检查节点是否报告了错误
if hasattr(state, "current_error") and state.current_error:
# 节点内部报告了错误,继续重试
last_error = Exception(state.current_error.error_message)
if attempt < config.max_retries:
delay = _calculate_delay(attempt, config)
time.sleep(delay)
continue
# 成功
return result
except Exception as e:
last_error = e
if attempt >= config.max_retries:
break
# 等待后重试
delay = _calculate_delay(attempt, config)
time.sleep(delay)
# 所有重试都失败,更新状态错误信息
from .state import ErrorRecord, ErrorSeverity
error_record = ErrorRecord(
error_type=f"{node_name}TimeoutError",
error_message=str(last_error) if last_error else f"{node_name} 执行超时",
severity=ErrorSeverity.ERROR,
source=node_name,
retry_count=config.max_retries,
max_retries=config.max_retries,
context={
"timeout": timeout,
"total_time": time.time() - start_time
}
)
if hasattr(state, "errors"):
state.errors.append(error_record)
if hasattr(state, "current_error"):
state.current_error = error_record
if hasattr(state, "error_message"):
state.error_message = str(last_error)
if hasattr(state, "current_phase"):
state.current_phase = "error_handling"
return state
return wrapped_node
# ========== 预配置的 RAG 重试配置 ==========
RAG_RETRY_CONFIG = RetryConfig(
max_retries=2,
timeout=60.0, # RAG 可以容忍稍长的超时
base_delay=2.0,
strategy=RetryStrategy.EXPONENTIAL
)
# ========== 预配置的子图重试配置 ==========
SUBGRAPH_RETRY_CONFIG = RetryConfig(
max_retries=1, # 子图通常不适合多次重试
timeout=120.0, # 子图执行时间较长
base_delay=3.0
)

View File

@@ -0,0 +1,193 @@
"""
React 模式主图构建器 - 完整循环推理版本
Main Graph Builder - Full React Mode with Loop Reasoning
"""
from app.main_graph.graph import StateGraph, START, END
from typing import Dict, Any
from .state import MainGraphState, CurrentAction
from .react_nodes import (
init_state_node,
react_reason_node,
error_handling_node,
final_response_node,
route_by_reasoning
)
from .rag_nodes import rag_retrieve_node
from app.subgraphs.contact import build_contact_subgraph
from app.subgraphs.dictionary import build_dictionary_subgraph
from app.subgraphs.news_analysis import build_news_analysis_subgraph
# ========== 子图包装器(处理子图错误传递) ==========
def wrap_subgraph_for_error_handling(subgraph, name: str):
"""
包装子图,使其错误能传递给主图
Args:
subgraph: 编译好的子图
name: 子图名称(用于错误标识)
Returns: 包装后的节点函数
"""
def wrapped_node(state: MainGraphState) -> MainGraphState:
try:
# 调用子图
result = subgraph.invoke(state)
# 更新主图状态
if name == "contact":
state.contact_result = result
elif name == "dictionary":
state.dictionary_result = result
elif name == "news_analysis":
state.news_result = result
# 标记成功
state.success = True
return state
except Exception as e:
# 捕获子图错误,传递给主图
from .state import ErrorRecord, ErrorSeverity
from datetime import datetime
error_record = ErrorRecord(
error_type=f"{name}SubgraphError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source=f"{name}_subgraph",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=1,
context={"user_query": state.user_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
return state
return wrapped_node
# ========== 主图构建 ==========
def build_react_main_graph() -> StateGraph:
"""
构建完整的 React 模式主图
流程:
START
init_state (初始化)
react_reason (推理) ←──────────────┐
↓ │
条件路由 │
├─→ rag_retrieve →───────────────┤
├─→ contact_subgraph →───────────┤
├─→ dictionary_subgraph →────────┤
├─→ news_analysis_subgraph →─────┤
├─→ handle_error → (重试或结束) ──┤
└─→ final_response
END
"""
# 创建图
graph = StateGraph(MainGraphState)
# ========== 添加节点 ==========
# 1. 初始化节点
graph.add_node("init_state", init_state_node)
# 2. React 推理节点
graph.add_node("react_reason", react_reason_node)
# 3. RAG 检索节点
graph.add_node("rag_retrieve", rag_retrieve_node)
# 4. 错误处理节点
graph.add_node("handle_error", error_handling_node)
# 5. 最终回答节点
graph.add_node("final_response", final_response_node)
# ========== 添加子图节点 ==========
# 构建并包装子图(带错误处理)
contact_graph = build_contact_subgraph()
dictionary_graph = build_dictionary_subgraph()
news_analysis_graph = build_news_analysis_subgraph()
graph.add_node(
"contact_subgraph",
wrap_subgraph_for_error_handling(contact_graph.compile(), "contact")
)
graph.add_node(
"dictionary_subgraph",
wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary")
)
graph.add_node(
"news_analysis_subgraph",
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
)
# ========== 添加边 ==========
# 1. START → init_state
graph.add_edge(START, "init_state")
# 2. init_state → react_reason
graph.add_edge("init_state", "react_reason")
# 3. 条件路由react_reason → 各分支
graph.add_conditional_edges(
"react_reason",
route_by_reasoning,
{
# 检索分支 → 检索后回到推理
"rag_retrieve": "rag_retrieve",
# 子图分支 → 子图后回到推理
"contact_subgraph": "contact_subgraph",
"dictionary_subgraph": "dictionary_subgraph",
"news_analysis_subgraph": "news_analysis_subgraph",
# 错误处理分支
"handle_error": "handle_error",
# 最终回答分支
"final_response": "final_response",
}
)
# 4. 循环边:检索/子图/错误处理 后 → 回到推理
graph.add_edge("rag_retrieve", "react_reason")
graph.add_edge("contact_subgraph", "react_reason")
graph.add_edge("dictionary_subgraph", "react_reason")
graph.add_edge("news_analysis_subgraph", "react_reason")
graph.add_edge("handle_error", "react_reason") # 错误处理后可能重试
# 5. 最终边final_response → END
graph.add_edge("final_response", END)
return graph
# ========== 兼容性:保留旧的函数名 ==========
def build_main_graph() -> StateGraph:
"""
兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本
"""
return build_react_main_graph()
# ========== 导出 ==========
__all__ = [
"build_react_main_graph",
"build_main_graph",
"wrap_subgraph_for_error_handling"
]

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""
LangGraph 图结构可视化脚本
快速查看节点和边的连接关系
运行方式python backend/app/graph/visualize_graph.py
"""
import sys
from pathlib import Path
from dotenv import load_dotenv
# 确定项目根目录Agent1 目录)
# 当前文件位置backend/app/graph/visualize_graph.py
# 向上 4 级到 Agent1
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
BACKEND_DIR = PROJECT_ROOT / "backend"
# 关键:把 backend 目录加入 sys.path这样才能找到 rag_core
# 注意:这只对直接运行脚本有效,对 -m 方式无效(因为 -m 方式在脚本运行前就导入了)
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
load_dotenv(PROJECT_ROOT / ".env")
from app.agent.service import AIAgentService
from app.config import DB_URI
from app.main_graph.checkpoint.postgres.aio import AsyncPostgresSaver
import asyncio
async def visualize_graph():
"""可视化 LangGraph 结构"""
print("=" * 80)
print(" LangGraph 图结构可视化")
print("=" * 80)
print(f"项目根目录: {PROJECT_ROOT}")
print(f"Backend 目录: {BACKEND_DIR}")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
for model_name, graph in agent_service.graphs.items():
print(f"\n{'=' * 80}")
print(f" 模型: {model_name}")
print(f"{'=' * 80}")
# 获取图结构
graph_structure = graph.get_graph()
# 1. 直接打印节点和边
print("\n[1] 节点列表:")
print("-" * 80)
for node_id, node in graph_structure.nodes.items():
print(f" - {node_id}: {node.name}")
print("\n[2] 边列表:")
print("-" * 80)
for edge in graph_structure.edges:
print(f" {edge.source} --> {edge.target}")
# 3. ASCII 字符画(需要 grandalf
print("\n[3] ASCII 字符画:")
print("-" * 80)
try:
print(graph_structure.draw_ascii())
except Exception as e:
print(f"⚠️ ASCII 绘制失败: {e}")
# 4. Mermaid 源码
print("\n[4] Mermaid 源码 (可复制到 https://mermaid.live/):")
print("-" * 80)
print(graph_structure.draw_mermaid())
if __name__ == "__main__":
asyncio.run(visualize_graph())