This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from ..logger import error # 保持兼容,或者替换为 logger
|
||||
from app.logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
@@ -8,11 +8,12 @@ import asyncio
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.graph_builder import GraphBuilder, GraphContext
|
||||
from app.main_graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from app.main_graph.config import set_stream_writer
|
||||
from ..model_services.chat_services import get_all_chat_services, LocalVLLMChatProvider
|
||||
from .rag_initializer import init_rag_tool
|
||||
from .intent_classifier import get_intent_classifier
|
||||
from ..logger import info, warning
|
||||
from app.main_graph.utils.rag_initializer import init_rag_tool
|
||||
from app.core.intent_classifier import get_intent_classifier
|
||||
from app.logger import info, warning
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
|
||||
@@ -4,7 +4,7 @@ FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
|
||||
"""
|
||||
|
||||
import os
|
||||
from .config import DB_URI, BACKEND_PORT
|
||||
from app.config import DB_URI, BACKEND_PORT
|
||||
import uuid
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -28,7 +28,7 @@ from app.subgraphs.dictionary.api_client import DictionaryAPIClient
|
||||
from app.subgraphs.news_analysis.api_client import NewsAPIClient
|
||||
from .db.init_db import init_subgraph_tables
|
||||
from .db.models import ContactRepository, DictionaryRepository, NewsRepository
|
||||
from .logger import info, error
|
||||
from app.logger import info, error
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
from .config import LOG_LEVEL, DEBUG
|
||||
from app.config import LOG_LEVEL, DEBUG
|
||||
import logging
|
||||
from typing import Any
|
||||
# 根据环境变量控制是否显示详细调试信息
|
||||
|
||||
21
backend/app/main_graph/config.py
Normal file
21
backend/app/main_graph/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Main Graph Configuration - Streaming Writer
|
||||
"""
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
_stream_writer: Optional[Callable[[Any], None]] = None
|
||||
|
||||
def set_stream_writer(writer: Callable[[Any], None]):
|
||||
"""Set the global stream writer"""
|
||||
global _stream_writer
|
||||
_stream_writer = writer
|
||||
|
||||
def get_stream_writer() -> Callable[[Any], None]:
|
||||
"""Get the global stream writer"""
|
||||
global _stream_writer
|
||||
if _stream_writer is None:
|
||||
# Default no-op writer
|
||||
def noop(_):
|
||||
pass
|
||||
return noop
|
||||
return _stream_writer
|
||||
8
backend/app/main_graph/graph.py
Normal file
8
backend/app/main_graph/graph.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
LangGraph 核心组件重新导出
|
||||
统一导入入口,避免直接依赖 langgraph
|
||||
"""
|
||||
|
||||
from langgraph.graph import StateGraph, START, END, add_messages
|
||||
|
||||
__all__ = ["StateGraph", "START", "END", "add_messages"]
|
||||
@@ -5,8 +5,8 @@ LangGraph 状态图构建模块 - 精简版,仅负责组装图
|
||||
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from .state import MessagesState, GraphContext
|
||||
from ..nodes import (
|
||||
from app.main_graph.state import MessagesState, GraphContext
|
||||
from .nodes import (
|
||||
should_continue,
|
||||
create_llm_call_node,
|
||||
create_tool_call_node,
|
||||
@@ -15,7 +15,7 @@ from ..nodes import (
|
||||
finalize_node,
|
||||
)
|
||||
from app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from ..memory import Mem0Client
|
||||
from app.memory import Mem0Client
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
|
||||
@@ -1 +1,19 @@
|
||||
"""主图节点"""
|
||||
"""
|
||||
主图节点模块导出
|
||||
"""
|
||||
|
||||
from .router import should_continue
|
||||
from .llm_call import create_llm_call_node
|
||||
from .tool_call import create_tool_call_node
|
||||
from .retrieve_memory import create_retrieve_memory_node
|
||||
from .summarize import create_summarize_node
|
||||
from .finalize import finalize_node
|
||||
|
||||
__all__ = [
|
||||
"should_continue",
|
||||
"create_llm_call_node",
|
||||
"create_tool_call_node",
|
||||
"create_retrieve_memory_node",
|
||||
"create_summarize_node",
|
||||
"finalize_node",
|
||||
]
|
||||
|
||||
@@ -8,8 +8,8 @@ from app.main_graph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MessagesState
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import info, error
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import info, error
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MessagesState
|
||||
from ..agent.prompts import create_system_prompt
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import debug, info, error
|
||||
from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
|
||||
def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from app.main_graph.state import MessagesState
|
||||
from ..memory.mem0_client import Mem0Client
|
||||
from ..logger import info
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.logger import info
|
||||
|
||||
# 全局变量,在 GraphBuilder 中注入
|
||||
_mem0_client: Mem0Client = None
|
||||
|
||||
@@ -11,7 +11,7 @@ import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from .retry_utils import (
|
||||
RetryConfig,
|
||||
RAG_RETRY_CONFIG,
|
||||
@@ -19,8 +19,8 @@ from .retry_utils import (
|
||||
)
|
||||
|
||||
# 真正导入和利用已有 RAG 代码
|
||||
from ..rag.tools import create_rag_tool_sync
|
||||
from ..rag.pipeline import RAGPipeline
|
||||
from app.rag.tools import create_rag_tool_sync
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
|
||||
|
||||
# ========== 全局 RAG 工具实例(延迟初始化)==========
|
||||
|
||||
@@ -22,7 +22,7 @@ from app.core.intent import (
|
||||
ReasoningResult
|
||||
)
|
||||
from app.core.state_base import StateUtils
|
||||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from .retry_utils import (
|
||||
RetryConfig,
|
||||
SUBGRAPH_RETRY_CONFIG
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
# 本地模块
|
||||
from .state import MessagesState
|
||||
from ..memory.mem0_client import Mem0Client
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import debug
|
||||
from app.main_graph.state import MessagesState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from ..config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
|
||||
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
|
||||
from app.main_graph.state import MessagesState
|
||||
from ..logger import info
|
||||
from app.logger import info
|
||||
|
||||
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']:
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import Any, Dict
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MessagesState
|
||||
from ..memory.mem0_client import Mem0Client
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import debug, info, error, warning
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error, warning
|
||||
|
||||
def create_summarize_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
|
||||
@@ -10,8 +10,8 @@ from app.main_graph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MessagesState
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import debug, info
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info
|
||||
|
||||
def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
"""
|
||||
|
||||
@@ -1 +1,10 @@
|
||||
"""主图工具"""
|
||||
from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from .subgraph_tools import SUBGRAPH_TOOLS, SUBGRAPH_TOOLS_BY_NAME
|
||||
|
||||
__all__ = [
|
||||
"AVAILABLE_TOOLS",
|
||||
"TOOLS_BY_NAME",
|
||||
"SUBGRAPH_TOOLS",
|
||||
"SUBGRAPH_TOOLS_BY_NAME",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# app/rag_initializer.py
|
||||
from ..rag.tools import create_rag_tool_sync
|
||||
from app.rag.tools import create_rag_tool_sync
|
||||
from rag_core import create_parent_retriever
|
||||
from ..model_services import get_embedding_service
|
||||
from ..logger import info, warning
|
||||
from app.model_services import get_embedding_service
|
||||
from app.logger import info, warning
|
||||
|
||||
async def init_rag_tool(local_llm_creator):
|
||||
"""初始化 RAG 工具,失败返回 None"""
|
||||
|
||||
@@ -287,7 +287,7 @@ def create_retry_wrapper_for_node(
|
||||
time.sleep(delay)
|
||||
|
||||
# 所有重试都失败,更新状态错误信息
|
||||
from .state import ErrorRecord, ErrorSeverity
|
||||
from app.main_graph.state import ErrorRecord, ErrorSeverity
|
||||
|
||||
error_record = ErrorRecord(
|
||||
error_type=f"{node_name}TimeoutError",
|
||||
|
||||
@@ -6,7 +6,7 @@ Main Graph Builder - Full React Mode with Loop Reasoning
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
from typing import Dict, Any
|
||||
|
||||
from .state import MainGraphState, CurrentAction
|
||||
from app.main_graph.state import MainGraphState, CurrentAction
|
||||
from .react_nodes import (
|
||||
init_state_node,
|
||||
react_reason_node,
|
||||
@@ -50,7 +50,7 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
|
||||
except Exception as e:
|
||||
# 捕获子图错误,传递给主图
|
||||
from .state import ErrorRecord, ErrorSeverity
|
||||
from app.main_graph.state import ErrorRecord, ErrorSeverity
|
||||
from datetime import datetime
|
||||
|
||||
error_record = ErrorRecord(
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from ..config import (
|
||||
from app.config import (
|
||||
LLM_API_KEY, ZHIPUAI_API_KEY,
|
||||
VLLM_BASE_URL, QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY,
|
||||
LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY,
|
||||
ZHIPU_EMBEDDING_MODEL, ZHIPU_API_BASE
|
||||
)
|
||||
from ..model_services import get_embedding_service
|
||||
from ..logger import info, warning, error
|
||||
from app.logger import info, warning, error
|
||||
import time
|
||||
"""
|
||||
Mem0 记忆层客户端封装模块
|
||||
|
||||
@@ -23,7 +23,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from ..config import (
|
||||
from app.config import (
|
||||
VLLM_BASE_URL,
|
||||
LLM_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
|
||||
@@ -21,7 +21,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from ..config import (
|
||||
from app.config import (
|
||||
LLAMACPP_EMBEDDING_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
|
||||
@@ -23,7 +23,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from ..config import (
|
||||
from app.config import (
|
||||
LLAMACPP_RERANKER_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
Contact Subgraph Module - Complete
|
||||
"""
|
||||
|
||||
from .state import (
|
||||
from app.main_graph.state import (
|
||||
ContactState,
|
||||
Contact,
|
||||
Email,
|
||||
ContactAction
|
||||
)
|
||||
from .graph import build_contact_subgraph
|
||||
from app.main_graph.graph import build_contact_subgraph
|
||||
from .nodes import (
|
||||
parse_intent,
|
||||
list_contacts,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .state import Contact, Email
|
||||
from app.main_graph.state import Contact, Email
|
||||
|
||||
|
||||
# ========== 模拟数据(保留作为备选)==========
|
||||
|
||||
@@ -6,7 +6,7 @@ Contact Subgraph Builder
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
|
||||
from .state import ContactState
|
||||
from app.main_graph.state import ContactState
|
||||
from .nodes import create_contact_nodes
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from datetime import datetime
|
||||
# 公共工具
|
||||
from ..common import MarkdownFormatter
|
||||
|
||||
from .state import ContactState, ContactAction, Contact, Email
|
||||
from app.main_graph.state import ContactState, ContactAction, Contact, Email
|
||||
from .api_client import ContactAPIClient
|
||||
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
Dictionary Subgraph Module - Complete
|
||||
"""
|
||||
|
||||
from .state import (
|
||||
from app.main_graph.state import (
|
||||
DictionaryState,
|
||||
DictionaryAction,
|
||||
WordEntry,
|
||||
ExtractedTerm
|
||||
)
|
||||
from .graph import build_dictionary_subgraph
|
||||
from app.main_graph.graph import build_dictionary_subgraph
|
||||
from .nodes import (
|
||||
parse_intent,
|
||||
query_word,
|
||||
|
||||
@@ -5,7 +5,7 @@ Dictionary Subgraph Builder - Complete
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
|
||||
from .state import DictionaryState
|
||||
from app.main_graph.state import DictionaryState
|
||||
from .nodes import (
|
||||
parse_intent,
|
||||
query_word,
|
||||
|
||||
@@ -12,7 +12,7 @@ from ..common import (
|
||||
MarkdownFormatter
|
||||
)
|
||||
|
||||
from .state import (
|
||||
from app.main_graph.state import (
|
||||
DictionaryState,
|
||||
DictionaryAction,
|
||||
WordEntry,
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
News Analysis Subgraph Module - Complete
|
||||
"""
|
||||
|
||||
from .state import (
|
||||
from app.main_graph.state import (
|
||||
NewsAnalysisState,
|
||||
NewsAction,
|
||||
NewsItem,
|
||||
NewsSource
|
||||
)
|
||||
from .graph import build_news_analysis_subgraph
|
||||
from app.main_graph.graph import build_news_analysis_subgraph
|
||||
from .nodes import (
|
||||
parse_intent,
|
||||
query_news,
|
||||
|
||||
@@ -5,7 +5,7 @@ News Analysis Subgraph Builder
|
||||
|
||||
from app.main_graph.graph import StateGraph, START, END
|
||||
|
||||
from .state import NewsAnalysisState
|
||||
from app.main_graph.state import NewsAnalysisState
|
||||
from .nodes import (
|
||||
parse_intent,
|
||||
query_news,
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime
|
||||
# 公共工具
|
||||
from ..common import MarkdownFormatter
|
||||
|
||||
from .state import (
|
||||
from app.main_graph.state import (
|
||||
NewsAnalysisState,
|
||||
NewsAction,
|
||||
NewsItem,
|
||||
|
||||
@@ -3,8 +3,8 @@ LangGraph 节点日志工具模块
|
||||
提供状态流转追踪和 LLM 输入输出打印功能
|
||||
"""
|
||||
|
||||
from ..config import ENABLE_GRAPH_TRACE
|
||||
from ..logger import debug, info
|
||||
from app.config import ENABLE_GRAPH_TRACE
|
||||
from app.logger import debug, info
|
||||
|
||||
|
||||
def log_state_change(node_name: str, state: dict, prefix: str = "进入"):
|
||||
|
||||
Reference in New Issue
Block a user