From 4ee769a79f375c163575b01dac7c780ea0e1edf0 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Fri, 1 May 2026 14:01:48 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=9E=B6=E6=9E=84=EF=BC=9A?= =?UTF-8?q?=E6=81=A2=E5=A4=8D=E7=BB=9F=E4=B8=80=E7=9A=84=20llm=5Fcall=20?= =?UTF-8?q?=E8=8A=82=E7=82=B9=EF=BC=8C=E7=A7=BB=E9=99=A4=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E7=9A=84=20final=5Fresponse=20=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/agent/agent_service.py | 2 +- backend/app/main_graph/nodes/llm_call.py | 80 ++++++++----- backend/app/main_graph/nodes/react_nodes.py | 112 ++---------------- .../main_graph/utils/main_graph_builder.py | 61 +++++----- 4 files changed, 94 insertions(+), 161 deletions(-) diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 8984536..a1e4aa0 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -46,7 +46,7 @@ class AIAgentService: for name, llm in chat_services.items(): try: info(f"🔄 初始化模型 '{name}'...") - graph = build_react_main_graph().compile(checkpointer=self.checkpointer) + graph = build_react_main_graph(llm=llm, tools=self.tools).compile(checkpointer=self.checkpointer) self.graphs[name] = graph info(f"✅ 模型 '{name}' 初始化成功") except Exception as e: diff --git a/backend/app/main_graph/nodes/llm_call.py b/backend/app/main_graph/nodes/llm_call.py index c391fef..aee7053 100644 --- a/backend/app/main_graph/nodes/llm_call.py +++ b/backend/app/main_graph/nodes/llm_call.py @@ -9,54 +9,68 @@ from langchain_core.language_models import BaseLLM from langchain_core.messages import AIMessage # 本地模块 -from app.main_graph.state import MessagesState +from app.main_graph.state import MainGraphState from app.agent.prompts import create_system_prompt from app.utils.logging import log_state_change from app.logger import debug, info, error -def create_llm_call_node(llm: BaseLLM, tools: list): +def create_llm_call_node(llm, tools: list): """ 工厂函数:创建 LLM 调用节点 - + Args: llm: LangChain LLM 实例 tools: 工具列表 - + Returns: 异步节点函数 """ # 构建调用链 prompt = create_system_prompt(tools) llm_with_tools = llm.bind_tools(tools) - + # 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历 chain = prompt | llm_with_tools - + from langchain_core.runnables.config import RunnableConfig - - async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + + async def call_llm(state: MainGraphState, config: RunnableConfig) -> Dict[str, Any]: """ LLM 调用节点(异步方法) - + Args: state: 当前对话状态 config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息 - + Returns: 更新后的状态字典 """ log_state_change("llm_call", state, "进入") - - memory_context = state.get("memory_context", "暂无用户信息") + + memory_context = getattr(state, "memory_context", "暂无用户信息") start_time = time.time() - + try: + # 添加 RAG 上下文到消息 + messages_with_context = list(state.messages) + if state.rag_context: + from langchain_core.messages import SystemMessage + rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}") + inserted = False + for i, msg in enumerate(messages_with_context): + if msg.type == "human": + messages_with_context.insert(i, rag_system_msg) + inserted = True + break + if not inserted: + messages_with_context.insert(0, rag_system_msg) + # 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。 # LangGraph 会自动监听这期间产生的所有 token。 chunks = [] async for chunk in chain.astream( { - "messages": state["messages"], + "messages": messages_with_context, "memory_context": memory_context }, config=config @@ -70,14 +84,14 @@ def create_llm_call_node(llm: BaseLLM, tools: list): response = response + chunk else: response = AIMessage(content="") - + elapsed_time = time.time() - start_time - + # 提取 token 用量(兼容不同 LLM 提供商的元数据格式) token_usage = {} input_tokens = 0 output_tokens = 0 - + # 尝试从 response_metadata 中提取 if hasattr(response, 'response_metadata') and response.response_metadata: meta = response.response_metadata @@ -85,18 +99,18 @@ def create_llm_call_node(llm: BaseLLM, tools: list): token_usage = meta['token_usage'] elif 'usage' in meta: token_usage = meta['usage'] - + # 尝试从 additional_kwargs 中提取 if not token_usage and hasattr(response, 'additional_kwargs'): add_kwargs = response.additional_kwargs if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']: token_usage = add_kwargs['llm_output']['token_usage'] - + # 提取具体的 token 数值 if token_usage: input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0)) output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0)) - + # 打印 LLM 的完整输出 debug("\n" + "="*80) debug("📥 [LLM输出] 大模型返回的完整响应:") @@ -111,18 +125,21 @@ def create_llm_call_node(llm: BaseLLM, tools: list): if token_usage: debug(f"📋 [LLM统计] 详细用量: {token_usage}") debug("="*80 + "\n") - + result = { "messages": [response], - "llm_calls": state.get('llm_calls', 0) + 1, + "llm_calls": getattr(state, 'llm_calls', 0) + 1, "last_token_usage": token_usage, "last_elapsed_time": elapsed_time, - "turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器 + "turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1, + "final_result": response.content, + "success": True, + "current_phase": "done" } - + log_state_change("llm_call", {**state, **result}, "离开") return result - + except Exception as e: elapsed_time = time.time() - start_time error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)") @@ -131,20 +148,23 @@ def create_llm_call_node(llm: BaseLLM, tools: list): import traceback traceback.print_exc() debug("="*80 + "\n") - + # 返回一个友好的错误消息 error_response = AIMessage( content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。" ) error_result = { "messages": [error_response], - "llm_calls": state.get('llm_calls', 0), + "llm_calls": getattr(state, 'llm_calls', 0), "last_token_usage": {}, "last_elapsed_time": elapsed_time, - "turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器 + "turns_since_last_summary": getattr(state, 'turns_since_last_summary', 0) + 1, + "final_result": "抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。", + "success": False, + "current_phase": "done" } - + log_state_change("llm_call", state, "离开(异常)") return error_result - + return call_llm diff --git a/backend/app/main_graph/nodes/react_nodes.py b/backend/app/main_graph/nodes/react_nodes.py index c99b092..d55cfa9 100644 --- a/backend/app/main_graph/nodes/react_nodes.py +++ b/backend/app/main_graph/nodes/react_nodes.py @@ -3,7 +3,6 @@ React 模式节点模块 - 带超时和重试功能 包含: - react_reason_node: 使用 intent.py 进行推理 - error_handling_node: 错误处理节点 -- final_response_node: 最终回答节点 - init_state_node: 初始化状态节点 注意:为了兼容 LangGraph 的同步接口,我们保留了同步的 react_reason 调用 @@ -233,98 +232,6 @@ def error_handling_node(state: MainGraphState) -> MainGraphState: return state -# ========== 3. 最终回答节点 ========== - -from langchain_core.runnables.config import RunnableConfig -from langchain_core.messages import AIMessage - -async def final_response_node(state: MainGraphState, config: RunnableConfig) -> MainGraphState: - """ - 最终回答节点:调用 LLM 生成最终回答(支持流式输出) - """ - state.current_phase = "finalizing" - - # 如果已经有 final_result 了,直接返回 - if state.final_result: - state.current_phase = "done" - return state - - import time - start_time = time.time() - - try: - # 构建 LLM 调用链 - from app.agent.prompts import create_system_prompt - from app.model_services.chat_services import get_chat_service - from app.logger import debug, info - - llm = get_chat_service() - prompt = create_system_prompt(tools=[]) - chain = prompt | llm - - # 构建上下文 - memory_context = getattr(state, "memory_context", "暂无用户信息") - - # 添加 RAG 上下文到消息 - messages_with_context = list(state.messages) - if state.rag_context: - # 把 RAG 上下文作为系统消息添加 - from langchain_core.messages import SystemMessage - rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}") - # 插入到第一个用户消息之前 - inserted = False - for i, msg in enumerate(messages_with_context): - if msg.type == "human": - messages_with_context.insert(i, rag_system_msg) - inserted = True - break - if not inserted: - messages_with_context.insert(0, rag_system_msg) - - # 调用 LLM(流式输出) - chunks = [] - async for chunk in chain.astream( - { - "messages": messages_with_context, - "memory_context": memory_context - }, - config=config - ): - chunks.append(chunk) - - # 将所有 chunk 合并成最终的 AIMessage - if chunks: - response = chunks[0] - for chunk in chunks[1:]: - response = response + chunk - else: - response = AIMessage(content="") - - elapsed_time = time.time() - start_time - - # 更新状态 - state.messages.append(response) - state.final_result = response.content - state.success = True - state.current_phase = "done" - state.end_time = datetime.now().isoformat() - state.llm_calls = getattr(state, "llm_calls", 0) + 1 - - info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒") - - except Exception as e: - from app.logger import error - import traceback - error(f"❌ [LLM错误] 调用失败: {e}") - traceback.print_exc() - - state.final_result = "抱歉,模型暂时无法响应,请稍后再试。" - state.success = False - state.current_phase = "done" - - return state - - # ========== 4. 初始化状态节点 ========== def init_state_node(state: MainGraphState) -> MainGraphState: @@ -353,11 +260,11 @@ def route_by_reasoning(state: MainGraphState) -> str: """ # 先检查特殊情况 if state.current_phase == "max_steps_exceeded": - return "final_response" + return "llm_call" if state.current_phase == "error_handling" or state.current_error: return "handle_error" if state.current_phase == "finalizing" or state.current_phase == "done": - return "final_response" + return "llm_call" if state.current_phase == "retrying": if state.retry_action and "rag" in state.retry_action.lower(): return "rag_retrieve" @@ -367,7 +274,7 @@ def route_by_reasoning(state: MainGraphState) -> str: reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result") if not reasoning_result: - return "final_response" + return "llm_call" # 使用 intent.py 提供的路由函数 route = get_route_by_reasoning(reasoning_result) @@ -375,18 +282,18 @@ def route_by_reasoning(state: MainGraphState) -> str: # 映射到我们的节点名称 # 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致 route_mapping = { - "direct_response": "final_response", + "direct_response": "llm_call", "retrieve_rag": "rag_retrieve", "re_retrieve_rag": "rag_retrieve", - "web_search": "web_search", # ⭐ 新增:联网搜索 - "clarify": "final_response", - "call_tool": "final_response", # 暂时映射到 final_response,后续可以扩展 + "web_search": "web_search", + "clarify": "llm_call", + "call_tool": "llm_call", "contact": "contact_subgraph", "dictionary": "dictionary_subgraph", "news_analysis": "news_analysis_subgraph", } - return route_mapping.get(route, "final_response") + return route_mapping.get(route, "llm_call") # ========== 导出 ========== @@ -394,8 +301,7 @@ def route_by_reasoning(state: MainGraphState) -> str: __all__ = [ "init_state_node", "react_reason_node", - "web_search_node", # ⭐ 新增 + "web_search_node", "error_handling_node", - "final_response_node", "route_by_reasoning" ] diff --git a/backend/app/main_graph/utils/main_graph_builder.py b/backend/app/main_graph/utils/main_graph_builder.py index 081d316..75923a2 100644 --- a/backend/app/main_graph/utils/main_graph_builder.py +++ b/backend/app/main_graph/utils/main_graph_builder.py @@ -10,11 +10,11 @@ from app.main_graph.state import MainGraphState, CurrentAction from app.main_graph.nodes.react_nodes import ( init_state_node, react_reason_node, - web_search_node, # ⭐ 新增 + web_search_node, error_handling_node, - final_response_node, route_by_reasoning ) +from app.main_graph.nodes.llm_call import create_llm_call_node from app.main_graph.nodes.rag_nodes import rag_retrieve_node from app.subgraphs.contact import build_contact_subgraph from app.subgraphs.dictionary import build_dictionary_subgraph @@ -75,10 +75,10 @@ def wrap_subgraph_for_error_handling(subgraph, name: str): # ========== 主图构建 ========== -def build_react_main_graph() -> StateGraph: +def build_react_main_graph(llm=None, tools=None) -> StateGraph: """ 构建完整的 React 模式主图 - + 流程: START ↓ @@ -87,18 +87,23 @@ def build_react_main_graph() -> StateGraph: react_reason (推理) ←──────────────┐ ↓ │ 条件路由 │ - ├─→ rag_retrieve →───────────────┤ - ├─→ contact_subgraph →───────────┤ - ├─→ dictionary_subgraph →────────┤ - ├─→ news_analysis_subgraph →─────┤ - ├─→ handle_error → (重试或结束) ──┤ - └─→ final_response + ├─ rag_retrieve →───────────────┤ + ├─ contact_subgraph →───────────┤ + ├─ dictionary_subgraph →────────┤ + ├─ news_analysis_subgraph →─────┤ + ├─ handle_error → (重试或结束) ─┤ + └─ llm_call → END ↓ END """ # 创建图 graph = StateGraph(MainGraphState) + # 创建 llm_call 节点 + llm_node = None + if llm is not None: + llm_node = create_llm_call_node(llm, tools or []) + # ========== 添加节点 ========== # 1. 初始化节点 @@ -110,14 +115,15 @@ def build_react_main_graph() -> StateGraph: # 3. RAG 检索节点 graph.add_node("rag_retrieve", rag_retrieve_node) - # 4. 联网搜索节点 ⭐ 新增 + # 4. 联网搜索节点 graph.add_node("web_search", web_search_node) # 5. 错误处理节点 graph.add_node("handle_error", error_handling_node) - # 6. 最终回答节点 - graph.add_node("final_response", final_response_node) + # 6. LLM 调用节点(真正的大模型输出) + if llm_node is not None: + graph.add_node("llm_call", llm_node) # ========== 添加子图节点 ========== @@ -154,33 +160,34 @@ def build_react_main_graph() -> StateGraph: { # 检索分支 → 检索后回到推理 "rag_retrieve": "rag_retrieve", - - # 联网搜索分支 ⭐ 新增 + + # 联网搜索分支 "web_search": "web_search", - + # 子图分支 → 子图后回到推理 "contact_subgraph": "contact_subgraph", "dictionary_subgraph": "dictionary_subgraph", "news_analysis_subgraph": "news_analysis_subgraph", - + # 错误处理分支 "handle_error": "handle_error", - - # 最终回答分支 - "final_response": "final_response", + + # LLM 调用分支 → 直接输出给用户 + "llm_call": "llm_call" } ) - - # 4. 循环边:检索/搜索/子图/错误处理 后 → 回到推理 + + # 4. 循环边:检索/搜索/子图/错误处理后 → 回到推理 graph.add_edge("rag_retrieve", "react_reason") - graph.add_edge("web_search", "react_reason") # ⭐ 新增 + graph.add_edge("web_search", "react_reason") graph.add_edge("contact_subgraph", "react_reason") graph.add_edge("dictionary_subgraph", "react_reason") graph.add_edge("news_analysis_subgraph", "react_reason") - graph.add_edge("handle_error", "react_reason") # 错误处理后可能重试 - - # 5. 最终边:final_response → END - graph.add_edge("final_response", END) + graph.add_edge("handle_error", "react_reason") + + # 5. 最终边:llm_call → END + if llm_node is not None: + graph.add_edge("llm_call", END) return graph