From 6d7f8758d2cf06cefc09e691fcb4df9ed6249ea4 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Thu, 7 May 2026 02:05:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E7=9C=9F=E6=AD=A3?= =?UTF-8?q?=E7=9A=84=20LLM=20=E6=B5=81=E5=BC=8F=20token=20=E5=8F=91?= =?UTF-8?q?=E9=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/agent/agent_service.py | 159 ++++++++++++++++---------- backend/app/agent/stream_context.py | 9 ++ backend/app/main_graph/nodes/agent.py | 77 +++++++++++-- 3 files changed, 176 insertions(+), 69 deletions(-) create mode 100644 backend/app/agent/stream_context.py diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 6c1d2ca..d509e77 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -15,6 +15,7 @@ from ..model_services import get_cached_chat_services from ..main_graph.main_graph_builder import build_agent_graph from backend.app.logger import debug, info, warning, error from ..main_graph.state import AgentState +from .stream_context import token_queue_var class AIAgentService: @@ -251,71 +252,107 @@ class AIAgentService: chunk_count = 0 full_message_content = "" - try: - info(f"📡 开始调用 graph.astream()...") - - event_count = 0 - - async for chunk in self.graph.astream( - input_state, - config=config, - stream_mode=["messages", "updates"], - version="v2", - subgraphs=True - ): - chunk_count += 1 - chunk_type = chunk["type"] + # 创建 token 队列 + token_queue = asyncio.Queue() + + # 设置上下文变量 + token_queue_var.set(token_queue) + + # 事件:graph 执行完成 + graph_done = asyncio.Event() + graph_error = None + + async def run_graph(): + """在后台运行 graph,并把 chunk 放进队列,同时也处理 events""" + nonlocal chunk_count, full_message_content, graph_error + try: + info(f"📡 开始调用 graph.astream()...") - # 记录原始 chunk 信息(前 10 个和后 10 个) - if chunk_count <= 10 or chunk_count % 50 == 0: - info(f" [{chunk_count}] chunk_type={chunk_type}, data={type(chunk.get('data'))}") + event_count = 0 + + async for chunk in self.graph.astream( + input_state, + config=config, + stream_mode=["messages", "updates"], + version="v2", + subgraphs=True + ): + chunk_count += 1 + chunk_type = chunk["type"] + + # 记录原始 chunk 信息(前 10 个和后 10 个) + if chunk_count <= 10 or chunk_count % 50 == 0: + info(f" [{chunk_count}] chunk_type={chunk_type}, data={type(chunk.get('data'))}") - if chunk_type == "messages": - async for event in self._handle_message_chunk( - chunk, current_node, tool_calls_in_progress - ): - if event.get("type") == "_update_state": - current_node = event.get("current_node", current_node) - else: - event_count += 1 - # 记录前 10 个事件 - if event_count <= 10: - info(f" → yield event #{event_count}: {event.get('type')}") - - # 如果是 agent 节点的 token,收集完整消息 - if ( - event.get("type") == "llm_token" - and event.get("node") == "agent" - and "token" in event - ): - full_message_content += event["token"] - yield event + if chunk_type == "messages": + async for event in self._handle_message_chunk( + chunk, current_node, tool_calls_in_progress + ): + if event.get("type") == "_update_state": + nonlocal current_node + current_node = event.get("current_node", current_node) + else: + event_count += 1 + # 记录前 10 个事件 + if event_count <= 10: + info(f" → yield event #{event_count}: {event.get('type')}") + + # 如果是 agent 节点的 token,收集完整消息 + if ( + event.get("type") == "llm_token" + and event.get("node") == "agent" + and "token" in event + ): + full_message_content += event["token"] + await token_queue.put(event) - elif chunk_type == "updates": - async for event in self._handle_updates_chunk( - chunk, tool_calls_in_progress, actual_model_used - ): - if event.get("type") == "_update_state": - actual_model_used = event.get("actual_model_used", actual_model_used) - else: - event_count += 1 - if event_count <= 10: - info(f" → yield event #{event_count}: {event.get('type')}") - yield event + elif chunk_type == "updates": + async for event in self._handle_updates_chunk( + chunk, tool_calls_in_progress, actual_model_used + ): + if event.get("type") == "_update_state": + nonlocal actual_model_used + actual_model_used = event.get("actual_model_used", actual_model_used) + else: + event_count += 1 + if event_count <= 10: + info(f" → yield event #{event_count}: {event.get('type')}") + await token_queue.put(event) - # 完整消息集合完成后,一次性打印 - info(f"✅ graph.astream() 完成,共 {chunk_count} 个 chunks, {event_count} 个 events") - if full_message_content: - info(f"📄 完整消息内容: {repr(full_message_content)}") + # 完整消息集合完成后,一次性打印 + info(f"✅ graph.astream() 完成,共 {chunk_count} 个 chunks, {event_count} 个 events") + if full_message_content: + info(f"📄 完整消息内容: {repr(full_message_content)}") + + except Exception as e: + error(f"❌ 执行图时出错: {e}") + import traceback + error(f"📋 堆栈: {traceback.format_exc()}") + graph_error = e + await token_queue.put({ + "type": "error", + "message": str(e) + }) + finally: + graph_done.set() + + # 启动后台任务运行 graph + graph_task = asyncio.create_task(run_graph()) + + try: + # 从队列里取事件并 yield + while True: + # 尝试从队列取事件,超时检查 graph 是否完成 + try: + event = await asyncio.wait_for(token_queue.get(), timeout=0.1) + yield event + except asyncio.TimeoutError: + # 检查 graph 是否完成 + if graph_done.is_set(): + break + + # 如果 graph 有错误,已经在 run_graph 里 yield error 了 - except Exception as e: - error(f"❌ 执行图时出错: {e}") - import traceback - error(f"📋 堆栈: {traceback.format_exc()}") - yield { - "type": "error", - "message": str(e) - } finally: # 无论成功或失败,都发送结束事件,保证前端平稳关闭 if current_node: @@ -327,3 +364,5 @@ class AIAgentService: "type": "done", "model_used": actual_model_used } + # 取消任务 + graph_task.cancel() diff --git a/backend/app/agent/stream_context.py b/backend/app/agent/stream_context.py new file mode 100644 index 0000000..e8efa2f --- /dev/null +++ b/backend/app/agent/stream_context.py @@ -0,0 +1,9 @@ +"""流式上下文,用于在 LangGraph 节点和 agent_service 之间传递 token 回调""" +import contextvars +import asyncio +from typing import Optional, Any + +# 上下文变量:存储当前的 token 队列 +token_queue_var: contextvars.ContextVar[Optional[asyncio.Queue]] = contextvars.ContextVar( + "token_queue", default=None +) diff --git a/backend/app/main_graph/nodes/agent.py b/backend/app/main_graph/nodes/agent.py index 7248269..7acce36 100644 --- a/backend/app/main_graph/nodes/agent.py +++ b/backend/app/main_graph/nodes/agent.py @@ -1,10 +1,11 @@ """Agent 节点:核心推理与工具调用""" from typing import Dict, Any, Optional -from langchain_core.messages import SystemMessage +from langchain_core.messages import SystemMessage, AIMessage, AIMessageChunk from langchain_core.runnables.config import RunnableConfig from ..state import AgentState -from backend.app.logger import info, warning +from backend.app.logger import info, warning, error +from .stream_context import token_queue_var # 系统提示词(从 main_graph_builder.py 搬过来) @@ -77,23 +78,81 @@ def create_agent_node(llm_with_tools, llm): # 判断是否达到步数上限 if state.current_step >= state.max_steps: info(f"[Agent] 达到步数上限 {state.max_steps},强制结束,不绑定工具") - llm_no_tools = llm.bind_tools([]) - response = await llm_no_tools.ainvoke(full_messages) + current_llm = llm.bind_tools([]) else: - info(f"[Agent] 调用带工具的 LLM...") - response = await llm_with_tools.ainvoke(full_messages) - + current_llm = llm_with_tools + + info(f"[Agent] 调用带工具的 LLM...") + + # 获取 token 队列 + token_queue = token_queue_var.get() + + # 完整消息 + full_content = "" + full_reasoning_content = "" + full_tool_calls = [] + + # 流式调用 LLM + async for chunk in current_llm.astream(full_messages): + if isinstance(chunk, AIMessageChunk): + # 处理 content + if chunk.content: + full_content += chunk.content + if token_queue: + await token_queue.put({ + "type": "llm_token", + "node": "agent", + "token": chunk.content, + "reasoning_token": "" + }) + + # 处理 reasoning_content + if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: + reasoning_content = chunk.additional_kwargs.get("reasoning_content", "") + if reasoning_content: + full_reasoning_content += reasoning_content + if token_queue: + await token_queue.put({ + "type": "llm_token", + "node": "agent", + "token": "", + "reasoning_token": reasoning_content + }) + + # 处理 tool_calls + if hasattr(chunk, 'tool_calls') and chunk.tool_calls: + # 合并 tool_calls + for tc in chunk.tool_calls: + # 查找是否已经有这个 id 的 tool_call + found = False + for existing_tc in full_tool_calls: + if existing_tc.get("id") == tc.get("id"): + # 合并 args + existing_tc["args"] = {**existing_tc.get("args", {}), **tc.get("args", {})} + found = True + break + if not found: + full_tool_calls.append(tc) + + # 构建完整的 AIMessage + response = AIMessage( + content=full_content, + tool_calls=full_tool_calls if full_tool_calls else None + ) + if full_reasoning_content: + response.additional_kwargs["reasoning_content"] = full_reasoning_content + info(f"[Agent] LLM 调用成功!响应类型: {type(response).__name__}") if hasattr(response, 'tool_calls') and response.tool_calls: info(f"[Agent] 检测到工具调用: {[tc['name'] for tc in response.tool_calls]}") - # 返回状态更新(注意:不原地修改 state,返回字典让 LangGraph 处理 + # 返回状态更新 return { "messages": [response], "current_step": state.current_step + 1, "llm_calls": state.llm_calls + 1 } - + except Exception as e: error(f"[Agent] ❌ 第 {state.current_step} 步推理出错: {e}") import traceback