This commit is contained in:
@@ -7,6 +7,9 @@ import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
|
||||
|
||||
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
|
||||
# 本地模块
|
||||
from ..model_services import get_cached_chat_services
|
||||
from ..main_graph.main_graph_builder import build_react_main_graph
|
||||
@@ -18,6 +21,23 @@ from ..logger import debug, info, warning, error
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
|
||||
|
||||
# ========== 自定义类型序列化器 ==========
|
||||
def create_serde() -> JsonPlusSerializer:
|
||||
"""创建带自定义类型注册的序列化器"""
|
||||
from backend.app.core.intent import ReasoningAction, RetrievalConfig, ReasoningResult
|
||||
|
||||
return JsonPlusSerializer(
|
||||
allowed_msgpack_modules=[
|
||||
("app.core.intent", "ReasoningAction"),
|
||||
("app.core.intent", "RetrievalConfig"),
|
||||
("app.core.intent", "ReasoningResult"),
|
||||
("app.main_graph.state", "CurrentAction"),
|
||||
("app.main_graph.state", "ErrorSeverity"),
|
||||
("app.main_graph.state", "ErrorRecord"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
@@ -55,6 +75,7 @@ class AIAgentService:
|
||||
tools=self.tools,
|
||||
mem0_client=self.mem0_client
|
||||
)
|
||||
# 注意:serde 已在创建 checkpointer 时传入,这里只需传入 checkpointer
|
||||
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
||||
info(f"✅ 单图初始化完成")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from app.logger import error # 保持兼容,或者替换为 logger
|
||||
from ..logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
Reference in New Issue
Block a user