From 22fdb625a409a46ea55164006d87044a448dbca7 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Thu, 7 May 2026 00:48:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E6=88=90=E6=9E=81=E7=AE=80=20?= =?UTF-8?q?LangGraph=20=E6=9E=B6=E6=9E=84=E8=BF=81=E7=A7=BB=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20Baosi=20API=20=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 主要变更: - 迁移到极简 LangGraph 标准架构(START → init_state → 记忆 → Agent ⇄ Tools → finalize → END) - 添加 Baosi API 支持,配置 ops4.7 模型 - 保留本地模型作为默认首选,Baosi 作为备选 - 新架构使用 LangGraph 原生 ToolNode 和 bind_tools - 移除旧的混合路由、JSON 解析等复杂逻辑 - 把旧代码移到 deprecated/ 目录 - 添加新的 Agent 节点和 Tools 模块 - 添加测试脚本验证新架构 - 所有测试通过 ✓ --- backend/app/agent/agent_service.py | 161 ++-------- backend/app/config.py | 8 +- .../nodes => deprecated}/fast_paths.py | 0 .../nodes => deprecated}/finalize.py | 0 .../nodes => deprecated}/hybrid_router.py | 0 backend/app/{core => deprecated}/intent.py | 0 .../app/{core => deprecated}/json_parser.py | 0 .../app/deprecated/main_graph_builder.old.py | 232 ++++++++++++++ .../nodes => deprecated}/reasoning.py | 0 backend/app/deprecated/state.old.py | 148 +++++++++ backend/app/main_graph/main_graph_builder.py | 297 ++++++++---------- backend/app/main_graph/nodes/__init__.py | 50 +-- backend/app/main_graph/nodes/agent.py | 89 ++++++ backend/app/main_graph/nodes/finalize_new.py | 59 ++++ .../app/main_graph/nodes/memory_trigger.py | 4 +- .../app/main_graph/nodes/retrieve_memory.py | 4 +- backend/app/main_graph/nodes/summarize.py | 4 +- backend/app/main_graph/state.py | 147 ++------- backend/app/model_services/chat_services.py | 63 +++- backend/app/tools/__init__.py | 188 +++++++++++ backend/app/utils/logging.py | 8 +- tools/test/test_baosi_provider.py | 59 ++++ tools/test/test_minimal_agent.py | 205 ++++++++++++ 23 files changed, 1232 insertions(+), 494 deletions(-) rename backend/app/{main_graph/nodes => deprecated}/fast_paths.py (100%) rename backend/app/{main_graph/nodes => deprecated}/finalize.py (100%) rename backend/app/{main_graph/nodes => deprecated}/hybrid_router.py (100%) rename backend/app/{core => deprecated}/intent.py (100%) rename backend/app/{core => deprecated}/json_parser.py (100%) create mode 100644 backend/app/deprecated/main_graph_builder.old.py rename backend/app/{main_graph/nodes => deprecated}/reasoning.py (100%) create mode 100644 backend/app/deprecated/state.old.py create mode 100644 backend/app/main_graph/nodes/agent.py create mode 100644 backend/app/main_graph/nodes/finalize_new.py create mode 100644 backend/app/tools/__init__.py create mode 100644 tools/test/test_baosi_provider.py create mode 100644 tools/test/test_minimal_agent.py diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 3f0b1a6..d32e7ff 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -1,5 +1,5 @@ """ -AI Agent 服务类 - 单图方案 + 动态模型选择 +AI Agent 服务类 - 极简 LangGraph Agent 架构 接收外部传入的 checkpointer,不负责管理连接生命周期 """ @@ -12,61 +12,16 @@ from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer # 本地模块 from ..model_services import get_cached_chat_services -from ..main_graph.main_graph_builder import build_react_main_graph -from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME -from ..main_graph.config import set_stream_writer -from ..main_graph.utils.rag_initializer import init_rag_tool +from ..main_graph.main_graph_builder import build_agent_graph from backend.app.logger import debug, info, warning, error -from ..main_graph.state import MainGraphState, CurrentAction - - -# ========== 自定义类型序列化器 ========== -def create_serde() -> JsonPlusSerializer: - """创建带自定义类型注册的序列化器""" - from backend.app.core.intent import ReasoningAction, RetrievalConfig, ReasoningResult - from backend.app.main_graph.state import ( - CurrentAction, ErrorSeverity, ErrorRecord, - ReactReasoningState, HybridRouterState, FastPathState - ) - from backend.app.main_graph.nodes.hybrid_router import HybridRouterResult - - return JsonPlusSerializer( - allowed_msgpack_modules=[ - # 新路径 - ("backend.app.core.intent", "ReasoningAction"), - ("backend.app.core.intent", "RetrievalConfig"), - ("backend.app.core.intent", "ReasoningResult"), - ("backend.app.main_graph.state", "CurrentAction"), - ("backend.app.main_graph.state", "ErrorSeverity"), - ("backend.app.main_graph.state", "ErrorRecord"), - ("backend.app.main_graph.state", "ReactReasoningState"), - ("backend.app.main_graph.state", "HybridRouterState"), - ("backend.app.main_graph.state", "FastPathState"), - ("backend.app.main_graph.nodes.hybrid_router", "HybridRouterResult"), - # 旧路径(兼容旧 checkpoint 数据) - ("app.core.intent", "ReasoningAction"), - ("app.core.intent", "RetrievalConfig"), - ("app.core.intent", "ReasoningResult"), - ("app.main_graph.state", "CurrentAction"), - ("app.main_graph.state", "ErrorSeverity"), - ("app.main_graph.state", "ErrorRecord"), - ("app.main_graph.state", "ReactReasoningState"), - ("app.main_graph.state", "HybridRouterState"), - ("app.main_graph.state", "FastPathState"), - ("app.main_graph.nodes.hybrid_router", "HybridRouterResult"), - ] - ) +from ..main_graph.state import AgentState class AIAgentService: def __init__(self, checkpointer): self.checkpointer = checkpointer - self.graph = None # 只有一张图 - self.chat_services = None # 缓存的模型字典 - self.tools = AVAILABLE_TOOLS.copy() - self.tools_by_name = TOOLS_BY_NAME.copy() - # RAG 管道(可选,需要时设置) - self.rag_pipeline = None + self.graph = None + self.chat_services = None # Mem0 客户端 self.mem0_client = None @@ -75,27 +30,20 @@ class AIAgentService: from ..memory.mem0_client import Mem0Client self.mem0_client = Mem0Client() - # 1. 初始化 RAG 工具(如果需要) - rag_tool = await init_rag_tool() - if rag_tool: - self.tools.append(rag_tool) - self.tools_by_name[rag_tool.name] = rag_tool - self.rag_tool = rag_tool # 保存到实例变量,供 config 注入 - - # 2. 获取缓存的模型字典 + # 1. 获取缓存的模型字典 self.chat_services = get_cached_chat_services() info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}") - # 3. 只构建一次图(传入 chat_services 字典) - info(f"🔄 构建单图...") - graph_builder = build_react_main_graph( + # 2. 构建图 + info(f"🔄 构建 Agent 图...") + graph_builder = build_agent_graph( chat_services=self.chat_services, - tools=self.tools, mem0_client=self.mem0_client ) - # 注意:serde 已在创建 checkpointer 时传入,这里只需传入 checkpointer + + # 编译图 self.graph = graph_builder.compile(checkpointer=self.checkpointer) - info(f"✅ 单图初始化完成") + info(f"✅ Agent 图初始化完成") return self @@ -130,19 +78,18 @@ class AIAgentService: Returns: (config, input_state) 元组 """ + from langchain_core.messages import HumanMessage + config = { "configurable": { "thread_id": thread_id, - "rag_tool": getattr(self, "rag_tool", None), }, "metadata": {"user_id": user_id} } + input_state = { - "user_query": message, - "messages": [{"role": "user", "content": message}], + "messages": [HumanMessage(content=message)], "user_id": user_id, - "current_model": model, - "current_action": CurrentAction.NONE } return config, input_state @@ -157,19 +104,19 @@ class AIAgentService: config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) result = await self.graph.ainvoke(input_state, config=config) - - reply = result.get("final_result", "") - if not reply and result.get("messages"): + + reply = "" + if result.get("messages"): reply = result["messages"][-1].content + token_usage = result.get("last_token_usage", {}) elapsed_time = result.get("last_elapsed_time", 0.0) - actual_model = result.get("current_model", resolved_model) return { "reply": reply, "token_usage": token_usage, "elapsed_time": elapsed_time, - "model_used": actual_model + "model_used": resolved_model } def _serialize_value(self, value): @@ -259,28 +206,8 @@ class AIAgentService: updates_data = chunk["data"] new_actual_model = actual_model_used - debug(f"[Stream] updates 数据: {list(updates_data.keys()) if isinstance(updates_data, dict) else type(updates_data)}") - - # 特别检查 final_result 和 current_model - if isinstance(updates_data, dict): - if "final_result" in updates_data: - debug(f"[Stream] 收到 final_result: {str(updates_data['final_result'])[:100]}...") - if "current_model" in updates_data: - new_actual_model = updates_data["current_model"] - info(f"[Stream] 实际使用模型: {new_actual_model}") - serialized_data = self._serialize_value(updates_data) - # 检查是否有人工审核请求 - if "review_pending" in serialized_data and serialized_data["review_pending"]: - review_id = serialized_data.get("review_id", "") - content_to_review = serialized_data.get("content_to_review", "") - yield { - "type": "human_review_request", - "review_id": review_id, - "content": content_to_review - } - # 检查是否有工具结果 if "messages" in serialized_data: for msg in serialized_data["messages"]: @@ -307,36 +234,6 @@ class AIAgentService: # 返回更新后的模型 yield {"type": "_update_state", "actual_model_used": new_actual_model} - async def _handle_custom_chunk(self, chunk: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: - """处理 custom 类型的 chunk""" - custom_data = chunk["data"] - - # 处理我们从 react_reason_node 发送的自定义推理事件 - if isinstance(custom_data, dict): - # 检查是否是我们的推理事件 - if "action" in custom_data and "reasoning" in custom_data: - yield { - "type": "react_reasoning", - "step": custom_data.get("step", 1), - "action": custom_data.get("action", "unknown"), - "confidence": custom_data.get("confidence", 0), - "reasoning": custom_data.get("reasoning", "") - } - else: - # 处理其他自定义事件 - serialized_data = self._serialize_value(custom_data) - yield { - "type": "custom", - "data": serialized_data - } - else: - # 处理其他自定义事件 - serialized_data = self._serialize_value(custom_data) - yield { - "type": "custom", - "data": serialized_data - } - async def process_message_stream( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> AsyncGenerator[Dict[str, Any], None]: @@ -347,8 +244,7 @@ class AIAgentService: # 构建调用参数 config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) - # ========== React 循环路径 ========== - info(f"🚀 开始执行单图,指定模型: {resolved_model}") + info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}") current_node = None tool_calls_in_progress: Dict[str, Any] = {} actual_model_used = resolved_model @@ -361,7 +257,7 @@ class AIAgentService: async for chunk in self.graph.astream( input_state, config=config, - stream_mode=["messages", "updates", "custom"], + stream_mode=["messages", "updates"], version="v2", subgraphs=True ): @@ -375,10 +271,10 @@ class AIAgentService: if event.get("type") == "_update_state": current_node = event.get("current_node", current_node) else: - # 如果是 llm_call 节点的 token,收集完整消息 + # 如果是 agent 节点的 token,收集完整消息 if ( event.get("type") == "llm_token" - and event.get("node") == "llm_call" + and event.get("node") == "agent" and "token" in event ): full_message_content += event["token"] @@ -393,18 +289,13 @@ class AIAgentService: else: yield event - elif chunk_type == "custom": - async for event in self._handle_custom_chunk(chunk): - yield event - # 完整消息集合完成后,一次性打印 info(f"✅ graph.astream() 完成,共 {chunk_count} 个 chunks") if full_message_content: info(f"📄 完整消息内容: {repr(full_message_content)}") - info(f"🤖 实际使用模型: {actual_model_used}") except Exception as e: - error(f"❌ 执行单图时出错: {e}") + error(f"❌ 执行图时出错: {e}") import traceback error(f"📋 堆栈: {traceback.format_exc()}") yield { diff --git a/backend/app/config.py b/backend/app/config.py index 86a48cf..3a82aef 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -40,6 +40,7 @@ def _get_bool(key: str) -> bool | None: ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY") DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY") SILICONFLOW_API_KEY = _get_str("SILICONFLOW_API_KEY") +BAOSI_API_KEY = _get_str("BAOSI_API_KEY") # ========== 智谱 API 配置 ========== @@ -58,6 +59,11 @@ SILICONFLOW_RERANK_MODEL = _get_str("SILICONFLOW_RERANK_MODEL") or "BAAI/bge-rer SILICONFLOW_API_BASE = _get_str("SILICONFLOW_API_BASE") or "https://api.siliconflow.cn/v1" +# ========== Baosi API 配置 ========== +BAOSI_API_BASE = _get_str("BAOSI_API_BASE") or "https://api.baosiapi.com" +BAOSI_MODEL = _get_str("BAOSI_MODEL") or "ops4.7" + + # ========== 稀疏模型配置 ========== SPARSE_MODEL_PATH = _get_str("SPARSE_MODEL_PATH") or "./models/sparse" SPARSE_MODEL_NAME = _get_str("SPARSE_MODEL_NAME") or "Qdrant/bm25" @@ -141,4 +147,4 @@ ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE") # ========== 日志配置 ========== LOG_LEVEL = _get_str("LOG_LEVEL") -DEBUG = _get_bool("DEBUG") \ No newline at end of file +DEBUG = _get_bool("DEBUG") diff --git a/backend/app/main_graph/nodes/fast_paths.py b/backend/app/deprecated/fast_paths.py similarity index 100% rename from backend/app/main_graph/nodes/fast_paths.py rename to backend/app/deprecated/fast_paths.py diff --git a/backend/app/main_graph/nodes/finalize.py b/backend/app/deprecated/finalize.py similarity index 100% rename from backend/app/main_graph/nodes/finalize.py rename to backend/app/deprecated/finalize.py diff --git a/backend/app/main_graph/nodes/hybrid_router.py b/backend/app/deprecated/hybrid_router.py similarity index 100% rename from backend/app/main_graph/nodes/hybrid_router.py rename to backend/app/deprecated/hybrid_router.py diff --git a/backend/app/core/intent.py b/backend/app/deprecated/intent.py similarity index 100% rename from backend/app/core/intent.py rename to backend/app/deprecated/intent.py diff --git a/backend/app/core/json_parser.py b/backend/app/deprecated/json_parser.py similarity index 100% rename from backend/app/core/json_parser.py rename to backend/app/deprecated/json_parser.py diff --git a/backend/app/deprecated/main_graph_builder.old.py b/backend/app/deprecated/main_graph_builder.old.py new file mode 100644 index 0000000..922ac83 --- /dev/null +++ b/backend/app/deprecated/main_graph_builder.old.py @@ -0,0 +1,232 @@ +""" +主图构建器 - 构建整合后的完整主图 +""" + +from langgraph.graph import StateGraph, START, END +from typing import Dict, Any + +from .state import MainGraphState +from .nodes.reasoning import react_reason_node +from .nodes.web_search import web_search_node +from .nodes.error_handling import error_handling_node +from .nodes.routing import init_state_node, route_by_reasoning, should_summarize +from .nodes.hybrid_router import ( + hybrid_router_node, + route_from_hybrid_decision, + check_fast_path_success, +) +from .nodes.fast_paths import ( + fast_chitchat_node, + fast_rag_node, + fast_tool_node, +) +from .nodes.llm_call import create_dynamic_llm_call_node +from .nodes.rag_nodes import rag_retrieve_node +from .nodes.retrieve_memory import create_retrieve_memory_node +from .nodes.memory_trigger import memory_trigger_node, set_mem0_client +from .nodes.summarize import create_summarize_node +from .nodes.finalize import finalize_node +from backend.app.subgraphs.contact import build_contact_subgraph +from backend.app.subgraphs.dictionary import build_dictionary_subgraph +from backend.app.subgraphs.news_analysis import build_news_analysis_subgraph +from backend.app.logger import info + +from .subgraph_wrapper import create_subgraph_nodes + + +# ========== 主图构建 ========== + +def build_react_main_graph( + chat_services: dict, + tools=None, + mem0_client=None, + use_hybrid_router: bool = True +) -> StateGraph: + """ + 构建整合后的完整主图(支持混合路由 + 动态模型选择) + + Args: + chat_services: 模型名称 -> ChatModel 实例 的字典 + tools: 工具列表 + mem0_client: Mem0 客户端实例 + use_hybrid_router: 是否使用混合路由(快速路径 + React 循环) + + Returns: + StateGraph: 构建好的图 + """ + # 创建图 + graph = StateGraph(MainGraphState) + + # 设置全局 mem0_client + if mem0_client: + set_mem0_client(mem0_client) + + # ========== 创建节点 ========== + + # LLM 调用节点 + llm_node = create_dynamic_llm_call_node(chat_services, tools or []) + + # 记忆节点 + retrieve_memory_node = None + summarize_node = None + if mem0_client: + retrieve_memory_node = create_retrieve_memory_node(mem0_client) + summarize_node = create_summarize_node(mem0_client) + + # 子图节点 + contact_graph = build_contact_subgraph() + dictionary_graph = build_dictionary_subgraph() + news_analysis_graph = build_news_analysis_subgraph() + subgraph_nodes = create_subgraph_nodes( + contact_graph, dictionary_graph, news_analysis_graph + ) + + # ========== 添加节点到图 ========== + + # 阶段 1: 记忆检索 + if retrieve_memory_node: + graph.add_node("retrieve_memory", retrieve_memory_node) + graph.add_node("memory_trigger", memory_trigger_node) + + # 阶段 2: 初始化 + graph.add_node("init_state", init_state_node) + + # 阶段 3: 混合路由(可选) + if use_hybrid_router: + graph.add_node("hybrid_router", hybrid_router_node) + graph.add_node("fast_chitchat", fast_chitchat_node) + graph.add_node("fast_rag", fast_rag_node) + graph.add_node("fast_tool", fast_tool_node) + + # 阶段 4: React 循环推理(始终保留) + graph.add_node("react_reason", react_reason_node) + graph.add_node("rag_retrieve", rag_retrieve_node) + graph.add_node("web_search", web_search_node) + graph.add_node("handle_error", error_handling_node) + + if llm_node is not None: + graph.add_node("llm_call", llm_node) + + # 子图节点 + for node_name, node_func in subgraph_nodes.items(): + graph.add_node(node_name, node_func) + + # 阶段 5: 完成处理 + if summarize_node: + graph.add_node("summarize", summarize_node) + graph.add_node("finalize", finalize_node) + + # ========== 添加边 ========== + + # 阶段 1: 记忆检索 + _add_memory_edges(graph, retrieve_memory_node) + + # 阶段 2: 初始化 + graph.add_edge("memory_trigger", "init_state") + + # 阶段 3: 路由分支 + _add_routing_edges(graph, use_hybrid_router, llm_node) + + # 阶段 4: React 循环边 + _add_react_loop_edges(graph, subgraph_nodes) + + # 阶段 5: 完成阶段 + _add_finalize_edges(graph, llm_node, summarize_node) + + info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})") + + return graph + + +def _add_memory_edges(graph: StateGraph, retrieve_memory_node) -> None: + """添加记忆检索阶段的边""" + if retrieve_memory_node: + graph.add_edge(START, "retrieve_memory") + graph.add_edge("retrieve_memory", "memory_trigger") + else: + graph.add_edge(START, "memory_trigger") + + +def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) -> None: + """添加路由阶段的边""" + if use_hybrid_router: + graph.add_edge("init_state", "hybrid_router") + + # 混合路由条件分支 + graph.add_conditional_edges( + "hybrid_router", + route_from_hybrid_decision, + { + "fast_chitchat": "fast_chitchat", + "fast_rag": "fast_rag", + "fast_tool": "fast_tool", + "react_loop": "react_reason" + } + ) + + # 快速路径的完成检查(fast_rag 失败直接走 react_reason) + for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]: + graph.add_conditional_edges( + fast_node, + check_fast_path_success, + { + "llm_call": "llm_call", + "escalate": "react_reason" + } + ) + + info(f"✅ [图构建] 混合路由模式已启用") + else: + graph.add_edge("init_state", "react_reason") + info(f"✅ [图构建] 纯 React 模式") + + +def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> None: + """添加 React 循环阶段的边""" + subgraph_names = list(subgraph_nodes.keys()) + + # React 推理的条件分支 + graph.add_conditional_edges( + "react_reason", + route_by_reasoning, + { + "rag_retrieve": "rag_retrieve", + "web_search": "web_search", + **{name: name for name in subgraph_names}, + "handle_error": "handle_error", + "llm_call": "llm_call" + } + ) + + # RAG 检索后回到 react_reason,由意图识别决定下一步 + graph.add_edge("rag_retrieve", "react_reason") + + # 循环边(回到 react_reason) + loop_back_nodes = ["web_search", "handle_error"] + subgraph_names + for node_name in loop_back_nodes: + graph.add_edge(node_name, "react_reason") + + +def _add_finalize_edges(graph: StateGraph, llm_node, summarize_node) -> None: + """添加完成阶段的边""" + if llm_node is not None: + if summarize_node: + graph.add_conditional_edges( + "llm_call", + should_summarize, + { + "summarize": "summarize", + "finalize": "finalize" + } + ) + graph.add_edge("summarize", "finalize") + else: + graph.add_edge("llm_call", "finalize") + + graph.add_edge("finalize", END) + + +# ========== 导出 ========== +__all__ = [ + "build_react_main_graph", +] diff --git a/backend/app/main_graph/nodes/reasoning.py b/backend/app/deprecated/reasoning.py similarity index 100% rename from backend/app/main_graph/nodes/reasoning.py rename to backend/app/deprecated/reasoning.py diff --git a/backend/app/deprecated/state.old.py b/backend/app/deprecated/state.old.py new file mode 100644 index 0000000..a6aad6b --- /dev/null +++ b/backend/app/deprecated/state.old.py @@ -0,0 +1,148 @@ +""" +主图状态定义 - React 模式增强版 +Main Graph State Definition - React Mode Enhanced + +字段分类说明: +- 持久化字段:跨轮次保留,不重置 +- 临时字段:每轮对话开始时重置 +""" + +from enum import Enum, auto +from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List +from dataclasses import dataclass, field +from langgraph.graph import add_messages +from langchain_core.messages import BaseMessage + + +# ========== 枚举类型 ========== +class CurrentAction(Enum): + """主图当前操作类型""" + NONE = auto() + GENERAL_CHAT = auto() + NEWS_ANALYSIS = auto() + DICTIONARY = auto() + CONTACT = auto() + + +class ErrorSeverity(Enum): + """错误严重程度""" + INFO = auto() # 信息级别,继续执行 + WARNING = auto() # 警告级别,可以重试 + ERROR = auto() # 错误级别,需要处理 + FATAL = auto() # 致命错误,终止执行 + + +@dataclass +class ErrorRecord: + """错误记录""" + error_type: str + error_message: str + severity: ErrorSeverity = ErrorSeverity.ERROR + source: str = "" # 来源:哪个节点/子图/工具 + timestamp: str = "" + retry_count: int = 0 # 已重试次数 + max_retries: int = 3 # 最大重试次数 + context: Dict[str, Any] = field(default_factory=dict) # 错误上下文 + + +@dataclass +class ReactReasoningState: + """React 推理状态""" + last_reasoning: Optional[Dict[str, Any]] = None + reasoning_result: Optional[Any] = None # 实际类型是 ReasoningResult + + +@dataclass +class HybridRouterState: + """混合路由状态""" + decision: Optional[Any] = None # 实际类型是 HybridRouterResult + start_time: Optional[str] = None + + +@dataclass +class FastPathState: + """快速路径状态""" + chitchat_success: bool = False + rag_success: bool = False + tool_success: bool = False + failed: bool = False + fail_reason: str = "" + + +@dataclass +class MainGraphState: + """ + 主图状态定义 + + 字段分类: + - 持久化字段:跨轮次保留,不重置 + - 临时字段:每轮对话开始时重置 + """ + + # ================================================== + # 持久化字段(每轮保留) + # ================================================== + + messages: Annotated[Sequence[BaseMessage], add_messages] = field(default_factory=list) + turns_since_last_summary: int = 0 # 距离上次总结的轮数 + user_id: str = "" + + # ================================================== + # 临时字段(每轮重置) + # ================================================== + + # 主图控制字段 + user_query: str = "" + current_action: CurrentAction = CurrentAction.NONE + current_model: str = "" # 本次请求使用的模型 + intent_confidence: float = 0.0 + + # React 推理专用字段 + reasoning_step: int = 0 + max_steps: int = 10 # 避免过长循环 + last_action: str = "" + reasoning_history: List[Dict[str, Any]] = field(default_factory=list) + + # RAG 相关字段 + rag_context: str = "" + rag_retrieved: bool = False + rag_docs: List[Dict[str, Any]] = field(default_factory=list) + rag_confidence: float = 0.0 # RAG 检索置信度 (0.0-1.0) + rag_attempts: int = 0 # RAG 检索次数统计 + + # 联网搜索相关字段 + web_search_results: List[str] = field(default_factory=list) + + # 错误处理字段 + errors: List[ErrorRecord] = field(default_factory=list) + current_error: Optional[ErrorRecord] = None + retry_action: Optional[str] = None + error_message: str = "" + + # 子图结果字段 + news_result: Optional[Dict[str, Any]] = None + dictionary_result: Optional[Dict[str, Any]] = None + contact_result: Optional[Dict[str, Any]] = None + + # 执行状态 + current_phase: str = "init" + final_result: str = "" + success: bool = False + + # 元数据 + start_time: Optional[str] = None + end_time: Optional[str] = None + + # 结构化状态 + react_reasoning: ReactReasoningState = field(default_factory=ReactReasoningState) + hybrid_router: HybridRouterState = field(default_factory=HybridRouterState) + fast_path: FastPathState = field(default_factory=FastPathState) + + # 统计字段(用于反馈) + llm_calls: int = 0 + last_token_usage: Dict[str, Any] = field(default_factory=dict) + last_elapsed_time: float = 0.0 + memory_context: str = "" # 记忆检索结果 + + # 向后兼容(保留但不推荐使用) + debug_info: Dict[str, Any] = field(default_factory=dict) diff --git a/backend/app/main_graph/main_graph_builder.py b/backend/app/main_graph/main_graph_builder.py index 922ac83..81ecd72 100644 --- a/backend/app/main_graph/main_graph_builder.py +++ b/backend/app/main_graph/main_graph_builder.py @@ -1,232 +1,185 @@ """ -主图构建器 - 构建整合后的完整主图 +极简 Agent 主图 - 回归 LangGraph 标准模式 + +架构: +START → [init_state] → [记忆] → [Agent] ⇄ [Tools] → [Finalize] → END + ↑________↓ """ from langgraph.graph import StateGraph, START, END -from typing import Dict, Any +from langgraph.prebuilt import ToolNode +from langchain_core.runnables.config import RunnableConfig +from typing import Dict, Any, Optional -from .state import MainGraphState -from .nodes.reasoning import react_reason_node -from .nodes.web_search import web_search_node -from .nodes.error_handling import error_handling_node -from .nodes.routing import init_state_node, route_by_reasoning, should_summarize -from .nodes.hybrid_router import ( - hybrid_router_node, - route_from_hybrid_decision, - check_fast_path_success, -) -from .nodes.fast_paths import ( - fast_chitchat_node, - fast_rag_node, - fast_tool_node, -) -from .nodes.llm_call import create_dynamic_llm_call_node -from .nodes.rag_nodes import rag_retrieve_node -from .nodes.retrieve_memory import create_retrieve_memory_node +from .state import AgentState from .nodes.memory_trigger import memory_trigger_node, set_mem0_client from .nodes.summarize import create_summarize_node -from .nodes.finalize import finalize_node -from backend.app.subgraphs.contact import build_contact_subgraph -from backend.app.subgraphs.dictionary import build_dictionary_subgraph -from backend.app.subgraphs.news_analysis import build_news_analysis_subgraph -from backend.app.logger import info - -from .subgraph_wrapper import create_subgraph_nodes +from .nodes.agent import create_agent_node +from backend.app.tools import ALL_TOOLS +from backend.app.logger import info, warning -# ========== 主图构建 ========== - -def build_react_main_graph( +def build_agent_graph( chat_services: dict, - tools=None, mem0_client=None, - use_hybrid_router: bool = True + max_steps: int = 10 ) -> StateGraph: """ - 构建整合后的完整主图(支持混合路由 + 动态模型选择) + 构建极简 Agent 图 Args: - chat_services: 模型名称 -> ChatModel 实例 的字典 - tools: 工具列表 - mem0_client: Mem0 客户端实例 - use_hybrid_router: 是否使用混合路由(快速路径 + React 循环) + chat_services: 模型服务字典 + mem0_client: 记忆客户端(可选) + max_steps: 最大步数限制 Returns: StateGraph: 构建好的图 """ - # 创建图 - graph = StateGraph(MainGraphState) - # 设置全局 mem0_client + graph = StateGraph(AgentState) + + # ========== 设置全局客户端 ========== if mem0_client: set_mem0_client(mem0_client) - # ========== 创建节点 ========== + # ========== 创建核心节点 ========== - # LLM 调用节点 - llm_node = create_dynamic_llm_call_node(chat_services, tools or []) + # 1. Agent 节点(绑定工具的 LLM) + llm = chat_services.get("primary", list(chat_services.values())[0] if chat_services else None) + if llm is None: + raise ValueError("No LLM service provided") - # 记忆节点 + llm_with_tools = llm.bind_tools(ALL_TOOLS) + agent_node = create_agent_node(llm_with_tools, llm) + + # 2. Tool 节点(LangGraph 内置) + tool_node = ToolNode(ALL_TOOLS) + + # 3. 记忆/总结节点(保留现有) retrieve_memory_node = None summarize_node = None if mem0_client: - retrieve_memory_node = create_retrieve_memory_node(mem0_client) - summarize_node = create_summarize_node(mem0_client) + try: + from .nodes.retrieve_memory import create_retrieve_memory_node + retrieve_memory_node = create_retrieve_memory_node(mem0_client) + summarize_node = create_summarize_node(mem0_client) + except Exception as e: + info(f"[Graph Builder] 记忆节点初始化失败: {e}") - # 子图节点 - contact_graph = build_contact_subgraph() - dictionary_graph = build_dictionary_subgraph() - news_analysis_graph = build_news_analysis_subgraph() - subgraph_nodes = create_subgraph_nodes( - contact_graph, dictionary_graph, news_analysis_graph - ) + # ========== 添加节点 ========== - # ========== 添加节点到图 ========== + # 1. 初始化节点(重置步数) + async def init_state_node(state: AgentState) -> Dict[str, Any]: + """初始化状态:重置步数计数器""" + info("[Init State] 初始化状态,重置步数") + return { + "current_step": 0 + } - # 阶段 1: 记忆检索 + graph.add_node("init_state", init_state_node) + + # 2. 记忆阶段 if retrieve_memory_node: graph.add_node("retrieve_memory", retrieve_memory_node) graph.add_node("memory_trigger", memory_trigger_node) - # 阶段 2: 初始化 - graph.add_node("init_state", init_state_node) + # 3. 核心 Agent 循环 + graph.add_node("agent", agent_node) + graph.add_node("tools", tool_node) - # 阶段 3: 混合路由(可选) - if use_hybrid_router: - graph.add_node("hybrid_router", hybrid_router_node) - graph.add_node("fast_chitchat", fast_chitchat_node) - graph.add_node("fast_rag", fast_rag_node) - graph.add_node("fast_tool", fast_tool_node) - - # 阶段 4: React 循环推理(始终保留) - graph.add_node("react_reason", react_reason_node) - graph.add_node("rag_retrieve", rag_retrieve_node) - graph.add_node("web_search", web_search_node) - graph.add_node("handle_error", error_handling_node) - - if llm_node is not None: - graph.add_node("llm_call", llm_node) - - # 子图节点 - for node_name, node_func in subgraph_nodes.items(): - graph.add_node(node_name, node_func) - - # 阶段 5: 完成处理 + # 4. 完成阶段 if summarize_node: graph.add_node("summarize", summarize_node) - graph.add_node("finalize", finalize_node) + + # 简单的完成节点 + async def finalize_node_simple(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: + """简单的完成节点,只发送完成事件""" + info("[Finalize] 进入完成节点") + + try: + from backend.app.main_graph.config import get_stream_writer + writer = get_stream_writer() + + # 提取最后的回复 + final_reply = "" + if state.messages: + last_msg = state.messages[-1] + final_reply = last_msg.content if hasattr(last_msg, "content") else str(last_msg) + + if writer and hasattr(writer, "__call__"): + try: + writer({ + "type": "custom", + "data": { + "type": "done", + "token_usage": state.last_token_usage, + "elapsed_time": state.last_elapsed_time, + "final_result": final_reply + } + }) + info("🏁 [完成事件] 已发送完成事件") + except Exception as e: + warning(f"⚠️ [完成事件] 发送失败 (非致命): {e}") + except Exception as e: + warning(f"⚠️ [完成事件] 处理失败 (非致命): {e}") + + return {} + + graph.add_node("finalize", finalize_node_simple) # ========== 添加边 ========== - # 阶段 1: 记忆检索 - _add_memory_edges(graph, retrieve_memory_node) + # 1. 初始化 + graph.add_edge(START, "init_state") - # 阶段 2: 初始化 - graph.add_edge("memory_trigger", "init_state") - - # 阶段 3: 路由分支 - _add_routing_edges(graph, use_hybrid_router, llm_node) - - # 阶段 4: React 循环边 - _add_react_loop_edges(graph, subgraph_nodes) - - # 阶段 5: 完成阶段 - _add_finalize_edges(graph, llm_node, summarize_node) - - info(f"✅ [图构建] 整合后的完整主图构建完成(混合路由: {use_hybrid_router})") - - return graph - - -def _add_memory_edges(graph: StateGraph, retrieve_memory_node) -> None: - """添加记忆检索阶段的边""" + # 2. 记忆阶段 if retrieve_memory_node: - graph.add_edge(START, "retrieve_memory") + graph.add_edge("init_state", "retrieve_memory") graph.add_edge("retrieve_memory", "memory_trigger") else: - graph.add_edge(START, "memory_trigger") + graph.add_edge("init_state", "memory_trigger") + # 3. 进入 Agent + graph.add_edge("memory_trigger", "agent") -def _add_routing_edges(graph: StateGraph, use_hybrid_router: bool, llm_node) -> None: - """添加路由阶段的边""" - if use_hybrid_router: - graph.add_edge("init_state", "hybrid_router") + # 4. 核心循环:Agent ⇄ Tools + def should_continue(state: AgentState) -> str: + """判断是继续调用工具还是结束""" + messages = state.messages + last_message = messages[-1] if messages else None - # 混合路由条件分支 - graph.add_conditional_edges( - "hybrid_router", - route_from_hybrid_decision, - { - "fast_chitchat": "fast_chitchat", - "fast_rag": "fast_rag", - "fast_tool": "fast_tool", - "react_loop": "react_reason" - } - ) + # 检查是否有 tool_calls + if last_message and hasattr(last_message, "tool_calls") and last_message.tool_calls: + return "tools" - # 快速路径的完成检查(fast_rag 失败直接走 react_reason) - for fast_node in ["fast_chitchat", "fast_rag", "fast_tool"]: - graph.add_conditional_edges( - fast_node, - check_fast_path_success, - { - "llm_call": "llm_call", - "escalate": "react_reason" - } - ) + # 否则结束 + return "finalize" - info(f"✅ [图构建] 混合路由模式已启用") - else: - graph.add_edge("init_state", "react_reason") - info(f"✅ [图构建] 纯 React 模式") - - -def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> None: - """添加 React 循环阶段的边""" - subgraph_names = list(subgraph_nodes.keys()) - - # React 推理的条件分支 graph.add_conditional_edges( - "react_reason", - route_by_reasoning, + "agent", + should_continue, { - "rag_retrieve": "rag_retrieve", - "web_search": "web_search", - **{name: name for name in subgraph_names}, - "handle_error": "handle_error", - "llm_call": "llm_call" + "tools": "tools", + "finalize": "finalize" } ) - # RAG 检索后回到 react_reason,由意图识别决定下一步 - graph.add_edge("rag_retrieve", "react_reason") + # Tools 执行完回到 Agent + graph.add_edge("tools", "agent") - # 循环边(回到 react_reason) - loop_back_nodes = ["web_search", "handle_error"] + subgraph_names - for node_name in loop_back_nodes: - graph.add_edge(node_name, "react_reason") + # 5. 完成阶段 + if summarize_node: + def should_summarize(state: AgentState) -> str: + if state.turns_since_last_summary >= 5: + return "summarize" + return "finalize" - -def _add_finalize_edges(graph: StateGraph, llm_node, summarize_node) -> None: - """添加完成阶段的边""" - if llm_node is not None: - if summarize_node: - graph.add_conditional_edges( - "llm_call", - should_summarize, - { - "summarize": "summarize", - "finalize": "finalize" - } - ) - graph.add_edge("summarize", "finalize") - else: - graph.add_edge("llm_call", "finalize") + # 总结逻辑暂简化:先 finalize + graph.add_edge("agent", "finalize") + else: + graph.add_edge("agent", "finalize") graph.add_edge("finalize", END) - -# ========== 导出 ========== -__all__ = [ - "build_react_main_graph", -] + info("✅ [图构建] 极简 Agent 图构建完成") + return graph diff --git a/backend/app/main_graph/nodes/__init__.py b/backend/app/main_graph/nodes/__init__.py index 0ff79ae..84a0623 100644 --- a/backend/app/main_graph/nodes/__init__.py +++ b/backend/app/main_graph/nodes/__init__.py @@ -1,61 +1,21 @@ """ -主图节点模块导出 +主图节点模块导出 - 极简架构 """ -# React 模式节点 -from .reasoning import react_reason_node -from .web_search import web_search_node -from .error_handling import error_handling_node -from .routing import init_state_node, route_by_reasoning, should_summarize -from .llm_call import create_dynamic_llm_call_node -from .rag_nodes import rag_retrieve_node - # 记忆节点 from .retrieve_memory import create_retrieve_memory_node from .memory_trigger import memory_trigger_node, set_mem0_client from .summarize import create_summarize_node -from .finalize import finalize_node -# 混合路由节点 -from .hybrid_router import ( - hybrid_router_node, - route_from_hybrid_decision, - check_fast_path_success, -) -from .fast_paths import ( - fast_chitchat_node, - fast_rag_node, - fast_tool_node, -) - -# 通用工具 -from ._utils import dispatch_custom_event, make_react_event +# 新架构节点 +from .agent import create_agent_node __all__ = [ - # React 模式节点 - "init_state_node", - "react_reason_node", - "web_search_node", - "error_handling_node", - "route_by_reasoning", - "should_summarize", - "create_dynamic_llm_call_node", - "rag_retrieve_node", - "rag_re_retrieve_node", # 记忆节点 "create_retrieve_memory_node", "memory_trigger_node", "set_mem0_client", "create_summarize_node", - "finalize_node", - # 混合路由节点 - "hybrid_router_node", - "route_from_hybrid_decision", - "check_fast_path_success", - "fast_chitchat_node", - "fast_rag_node", - "fast_tool_node", - # 通用工具 - "dispatch_custom_event", - "make_react_event", + # 新架构节点 + "create_agent_node", ] diff --git a/backend/app/main_graph/nodes/agent.py b/backend/app/main_graph/nodes/agent.py new file mode 100644 index 0000000..225468f --- /dev/null +++ b/backend/app/main_graph/nodes/agent.py @@ -0,0 +1,89 @@ +"""Agent 节点:核心推理与工具调用""" + +from typing import Dict, Any, Optional +from langchain_core.messages import SystemMessage +from langchain_core.runnables.config import RunnableConfig +from ..state import AgentState +from backend.app.logger import info, warning + + +# 系统提示词(从 main_graph_builder.py 搬过来) +SYSTEM_PROMPT = """你是一个智能助手,可以使用多种工具完成复杂任务。你必须用中文回复。 + +## 核心工具与能力 +你可以使用以下工具(函数),但只能在真正需要时调用,禁止无意义的测试调用或重复调用: +1. rag_search – 从内部知识库中检索文档,输入为优化后的查询字符串。 +2. web_search – 联网搜索获取最新信息,输入为搜索关键词。 +3. contact_lookup – 查询企业通讯录,输入姓名、部门或邮箱等。 +4. dictionary_lookup – 翻译单词、查询词典或提取术语。 +5. news_analysis – 获取或分析新闻资讯。 + +## 工作流程(ReAct 决策闭环) +你必须严格按照思考 → 行动 → 观察的闭环来处理每个请求,具体规则如下: + +### 1. 初始决策 +- 如果用户的问题很明确且你已有足够内部知识,可以直接回答,无需调用任何工具。 +- 如果需要外部信息,请按以下优先级选择工具: + - 优先使用 rag_search。 + - 若第一次 rag_search 返回的结果不相关或质量低,你可以改写查询关键词再次调用 rag_search(最多重复一次)。 + - 如果两次 rag_search 均无法获得满意信息,或者用户明确要求实时资讯,则必须切换为 web_search。 +- 遇到通讯录、词典、新闻类明确需求,直接调用对应的专用工具。 + +### 2. 观察与反思 +- 每次工具调用返回结果后,你必须先评估结果质量(内容是否相关、是否充分)。 +- 如果信息不足,根据上述规则决定下一步行动;如果信息足够,则直接生成最终答案,绝不再调用任何工具。 +- 在整个过程中,禁止使用工具返回的信息直接重复或编造来源,必须如实标注。 + +### 3. 结束条件 +当你认为已经拥有足够信息回答用户时,输出最终回复并停止调用工具。若连续调用工具超过 5 轮仍未解决,也必须基于当前收集到的信息给出最佳回答并说明局限性。 + +## 回答规范 +1. 来源标注:回答开头用方括号注明信息来源,如多处来源按使用顺序列出: + - 知识库:【知识库:相关文档主题】 + - 联网搜索:【联网搜索:来源网站或摘要】 +2. 思维链:对于需要复杂推理的问题,请将推理过程放在 ... 标签内,并置于回答最前面(来源标注之前)。 +3. 内容要求:回答应重点突出、条理清晰,优先结合用户背景信息进行个性化;若无任何可靠依据,如实说明“暂时无法回答”。 + +## 特别注意 +- 不要向用户暴露任何工具调用的技术细节(如参数、函数名)。 +- 如果用户只是闲聊、问候或道别,直接友好回复,严禁调用任何工具。 +- 所有联网搜索必须以获取帮助用户为目的,不得搜索无关内容。 + +现在,请遵循以上规则处理用户的每一次输入。记住:思考 → 行动 → 观察 → 直到完成。""" + + +def create_agent_node(llm_with_tools, llm): + """创建 Agent 节点函数""" + + async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: + """ + Agent 节点:调用带工具的 LLM,处理步数限制 + + Args: + state: 当前状态 + config: 运行配置 + + Returns: + 状态更新字典 + """ + info(f"[Agent] 第 {state.current_step} 步推理") + + # 组装完整消息:系统提示 + 历史消息 + full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + state.messages + + # 判断是否达到步数上限 + if state.current_step >= state.max_steps: + info(f"[Agent] 达到步数上限 {state.max_steps},强制结束,不绑定工具") + llm_no_tools = llm.bind_tools([]) + response = await llm_no_tools.ainvoke(full_messages) + else: + response = await llm_with_tools.ainvoke(full_messages) + + # 返回状态更新(注意:不原地修改 state,返回字典让 LangGraph 处理 + return { + "messages": [response], + "current_step": state.current_step + 1, + "llm_calls": state.llm_calls + 1 + } + + return agent_node diff --git a/backend/app/main_graph/nodes/finalize_new.py b/backend/app/main_graph/nodes/finalize_new.py new file mode 100644 index 0000000..e418caf --- /dev/null +++ b/backend/app/main_graph/nodes/finalize_new.py @@ -0,0 +1,59 @@ +""" +完成事件节点模块(新架构版本) +负责发送完成事件 +""" + +from typing import Any, Dict +from datetime import datetime + +# 本地模块 +from .state import AgentState +from backend.app.logger import info, warning + +from langchain_core.runnables.config import RunnableConfig + + +async def finalize_node(state: AgentState, config: RunnableConfig) -> Dict[str, Any]: + """ + 完成事件节点(新架构版本) + + Args: + state: 当前对话状态 + config: 运行时配置 + + Returns: + 空(不修改状态) + """ + info("[Finalize] 进入完成节点") + + try: + # 获取流式写入器并发送完成事件 + from backend.app.main_graph.config import get_stream_writer + writer = get_stream_writer() + + # 提取最后的回复 + final_reply = "" + if state.messages: + last_msg = state.messages[-1] + final_reply = last_msg.content if hasattr(last_msg, 'content') else str(last_msg) + + # 只在 writer 存在且不是 noop 时才发送 + if writer and hasattr(writer, '__call__'): + try: + writer({ + "type": "custom", + "data": { + "type": "done", + "token_usage": state.last_token_usage, + "elapsed_time": state.last_elapsed_time, + "final_result": final_reply + } + }) + info("🏁 [完成事件] 已发送完成事件") + except Exception as e: + warning(f"⚠️ [完成事件] 发送完成事件失败 (非致命): {e}") + except Exception as e: + warning(f"⚠️ [完成事件] 处理失败 (非致命): {e}") + + info("[Finalize] 离开完成节点") + return {} diff --git a/backend/app/main_graph/nodes/memory_trigger.py b/backend/app/main_graph/nodes/memory_trigger.py index 7d309af..b47a2db 100644 --- a/backend/app/main_graph/nodes/memory_trigger.py +++ b/backend/app/main_graph/nodes/memory_trigger.py @@ -1,6 +1,6 @@ from typing import Any, Dict from langchain_core.runnables.config import RunnableConfig -from ...main_graph.state import MainGraphState +from ..state import AgentState from ...memory.mem0_client import Mem0Client from backend.app.logger import info @@ -14,7 +14,7 @@ def set_mem0_client(client: Mem0Client): _mem0_client = client -async def memory_trigger_node(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: +async def memory_trigger_node(state: AgentState, config: RunnableConfig) -> Dict[str, Any]: """检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储""" if _mem0_client is None: return {} diff --git a/backend/app/main_graph/nodes/retrieve_memory.py b/backend/app/main_graph/nodes/retrieve_memory.py index 3837796..8f68f58 100644 --- a/backend/app/main_graph/nodes/retrieve_memory.py +++ b/backend/app/main_graph/nodes/retrieve_memory.py @@ -6,7 +6,7 @@ from typing import Any, Dict # 本地模块 -from ...main_graph.state import MainGraphState +from ...main_graph.state import AgentState from ...memory.mem0_client import Mem0Client from ...utils.logging import log_state_change from backend.app.logger import debug @@ -25,7 +25,7 @@ def create_retrieve_memory_node(mem0_client: Mem0Client): from langchain_core.runnables.config import RunnableConfig - async def retrieve_memory(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: + async def retrieve_memory(state: AgentState, config: RunnableConfig) -> Dict[str, Any]: """ 记忆检索节点 - 使用 Mem0 diff --git a/backend/app/main_graph/nodes/summarize.py b/backend/app/main_graph/nodes/summarize.py index 3836de7..5d70e95 100644 --- a/backend/app/main_graph/nodes/summarize.py +++ b/backend/app/main_graph/nodes/summarize.py @@ -6,7 +6,7 @@ from typing import Any, Dict # 本地模块 -from ...main_graph.state import MainGraphState +from ...main_graph.state import AgentState from ...memory.mem0_client import Mem0Client from ...utils.logging import log_state_change from backend.app.logger import debug, info, error, warning @@ -25,7 +25,7 @@ def create_summarize_node(mem0_client: Mem0Client): from langchain_core.runnables.config import RunnableConfig - async def summarize_conversation(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: + async def summarize_conversation(state: AgentState, config: RunnableConfig) -> Dict[str, Any]: """ 记忆存储节点 - 使用 Mem0 diff --git a/backend/app/main_graph/state.py b/backend/app/main_graph/state.py index a6aad6b..bf06b40 100644 --- a/backend/app/main_graph/state.py +++ b/backend/app/main_graph/state.py @@ -1,148 +1,37 @@ """ -主图状态定义 - React 模式增强版 -Main Graph State Definition - React Mode Enhanced +极简 Agent 状态定义 - 只保留真正需要的字段 -字段分类说明: -- 持久化字段:跨轮次保留,不重置 -- 临时字段:每轮对话开始时重置 +保留的核心字段: +- messages: 对话历史(LangGraph 必需) +- user_id: 用户标识 +- 记忆相关:turns_since_last_summary, memory_context +- 安全限制:current_step, max_steps +- 统计:llm_calls, last_token_usage, last_elapsed_time """ -from enum import Enum, auto -from typing import Optional, Dict, Any, Annotated, Sequence, TypedDict, List +from typing import Annotated, Sequence, Optional, Dict, Any from dataclasses import dataclass, field from langgraph.graph import add_messages from langchain_core.messages import BaseMessage -# ========== 枚举类型 ========== -class CurrentAction(Enum): - """主图当前操作类型""" - NONE = auto() - GENERAL_CHAT = auto() - NEWS_ANALYSIS = auto() - DICTIONARY = auto() - CONTACT = auto() - - -class ErrorSeverity(Enum): - """错误严重程度""" - INFO = auto() # 信息级别,继续执行 - WARNING = auto() # 警告级别,可以重试 - ERROR = auto() # 错误级别,需要处理 - FATAL = auto() # 致命错误,终止执行 - - @dataclass -class ErrorRecord: - """错误记录""" - error_type: str - error_message: str - severity: ErrorSeverity = ErrorSeverity.ERROR - source: str = "" # 来源:哪个节点/子图/工具 - timestamp: str = "" - retry_count: int = 0 # 已重试次数 - max_retries: int = 3 # 最大重试次数 - context: Dict[str, Any] = field(default_factory=dict) # 错误上下文 - - -@dataclass -class ReactReasoningState: - """React 推理状态""" - last_reasoning: Optional[Dict[str, Any]] = None - reasoning_result: Optional[Any] = None # 实际类型是 ReasoningResult - - -@dataclass -class HybridRouterState: - """混合路由状态""" - decision: Optional[Any] = None # 实际类型是 HybridRouterResult - start_time: Optional[str] = None - - -@dataclass -class FastPathState: - """快速路径状态""" - chitchat_success: bool = False - rag_success: bool = False - tool_success: bool = False - failed: bool = False - fail_reason: str = "" - - -@dataclass -class MainGraphState: - """ - 主图状态定义 - - 字段分类: - - 持久化字段:跨轮次保留,不重置 - - 临时字段:每轮对话开始时重置 - """ - - # ================================================== - # 持久化字段(每轮保留) - # ================================================== +class AgentState: + """Agent 状态""" + # ========== 核心持久化字段(必需) ========== messages: Annotated[Sequence[BaseMessage], add_messages] = field(default_factory=list) - turns_since_last_summary: int = 0 # 距离上次总结的轮数 user_id: str = "" - # ================================================== - # 临时字段(每轮重置) - # ================================================== + # ========== 安全限制字段(防止无限循环) ========== + max_steps: int = 10 + current_step: int = 0 - # 主图控制字段 - user_query: str = "" - current_action: CurrentAction = CurrentAction.NONE - current_model: str = "" # 本次请求使用的模型 - intent_confidence: float = 0.0 + # ========== 记忆相关字段(保留) ========== + turns_since_last_summary: int = 0 + memory_context: str = "" - # React 推理专用字段 - reasoning_step: int = 0 - max_steps: int = 10 # 避免过长循环 - last_action: str = "" - reasoning_history: List[Dict[str, Any]] = field(default_factory=list) - - # RAG 相关字段 - rag_context: str = "" - rag_retrieved: bool = False - rag_docs: List[Dict[str, Any]] = field(default_factory=list) - rag_confidence: float = 0.0 # RAG 检索置信度 (0.0-1.0) - rag_attempts: int = 0 # RAG 检索次数统计 - - # 联网搜索相关字段 - web_search_results: List[str] = field(default_factory=list) - - # 错误处理字段 - errors: List[ErrorRecord] = field(default_factory=list) - current_error: Optional[ErrorRecord] = None - retry_action: Optional[str] = None - error_message: str = "" - - # 子图结果字段 - news_result: Optional[Dict[str, Any]] = None - dictionary_result: Optional[Dict[str, Any]] = None - contact_result: Optional[Dict[str, Any]] = None - - # 执行状态 - current_phase: str = "init" - final_result: str = "" - success: bool = False - - # 元数据 - start_time: Optional[str] = None - end_time: Optional[str] = None - - # 结构化状态 - react_reasoning: ReactReasoningState = field(default_factory=ReactReasoningState) - hybrid_router: HybridRouterState = field(default_factory=HybridRouterState) - fast_path: FastPathState = field(default_factory=FastPathState) - - # 统计字段(用于反馈) + # ========== 统计字段(保留) ========== llm_calls: int = 0 last_token_usage: Dict[str, Any] = field(default_factory=dict) last_elapsed_time: float = 0.0 - memory_context: str = "" # 记忆检索结果 - - # 向后兼容(保留但不推荐使用) - debug_info: Dict[str, Any] = field(default_factory=dict) diff --git a/backend/app/model_services/chat_services.py b/backend/app/model_services/chat_services.py index 349712f..4ce7b64 100644 --- a/backend/app/model_services/chat_services.py +++ b/backend/app/model_services/chat_services.py @@ -5,11 +5,13 @@ 1. Local VLLM 服务:本地 gemma-4-E4B-it 模型 2. Zhipu AI:智谱 glm-5.1 模型 3. DeepSeek:deepseek-v4-pro 模型 +4. Baosi API:ops4.7 模型 主要功能: - LocalVLLMChatProvider:本地 VLLM 服务提供者 - ZhipuChatProvider:智谱 API 服务提供者 - DeepSeekChatProvider:DeepSeek API 服务提供者 +- BaosiChatProvider:Baosi API 服务提供者 - get_chat_service():获取默认服务(带自动降级) - get_all_chat_services():获取所有可用模型服务(用于多模型切换) """ @@ -28,6 +30,9 @@ from backend.app.config import ( LLM_API_KEY, ZHIPUAI_API_KEY, DEEPSEEK_API_KEY, + BAOSI_API_KEY, + BAOSI_API_BASE, + BAOSI_MODEL, LOCAL_MODEL_NAME ) @@ -194,6 +199,59 @@ class DeepSeekChatProvider(BaseServiceProvider[BaseChatModel]): return self._service_instance +class BaosiChatProvider(BaseServiceProvider[BaseChatModel]): + """ + Baosi API 生成式大模型服务提供者 + """ + + def __init__(self, model: str = None): + super().__init__("baosi_chat") + self._model = model or BAOSI_MODEL + self._base_url = BAOSI_API_BASE + self._api_key = BAOSI_API_KEY + + def is_available(self) -> bool: + """ + 检查 Baosi API 服务是否可用 + + Returns: + bool: 服务是否可用 + """ + if not self._api_key: + logger.warning("BAOSI_API_KEY 未配置") + return False + + try: + logger.info(f"Baosi API 服务配置正确,准备使用: {self._model}") + return True + except Exception as e: + logger.warning(f"Baosi API 服务不可用: {e}") + return False + + def get_service(self) -> BaseChatModel: + """ + 获取 Baosi API 服务 + + Returns: + BaseChatModel: LangChain 兼容的 ChatModel 实例 + """ + if self._service_instance is None: + from langchain_openai import ChatOpenAI + from pydantic import SecretStr + + self._service_instance = ChatOpenAI( + base_url=self._base_url, + api_key=SecretStr(self._api_key) if self._api_key else SecretStr(""), + model=self._model, + temperature=0.1, + max_tokens=4096, + timeout=120.0, + max_retries=2, + streaming=False, # Baosi API 可能不兼容 streaming,设置为 False + ) + return self._service_instance + + # ========== 轻量级模型 Provider ========== class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]): @@ -276,6 +334,7 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]): # 全局服务映射表 - 名称 -> Provider CHAT_PROVIDERS: Dict[str, Callable[[], BaseServiceProvider[BaseChatModel]]] = { "local": lambda: LocalVLLMChatProvider(), + "baosi": lambda: BaosiChatProvider(), "zhipu": lambda: ZhipuChatProvider(), "deepseek": lambda: DeepSeekChatProvider(), } @@ -284,14 +343,14 @@ CHAT_PROVIDERS: Dict[str, Callable[[], BaseServiceProvider[BaseChatModel]]] = { def get_chat_service() -> BaseChatModel: """ 获取默认的生成式大模型服务(带自动降级) - 优先顺序: local -> zhipu -> deepseek + 优先顺序: local → baosi → zhipu → deepseek Returns: BaseChatModel: LangChain 兼容的 ChatModel 实例 """ def _create_chain(): primary = LocalVLLMChatProvider() - fallbacks = [ZhipuChatProvider(), DeepSeekChatProvider()] + fallbacks = [BaosiChatProvider(), ZhipuChatProvider(), DeepSeekChatProvider()] return FallbackServiceChain(primary, fallbacks) chain = SingletonServiceManager.get_or_create("chat_service_chain", _create_chain) diff --git a/backend/app/tools/__init__.py b/backend/app/tools/__init__.py new file mode 100644 index 0000000..4cec476 --- /dev/null +++ b/backend/app/tools/__init__.py @@ -0,0 +1,188 @@ +""" +Agent Tools - 封装所有功能为 @tool 函数 +""" + +from langchain_core.tools import tool +from typing import Optional +from backend.app.logger import info + + +# ====== RAG Pipeline(复用现有) +_rag_pipeline = None + + +def _get_rag_pipeline(): + """获取 RAG Pipeline 实例(复用 rag_nodes.py 的逻辑)""" + global _rag_pipeline + if _rag_pipeline is None: + from backend.app.rag.pipeline import RAGPipeline + _rag_pipeline = RAGPipeline( + num_queries=3, + rerank_top_n=5, + use_rerank=True, + return_parent_docs=True, + ) + return _rag_pipeline + + +@tool +async def rag_search(query: str) -> str: + """ + 检索知识库获取相关信息。 + + 当用户询问关于系统、业务、文档相关的问题时使用此工具。 + + Args: + query: 用户的问题或搜索关键词 + + Returns: + 检索到的相关文档内容 + """ + info(f"[RAG Tool] 开始检索: {query[:50]}...") + + try: + pipeline = _get_rag_pipeline() + documents = await pipeline.aretrieve(query) + rag_context = pipeline.format_context(documents) + + info(f"[RAG Tool] 检索完成,得到 {len(documents)} 个文档") + + if rag_context: + return rag_context + else: + return "知识库中没有找到相关内容。" + + except Exception as e: + info(f"[RAG Tool] 检索失败: {e}") + return f"知识库检索失败: {str(e)}" + + +@tool +def web_search(query: str) -> str: + """ + 联网搜索获取最新信息。 + + 当用户询问实时新闻、热点事件、最新资讯或知识库中没有的内容时使用此工具。 + + Args: + query: 搜索关键词 + + Returns: + 搜索结果摘要 + """ + info(f"[WebSearch Tool] 开始搜索: {query[:50]}...") + + try: + from backend.app.core import web_search as core_web_search + search_result = core_web_search(query, max_results=5) + + info(f"[WebSearch Tool] 搜索完成") + return search_result + + except Exception as e: + info(f"[WebSearch Tool] 搜索失败: {e}") + return f"联网搜索失败: {str(e)}" + + +# ====== 子图工具封装器 +async def _invoke_subgraph(subgraph_builder, query: str, state_class) -> str: + """ + 通用子图调用函数 + + Args: + subgraph_builder: 子图构建函数 + query: 用户查询 + state_class: 子图状态类 + + Returns: + 子图执行结果 + """ + try: + graph = subgraph_builder() + compiled_graph = graph.compile() + + # 构造初始状态 + initial_state = state_class(user_query=query) + + # 调用子图 + result = await compiled_graph.ainvoke(initial_state) + + # 返回结果 + return result.get("final_result", "子图执行完成") + + except Exception as e: + info(f"[Subgraph Tool] 执行失败: {e}") + return f"执行失败: {str(e)}" + + +@tool +async def contact_lookup(query: str) -> str: + """ + 查询通讯录信息。 + + 当用户询问联系人、邮箱、联系方式、发送邮件时使用此工具。 + + Args: + query: 用户查询,描述需要的操作 + + Returns: + 联系人信息或操作结果 + """ + info(f"[Contact Tool] 查询: {query[:50]}...") + + from backend.app.subgraphs.contact.graph import build_contact_subgraph + from backend.app.subgraphs.contact.state import ContactState + + return await _invoke_subgraph(build_contact_subgraph, query, ContactState) + + +@tool +async def dictionary_lookup(word: str) -> str: + """ + 查询词典,获取单词释义、翻译等。 + + 当用户询问单词、翻译、生词时使用此工具。 + + Args: + word: 需要查询的单词或短语 + + Returns: + 单词释义和翻译 + """ + info(f"[Dictionary Tool] 查询: {word}") + + from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph + from backend.app.subgraphs.dictionary.state import DictionaryState + + return await _invoke_subgraph(build_dictionary_subgraph, word, DictionaryState) + + +@tool +async def news_analysis(topic: str) -> str: + """ + 分析热点新闻和资讯。 + + 当用户询问新闻分析、热点解读时使用此工具。 + + Args: + topic: 新闻主题或关键词 + + Returns: + 新闻分析结果 + """ + info(f"[NewsAnalysis Tool] 分析: {topic}") + + from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph + from backend.app.subgraphs.news_analysis.state import NewsAnalysisState + + return await _invoke_subgraph(build_news_analysis_subgraph, topic, NewsAnalysisState) + + +# ====== 导出所有工具 +ALL_TOOLS = [ + rag_search, + web_search, + contact_lookup, + dictionary_lookup, + news_analysis, +] diff --git a/backend/app/utils/logging.py b/backend/app/utils/logging.py index 4ce7bc4..45b08db 100644 --- a/backend/app/utils/logging.py +++ b/backend/app/utils/logging.py @@ -3,12 +3,12 @@ LangGraph 节点日志工具模块 提供状态流转追踪和 LLM 输入输出打印功能 """ +from typing import Any from backend.app.config import ENABLE_GRAPH_TRACE from backend.app.logger import debug, info -from ..main_graph.state import MainGraphState -def log_state_change(node_name: str, state: MainGraphState, prefix: str = "进入"): +def log_state_change(node_name: str, state: Any, prefix: str = "进入"): """ 记录状态变化日志 @@ -53,5 +53,5 @@ def print_llm_input(prompt_value): content_preview = str(msg.content) # 完整输出 debug(f" [{i}] {msg.type.upper():10s}: {content_preview}") debug("\n" + "="*80 + "\n") - - return prompt_value \ No newline at end of file + + return prompt_value diff --git a/tools/test/test_baosi_provider.py b/tools/test/test_baosi_provider.py new file mode 100644 index 0000000..f081133 --- /dev/null +++ b/tools/test/test_baosi_provider.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +""" +简单测试:验证 Baosi API 是否正常工作 +""" + +import sys +from pathlib import Path + +# 添加项目路径 +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) + +import asyncio +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv(project_root / ".env") + +from backend.app.model_services.chat_services import BaosiChatProvider + + +async def test_baosi_provider(): + """测试 Baosi API Provider""" + print("=" * 60) + print("测试 Baosi API Provider") + print("=" * 60) + + # 创建 provider + provider = BaosiChatProvider() + + # 检查是否可用 + print(f"\n检查是否可用: {provider.is_available()}") + + try: + # 获取 LLM + llm = provider.get_service() + print(f"\n✓ LLM 获取成功: {type(llm)}") + + # 测试简单调用 + print(f"\n测试简单调用...") + from langchain_core.messages import HumanMessage + response = await llm.ainvoke([ + HumanMessage(content="你好,请简单介绍一下你自己") + ]) + print(f"\n✓ 响应成功:") + print(f" 响应类型: {type(response)}") + print(f" 响应内容:\n{response.content}") + + return True + + except Exception as e: + print(f"\n✗ 测试失败: {e}") + import traceback + print(f"堆栈:\n{traceback.format_exc()}") + return False + + +if __name__ == "__main__": + asyncio.run(test_baosi_provider()) diff --git a/tools/test/test_minimal_agent.py b/tools/test/test_minimal_agent.py new file mode 100644 index 0000000..393c56f --- /dev/null +++ b/tools/test/test_minimal_agent.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +""" +极简 Agent 架构测试 - 适配新架构 +""" + +import sys +from pathlib import Path + +# 添加项目路径 +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) + +import asyncio +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv(project_root / ".env") + +from backend.app.main_graph.state import AgentState +from backend.app.main_graph.main_graph_builder import build_agent_graph +from backend.app.model_services.chat_services import get_cached_chat_services + + +# ========== 测试用例配置 ========== +TEST_CASES = [ + # 测试1: 简单闲聊 + { + "name": "闲聊测试", + "query": "你好!", + "description": "测试简单对话" + }, + # 测试2: 知识查询 + { + "name": "知识库测试", + "query": "吕布的事迹?", + "description": "测试 RAG 工具调用" + }, + # 测试3: 简单问题 + { + "name": "简单问答测试", + "query": "介绍一下你自己", + "description": "测试直接回答能力" + }, +] + + +async def setup_test_environment(): + """设置测试环境""" + print("=" * 60) + print("设置测试环境...") + print("=" * 60) + + # 获取 LLM 服务 + chat_services = get_cached_chat_services() + if not chat_services: + raise RuntimeError("没有可用的 LLM 服务") + + print(f"✓ 可用模型: {list(chat_services.keys())}") + + # 选择 zhipu 或 deepseek 作为测试模型,避免 Baosi API 的问题 + test_model = None + if "zhipu" in chat_services: + test_model = "zhipu" + print(f"✓ 选择 zhipu 作为测试模型") + elif "deepseek" in chat_services: + test_model = "deepseek" + print(f"✓ 选择 deepseek 作为测试模型") + elif "local" in chat_services: + test_model = "local" + print(f"✓ 选择 local 作为测试模型") + else: + # 用第一个可用的 + test_model = list(chat_services.keys())[0] + print(f"✓ 选择 {test_model} 作为测试模型") + + # 只保留选中的模型,方便测试 + test_chat_services = {test_model: chat_services[test_model]} + + # 构建图(使用新的 build_agent_graph) + graph_builder = build_agent_graph( + chat_services=test_chat_services + ) + graph = graph_builder.compile() + + print(f"✓ 图构建完成") + print() + + return graph, test_chat_services + + +def create_test_state(query: str, user_id: str = "test_user") -> dict: + """创建测试状态""" + from langchain_core.messages import HumanMessage + + return { + "messages": [HumanMessage(content=query)], + "user_id": user_id, + } + + +async def run_single_test(graph, test_case: dict) -> dict: + """运行单个测试""" + name = test_case["name"] + query = test_case["query"] + description = test_case["description"] + + print(f"\n{'=' * 60}") + print(f"测试: {name}") + print(f"描述: {description}") + print(f"查询: {query}") + print(f"{'=' * 60}") + + try: + # 创建初始状态 + input_state = create_test_state(query) + + # 配置 + config = { + "configurable": { + "thread_id": f"test_{name}" + } + } + + # 执行图 + print("开始执行图...") + result = await graph.ainvoke(input_state, config=config) + + # 提取最终回复 + reply = "" + if result.get("messages"): + reply = result["messages"][-1].content + + print(f"\n✓ 执行完成") + print(f"最终回复: {reply[:500]}{'...' if len(reply) > 500 else ''}") + + return { + "name": name, + "success": True, + "reply": reply, + "state": result + } + + except Exception as e: + print(f"\n✗ 测试失败: {e}") + import traceback + print(f"堆栈: {traceback.format_exc()}") + return { + "name": name, + "success": False, + "error": str(e) + } + + +async def main(): + """主函数""" + print("\n" + "=" * 60) + print("极简 Agent 架构测试") + print("=" * 60) + + try: + # 设置环境 + graph, chat_services = await setup_test_environment() + + # 运行所有测试 + results = [] + for test_case in TEST_CASES: + result = await run_single_test(graph, test_case) + results.append(result) + + # 稍微间隔一下 + await asyncio.sleep(0.5) + + # 总结 + print("\n" + "=" * 60) + print("测试总结") + print("=" * 60) + + total = len(results) + passed = sum(1 for r in results if r["success"]) + failed = total - passed + + print(f"\n总测试数: {total}") + print(f"通过: {passed}") + print(f"失败: {failed}") + + print("\n详细结果:") + for result in results: + status = "✓ 通过" if result["success"] else "✗ 失败" + print(f" {result['name']}: {status}") + + print("\n" + "=" * 60) + if failed == 0: + print("🎉 所有测试通过!") + else: + print(f"⚠️ 有 {failed} 个测试失败") + print("=" * 60) + + except Exception as e: + print(f"\n测试运行失败: {e}") + import traceback + print(traceback.format_exc()) + + +if __name__ == "__main__": + asyncio.run(main())