refactor: 重构目录结构 - 简化层级
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-04-29 12:52:41 +08:00
parent 223d1c9afd
commit ef5113bffb
54 changed files with 42 additions and 1819 deletions

View File

@@ -0,0 +1 @@
"""主图工具函数"""

View File

@@ -0,0 +1,27 @@
# app/rag_initializer.py
from ..rag.tools import create_rag_tool_sync
from rag_core import create_parent_retriever
from ..model_services import get_embedding_service
from ..logger import info, warning
async def init_rag_tool(local_llm_creator):
"""初始化 RAG 工具,失败返回 None"""
try:
info("🔄 正在初始化 RAG 检索系统...")
# 使用统一的嵌入服务获取接口
embeddings = get_embedding_service()
retriever = create_parent_retriever(
collection_name="rag_documents",
search_k=5,
embeddings=embeddings
)
rewrite_llm = local_llm_creator()
rag_tool = create_rag_tool_sync(
retriever, rewrite_llm,
num_queries=3, rerank_top_n=5
)
info("✅ RAG 检索工具初始化成功")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None

View File

@@ -0,0 +1,332 @@
"""
超时和重试工具模块
为 React 模式提供超时控制和重试机制
"""
import time
import asyncio
from functools import wraps
from typing import Callable, Any, Optional, Type, Tuple, Union
from dataclasses import dataclass, field
from enum import Enum, auto
class RetryStrategy(Enum):
"""重试策略"""
FIXED = auto() # 固定间隔
EXPONENTIAL = auto() # 指数退避
LINEAR = auto() # 线性增长
@dataclass
class RetryConfig:
"""重试配置"""
max_retries: int = 3 # 最大重试次数
base_delay: float = 1.0 # 基础延迟(秒)
max_delay: float = 10.0 # 最大延迟(秒)
strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
timeout: Optional[float] = 30.0 # 单次调用超时(秒)
recoverable_exceptions: Tuple[Type[Exception], ...] = field(
default_factory=lambda: (Exception,)
)
unrecoverable_exceptions: Tuple[Type[Exception], ...] = field(
default_factory=tuple
)
@dataclass
class RetryResult:
"""重试结果"""
success: bool
result: Any = None
error: Optional[Exception] = None
retry_count: int = 0
total_time: float = 0.0
timed_out: bool = False
# ========== 同步重试装饰器 ==========
def with_retry(
config: Optional[RetryConfig] = None,
max_retries: int = 3,
timeout: Optional[float] = 30.0,
base_delay: float = 1.0,
on_retry: Optional[Callable[[int, Exception], None]] = None
):
"""
同步重试装饰器
Args:
config: 重试配置对象
max_retries: 最大重试次数(如果没有 config
timeout: 单次调用超时(秒)
base_delay: 基础延迟(秒)
on_retry: 重试回调函数(retry_count, exception)
"""
if config is None:
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
base_delay=base_delay
)
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> RetryResult:
start_time = time.time()
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行函数(带超时)
if config.timeout:
# 使用信号量或线程实现超时(简化版)
result = func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# 成功
total_time = time.time() - start_time
return RetryResult(
success=True,
result=result,
retry_count=attempt,
total_time=total_time
)
except Exception as e:
last_error = e
# 检查是否是不可恢复的异常
if isinstance(e, config.unrecoverable_exceptions):
break
# 检查是否达到最大重试次数
if attempt >= config.max_retries:
break
# 计算延迟
delay = _calculate_delay(attempt, config)
# 回调通知
if on_retry:
on_retry(attempt + 1, e)
# 等待
time.sleep(delay)
# 所有重试都失败
total_time = time.time() - start_time
return RetryResult(
success=False,
error=last_error,
retry_count=config.max_retries,
total_time=total_time
)
return wrapper
return decorator
# ========== 异步重试装饰器 ==========
def with_async_retry(
config: Optional[RetryConfig] = None,
max_retries: int = 3,
timeout: Optional[float] = 30.0,
base_delay: float = 1.0,
on_retry: Optional[Callable[[int, Exception], None]] = None
):
"""
异步重试装饰器
"""
if config is None:
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
base_delay=base_delay
)
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs) -> RetryResult:
start_time = time.time()
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行函数(带超时)
if config.timeout:
result = await asyncio.wait_for(
func(*args, **kwargs),
timeout=config.timeout
)
else:
result = await func(*args, **kwargs)
# 成功
total_time = time.time() - start_time
return RetryResult(
success=True,
result=result,
retry_count=attempt,
total_time=total_time
)
except asyncio.TimeoutError as e:
last_error = e
timed_out = True
except Exception as e:
last_error = e
timed_out = False
# 检查是否是不可恢复的异常
if isinstance(e, config.unrecoverable_exceptions):
break
# 检查是否达到最大重试次数
if attempt >= config.max_retries:
break
# 计算延迟
delay = _calculate_delay(attempt, config)
# 回调通知
if on_retry:
on_retry(attempt + 1, last_error)
# 等待
await asyncio.sleep(delay)
# 所有重试都失败
total_time = time.time() - start_time
return RetryResult(
success=False,
error=last_error,
retry_count=config.max_retries,
total_time=total_time,
timed_out=isinstance(last_error, asyncio.TimeoutError)
)
return wrapper
return decorator
# ========== 辅助函数 ==========
def _calculate_delay(attempt: int, config: RetryConfig) -> float:
"""计算延迟时间"""
if config.strategy == RetryStrategy.FIXED:
delay = config.base_delay
elif config.strategy == RetryStrategy.LINEAR:
delay = config.base_delay * (attempt + 1)
elif config.strategy == RetryStrategy.EXPONENTIAL:
delay = config.base_delay * (2 ** attempt)
else:
delay = config.base_delay
# 不超过最大延迟
return min(delay, config.max_delay)
# ========== 为 React 节点设计的超时重试包装器 ==========
def create_retry_wrapper_for_node(
node_func: Callable,
node_name: str,
max_retries: int = 2,
timeout: float = 30.0
):
"""
为 React 节点创建带重试和超时的包装器
Args:
node_func: 原始节点函数
node_name: 节点名称(用于错误标识)
max_retries: 最大重试次数
timeout: 单次执行超时
Returns: 包装后的节点函数
"""
config = RetryConfig(
max_retries=max_retries,
timeout=timeout,
strategy=RetryStrategy.EXPONENTIAL
)
@wraps(node_func)
def wrapped_node(state):
# 记录开始时间
start_time = time.time()
# 重试循环
last_error = None
for attempt in range(config.max_retries + 1):
try:
# 执行节点
result = node_func(state)
# 检查节点是否报告了错误
if hasattr(state, "current_error") and state.current_error:
# 节点内部报告了错误,继续重试
last_error = Exception(state.current_error.error_message)
if attempt < config.max_retries:
delay = _calculate_delay(attempt, config)
time.sleep(delay)
continue
# 成功
return result
except Exception as e:
last_error = e
if attempt >= config.max_retries:
break
# 等待后重试
delay = _calculate_delay(attempt, config)
time.sleep(delay)
# 所有重试都失败,更新状态错误信息
from .state import ErrorRecord, ErrorSeverity
error_record = ErrorRecord(
error_type=f"{node_name}TimeoutError",
error_message=str(last_error) if last_error else f"{node_name} 执行超时",
severity=ErrorSeverity.ERROR,
source=node_name,
retry_count=config.max_retries,
max_retries=config.max_retries,
context={
"timeout": timeout,
"total_time": time.time() - start_time
}
)
if hasattr(state, "errors"):
state.errors.append(error_record)
if hasattr(state, "current_error"):
state.current_error = error_record
if hasattr(state, "error_message"):
state.error_message = str(last_error)
if hasattr(state, "current_phase"):
state.current_phase = "error_handling"
return state
return wrapped_node
# ========== 预配置的 RAG 重试配置 ==========
RAG_RETRY_CONFIG = RetryConfig(
max_retries=2,
timeout=60.0, # RAG 可以容忍稍长的超时
base_delay=2.0,
strategy=RetryStrategy.EXPONENTIAL
)
# ========== 预配置的子图重试配置 ==========
SUBGRAPH_RETRY_CONFIG = RetryConfig(
max_retries=1, # 子图通常不适合多次重试
timeout=120.0, # 子图执行时间较长
base_delay=3.0
)

View File

@@ -0,0 +1,193 @@
"""
React 模式主图构建器 - 完整循环推理版本
Main Graph Builder - Full React Mode with Loop Reasoning
"""
from app.main_graph.graph import StateGraph, START, END
from typing import Dict, Any
from .state import MainGraphState, CurrentAction
from .react_nodes import (
init_state_node,
react_reason_node,
error_handling_node,
final_response_node,
route_by_reasoning
)
from .rag_nodes import rag_retrieve_node
from app.subgraphs.contact import build_contact_subgraph
from app.subgraphs.dictionary import build_dictionary_subgraph
from app.subgraphs.news_analysis import build_news_analysis_subgraph
# ========== 子图包装器(处理子图错误传递) ==========
def wrap_subgraph_for_error_handling(subgraph, name: str):
"""
包装子图,使其错误能传递给主图
Args:
subgraph: 编译好的子图
name: 子图名称(用于错误标识)
Returns: 包装后的节点函数
"""
def wrapped_node(state: MainGraphState) -> MainGraphState:
try:
# 调用子图
result = subgraph.invoke(state)
# 更新主图状态
if name == "contact":
state.contact_result = result
elif name == "dictionary":
state.dictionary_result = result
elif name == "news_analysis":
state.news_result = result
# 标记成功
state.success = True
return state
except Exception as e:
# 捕获子图错误,传递给主图
from .state import ErrorRecord, ErrorSeverity
from datetime import datetime
error_record = ErrorRecord(
error_type=f"{name}SubgraphError",
error_message=str(e),
severity=ErrorSeverity.WARNING,
source=f"{name}_subgraph",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=1,
context={"user_query": state.user_query}
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
state.success = False
return state
return wrapped_node
# ========== 主图构建 ==========
def build_react_main_graph() -> StateGraph:
"""
构建完整的 React 模式主图
流程:
START
init_state (初始化)
react_reason (推理) ←──────────────┐
↓ │
条件路由 │
├─→ rag_retrieve →───────────────┤
├─→ contact_subgraph →───────────┤
├─→ dictionary_subgraph →────────┤
├─→ news_analysis_subgraph →─────┤
├─→ handle_error → (重试或结束) ──┤
└─→ final_response
END
"""
# 创建图
graph = StateGraph(MainGraphState)
# ========== 添加节点 ==========
# 1. 初始化节点
graph.add_node("init_state", init_state_node)
# 2. React 推理节点
graph.add_node("react_reason", react_reason_node)
# 3. RAG 检索节点
graph.add_node("rag_retrieve", rag_retrieve_node)
# 4. 错误处理节点
graph.add_node("handle_error", error_handling_node)
# 5. 最终回答节点
graph.add_node("final_response", final_response_node)
# ========== 添加子图节点 ==========
# 构建并包装子图(带错误处理)
contact_graph = build_contact_subgraph()
dictionary_graph = build_dictionary_subgraph()
news_analysis_graph = build_news_analysis_subgraph()
graph.add_node(
"contact_subgraph",
wrap_subgraph_for_error_handling(contact_graph.compile(), "contact")
)
graph.add_node(
"dictionary_subgraph",
wrap_subgraph_for_error_handling(dictionary_graph.compile(), "dictionary")
)
graph.add_node(
"news_analysis_subgraph",
wrap_subgraph_for_error_handling(news_analysis_graph.compile(), "news_analysis")
)
# ========== 添加边 ==========
# 1. START → init_state
graph.add_edge(START, "init_state")
# 2. init_state → react_reason
graph.add_edge("init_state", "react_reason")
# 3. 条件路由react_reason → 各分支
graph.add_conditional_edges(
"react_reason",
route_by_reasoning,
{
# 检索分支 → 检索后回到推理
"rag_retrieve": "rag_retrieve",
# 子图分支 → 子图后回到推理
"contact_subgraph": "contact_subgraph",
"dictionary_subgraph": "dictionary_subgraph",
"news_analysis_subgraph": "news_analysis_subgraph",
# 错误处理分支
"handle_error": "handle_error",
# 最终回答分支
"final_response": "final_response",
}
)
# 4. 循环边:检索/子图/错误处理 后 → 回到推理
graph.add_edge("rag_retrieve", "react_reason")
graph.add_edge("contact_subgraph", "react_reason")
graph.add_edge("dictionary_subgraph", "react_reason")
graph.add_edge("news_analysis_subgraph", "react_reason")
graph.add_edge("handle_error", "react_reason") # 错误处理后可能重试
# 5. 最终边final_response → END
graph.add_edge("final_response", END)
return graph
# ========== 兼容性:保留旧的函数名 ==========
def build_main_graph() -> StateGraph:
"""
兼容性函数:旧代码调用 build_main_graph() 时返回 React 版本
"""
return build_react_main_graph()
# ========== 导出 ==========
__all__ = [
"build_react_main_graph",
"build_main_graph",
"wrap_subgraph_for_error_handling"
]

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""
LangGraph 图结构可视化脚本
快速查看节点和边的连接关系
运行方式python backend/app/graph/visualize_graph.py
"""
import sys
from pathlib import Path
from dotenv import load_dotenv
# 确定项目根目录Agent1 目录)
# 当前文件位置backend/app/graph/visualize_graph.py
# 向上 4 级到 Agent1
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
BACKEND_DIR = PROJECT_ROOT / "backend"
# 关键:把 backend 目录加入 sys.path这样才能找到 rag_core
# 注意:这只对直接运行脚本有效,对 -m 方式无效(因为 -m 方式在脚本运行前就导入了)
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
load_dotenv(PROJECT_ROOT / ".env")
from app.agent.service import AIAgentService
from app.config import DB_URI
from app.main_graph.checkpoint.postgres.aio import AsyncPostgresSaver
import asyncio
async def visualize_graph():
"""可视化 LangGraph 结构"""
print("=" * 80)
print(" LangGraph 图结构可视化")
print("=" * 80)
print(f"项目根目录: {PROJECT_ROOT}")
print(f"Backend 目录: {BACKEND_DIR}")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
for model_name, graph in agent_service.graphs.items():
print(f"\n{'=' * 80}")
print(f" 模型: {model_name}")
print(f"{'=' * 80}")
# 获取图结构
graph_structure = graph.get_graph()
# 1. 直接打印节点和边
print("\n[1] 节点列表:")
print("-" * 80)
for node_id, node in graph_structure.nodes.items():
print(f" - {node_id}: {node.name}")
print("\n[2] 边列表:")
print("-" * 80)
for edge in graph_structure.edges:
print(f" {edge.source} --> {edge.target}")
# 3. ASCII 字符画(需要 grandalf
print("\n[3] ASCII 字符画:")
print("-" * 80)
try:
print(graph_structure.draw_ascii())
except Exception as e:
print(f"⚠️ ASCII 绘制失败: {e}")
# 4. Mermaid 源码
print("\n[4] Mermaid 源码 (可复制到 https://mermaid.live/):")
print("-" * 80)
print(graph_structure.draw_mermaid())
if __name__ == "__main__":
asyncio.run(visualize_graph())