diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 673dda2..59c9bcf 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -1,6 +1,6 @@ """ -AI Agent 服务类 - 优化版本:双协程 + 结束哨兵 + 完整的取消和异常处理 -接收外部传入的 checkpointer,不负责管理连接生命周期 +AI Agent 服务类 - 完全简化版本! +按照指南实现,不用 stream_mode="messages" 避免重复 token! """ import json @@ -11,11 +11,11 @@ from typing import AsyncGenerator, Dict, Any, Optional, Tuple from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer # 本地模块 -from ..model_services import get_cached_chat_services -from ..main_graph.main_graph_builder import build_agent_graph +from backend.app.model_services import get_cached_chat_services +from backend.app.main_graph.main_graph_builder import build_agent_graph from backend.app.logger import debug, info, warning, error -from ..main_graph.state import AgentState -from .stream_context import token_queue_var +from backend.app.main_graph.state import AgentState +from .stream_context import set_stream_queue class AIAgentService: @@ -120,125 +120,10 @@ class AIAgentService: "model_used": resolved_model } - def _serialize_value(self, value): - """递归将 LangChain 对象转换为可 JSON 序列化的格式""" - if hasattr(value, 'content'): - msg_type = getattr(value, 'type', 'message') - return { - "role": msg_type, - "content": getattr(value, 'content', ''), - "additional_kwargs": getattr(value, 'additional_kwargs', {}), - "tool_calls": getattr(value, 'tool_calls', []) - } - elif isinstance(value, dict): - return {k: self._serialize_value(v) for k, v in value.items()} - elif isinstance(value, (list, tuple)): - return [self._serialize_value(item) for item in value] - else: - try: - json.dumps(value) - return value - except (TypeError, ValueError): - return str(value) - - async def _handle_message_chunk( - self, chunk: Dict[str, Any], current_node: Optional[str], tool_calls_in_progress: Dict[str, Any] - ) -> AsyncGenerator[Dict[str, Any], None]: - """处理 messages 类型的 chunk""" - message_chunk, metadata = chunk["data"] - node_name = metadata.get("langgraph_node", "unknown") - new_current_node = current_node - - # 检测节点变化,发送节点开始事件 - if node_name != current_node: - if current_node: - yield {"type": "node_end", "node": current_node} - yield {"type": "node_start", "node": node_name} - new_current_node = node_name - - # 处理消息内容 - token_content = getattr(message_chunk, 'content', str(message_chunk)) - reasoning_token = "" - if hasattr(message_chunk, 'additional_kwargs'): - reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") - - # 处理思考过程 - if reasoning_token: - yield { - "type": "llm_token", - "node": node_name, - "reasoning_token": reasoning_token - } - # 处理工具调用 - elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls: - for tool_call in message_chunk.tool_calls: - tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("name", "") - tool_args = tool_call.get("args", {}) - - # 记录工具调用开始,避免重复 - if tool_call_id and tool_call_id not in tool_calls_in_progress: - tool_calls_in_progress[tool_call_id] = { - "name": tool_name, - "args": tool_args - } - yield { - "type": "tool_call_start", - "tool": tool_name, - "args": tool_args, - "id": tool_call_id - } - # 处理普通 token - elif token_content: - yield { - "type": "llm_token", - "node": node_name, - "token": token_content, - "reasoning_token": reasoning_token - } - - # 返回更新后的 current_node - yield {"type": "_update_state", "current_node": new_current_node} - - async def _handle_updates_chunk( - self, chunk: Dict[str, Any], tool_calls_in_progress: Dict[str, Any], actual_model_used: str - ) -> AsyncGenerator[Dict[str, Any], None]: - """处理 updates 类型的 chunk""" - updates_data = chunk["data"] - new_actual_model = actual_model_used - - serialized_data = self._serialize_value(updates_data) - - # 检查是否有工具结果 - if "messages" in serialized_data: - for msg in serialized_data["messages"]: - # 检测工具结果消息 - if msg.get("role") == "tool": - tool_call_id = msg.get("tool_call_id", "") - tool_name = msg.get("name", "") - tool_result = msg.get("content", "") - - if tool_call_id and tool_call_id in tool_calls_in_progress: - yield { - "type": "tool_call_end", - "tool": tool_name, - "id": tool_call_id, - "result": tool_result - } - del tool_calls_in_progress[tool_call_id] - - yield { - "type": "state_update", - "data": serialized_data - } - - # 返回更新后的模型 - yield {"type": "_update_state", "actual_model_used": new_actual_model} - async def process_message_stream( self, message: str, thread_id: str, model: str = "", user_id: str = "default_user" ) -> AsyncGenerator[Dict[str, Any], None]: - """流式处理消息 - 双协程 + 结束哨兵 + 完整取消和异常处理""" + """流式处理消息 - 完全简化!""" # 解析模型名称 resolved_model = self._resolve_model(model) @@ -246,144 +131,64 @@ class AIAgentService: config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id) info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}") - current_node = None - tool_calls_in_progress: Dict[str, Any] = {} actual_model_used = resolved_model - full_message_content = "" # 创建 token 队列 - token_queue = asyncio.Queue() - # 结束哨兵 - SENTINEL = object() + queue = asyncio.Queue() + set_stream_queue(queue) # 设置上下文变量 - # 设置上下文变量 - token_queue_var.set(token_queue) - - # 事件和错误跟踪 - graph_error = None - graph_done = asyncio.Event() - - async def run_graph_task(): - """后台任务:运行 graph.astream()""" - nonlocal current_node, actual_model_used, full_message_content, graph_error + async def run_graph(): + """后台任务:运行 graph,只获取 updates,不要用 stream_mode="messages" 避免重复 token!""" try: info(f"📡 开始调用 graph.astream()...") - event_count = 0 - + # 注意:只用 stream_mode=["updates"],不要 "messages"!避免重复 token! async for chunk in self.graph.astream( input_state, config=config, - stream_mode=["messages", "updates"], + stream_mode=["updates"], version="v2", subgraphs=True ): - chunk_count = 0 - chunk_count += 1 - chunk_type = chunk["type"] - - # 记录原始 chunk 信息(前 10 个和后 10 个) - if chunk_count <= 10 or chunk_count % 50 == 0: - info(f" [{chunk_count}] chunk_type={chunk_type}, data={type(chunk.get('data'))}") - - if chunk_type == "messages": - async for event in self._handle_message_chunk( - chunk, current_node, tool_calls_in_progress - ): - if event.get("type") == "_update_state": - current_node = event.get("current_node", current_node) - else: - event_count += 1 - # 记录前 10 个事件 - if event_count <= 10: - info(f" → yield event #{event_count}: {event.get('type')}") - - # 如果是 agent 节点的 token,收集完整消息 - if ( - event.get("type") == "llm_token" - and event.get("node") == "agent" - and "token" in event - ): - full_message_content += event["token"] - await token_queue.put(event) - - elif chunk_type == "updates": - async for event in self._handle_updates_chunk( - chunk, tool_calls_in_progress, actual_model_used - ): - if event.get("type") == "_update_state": - actual_model_used = event.get("actual_model_used", actual_model_used) - else: - event_count += 1 - if event_count <= 10: - info(f" → yield event #{event_count}: {event.get('type')}") - await token_queue.put(event) - - # 完整消息集合完成后,一次性打印 - info(f"✅ graph.astream() 完成,共 {event_count} 个 events") - if full_message_content: - info(f"📄 完整消息内容: {repr(full_message_content)}") - + # 可以处理一些状态更新事件,如 final_result 等 + await queue.put({ + "type": "graph_update", + "data": chunk, + }) except Exception as e: error(f"❌ 执行图时出错: {e}") import traceback error(f"📋 堆栈: {traceback.format_exc()}") - graph_error = e - await token_queue.put({ - "type": "error", - "message": str(e) - }) + await queue.put({"type": "error", "message": str(e)}) finally: - # 发送结束哨兵 - await token_queue.put(SENTINEL) - graph_done.set() + await queue.put(None) # 结束哨兵 # 启动后台任务 - graph_task = asyncio.create_task(run_graph_task()) + bg_task = asyncio.create_task(run_graph()) try: - # 主协程:从队列里取事件并 yield while True: - try: - # 等待队列中的事件,带超时检查任务是否完成 - event = await asyncio.wait_for(token_queue.get(), timeout=0.5) - - # 检查是否是结束哨兵 - if event is SENTINEL: - break - - yield event - - except asyncio.TimeoutError: - # 超时检查任务是否完成 - if graph_task.done(): - # 检查任务是否抛出异常 - if graph_task.exception(): - exc = graph_task.exception() - error(f"❌ 后台任务异常: {exc}") - break - - except asyncio.CancelledError: - info("⚠️ 流式生成被取消") + event = await queue.get() + if event is None: + break + yield event + + except GeneratorExit: + # 客户端断开连接,取消后台任务 + info("⚠️ GeneratorExit,取消后台任务") + bg_task.cancel() raise - finally: - # 无论成功或失败,都清理资源 - # 取消后台任务 - if not graph_task.done(): - info("⏹️ 取消后台任务") - graph_task.cancel() + # 保证任务被清理 + if not bg_task.done(): + info("⏹️ 清理后台任务") + bg_task.cancel() try: - await graph_task + await bg_task except asyncio.CancelledError: info("✅ 后台任务已取消") # 发送结束事件,保证前端平稳关闭 - if current_node: - yield { - "type": "node_end", - "node": current_node - } yield { "type": "done", "model_used": actual_model_used diff --git a/backend/app/agent/stream_context.py b/backend/app/agent/stream_context.py index 2d5470d..82d8663 100644 --- a/backend/app/agent/stream_context.py +++ b/backend/app/agent/stream_context.py @@ -1,9 +1,22 @@ -"""流式上下文,用于在 LangGraph 节点和 agent_service 之间传递 token 队列""" +""" +流式上下文,用于在 LangGraph 节点和 agent_service 之间传递 token 队列 +清晰的 API,更易用! +""" import contextvars import asyncio from typing import Optional, Any -# 上下文变量:存储当前的 token 队列 -token_queue_var: contextvars.ContextVar[Optional[asyncio.Queue]] = contextvars.ContextVar( - "token_queue", default=None +# 上下文变量:存储每个请求专属的 token 队列 +stream_queue_ctx: contextvars.ContextVar[Optional[asyncio.Queue]] = contextvars.ContextVar( + "stream_queue", default=None ) + + +def set_stream_queue(queue: asyncio.Queue) -> None: + """设置当前请求的队列""" + stream_queue_ctx.set(queue) + + +def get_stream_queue() -> Optional[asyncio.Queue]: + """获取当前请求的队列""" + return stream_queue_ctx.get() diff --git a/backend/app/main_graph/main_graph_builder.py b/backend/app/main_graph/main_graph_builder.py index 5f048d8..766d8d4 100644 --- a/backend/app/main_graph/main_graph_builder.py +++ b/backend/app/main_graph/main_graph_builder.py @@ -1,12 +1,12 @@ """ -极简 Agent 主图 - 自己的节点结构,更好控制流式 +极简 Agent 主图 - 简化版本! +因为完整的 ReAct 循环已经在 agent.py 里了! """ from langgraph.graph import StateGraph, START, END -from langgraph.prebuilt import ToolNode -from ..state import AgentState -from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client -from ..nodes.agent import create_agent_node +from backend.app.main_graph.state import AgentState +from backend.app.main_graph.nodes.memory_trigger import memory_trigger_node, set_mem0_client +from backend.app.main_graph.nodes.agent import create_agent_node from backend.app.logger import info, warning from backend.app.tools import ALL_TOOLS @@ -17,7 +17,7 @@ def build_agent_graph( max_steps: int = 10 ): """ - 构建包含记忆节点的 Agent 图 + 构建简化的 Agent 图(ReAct 循环在 agent 节点内) Args: chat_services: 模型服务字典 @@ -51,28 +51,16 @@ def build_agent_graph( except Exception as e: info(f"[Graph Builder] 记忆节点初始化失败: {e}") - # ========== 3. 核心节点 ========== + # ========== 3. Agent 节点(包含完整 ReAct 循环) ========== llm_with_tools = primary_model.bind_tools(ALL_TOOLS) agent_node_fn = create_agent_node(llm_with_tools, primary_model) - tool_node_fn = ToolNode(ALL_TOOLS) - # ========== 4. 条件边判断函数 ========== - def should_continue(state: AgentState): - """判断是继续调用工具还是结束""" - messages = state.messages - last_message = messages[-1] if messages else None - - if last_message and hasattr(last_message, 'tool_calls') and last_message.tool_calls: - return "tools" - - return "finalize" - - # ========== 5. 完成节点 ========== + # ========== 4. 完成节点 ========== async def finalize_node_simple(state: AgentState): info("[Finalize] 进入完成节点") return {} - # ========== 6. 构建图 ========== + # ========== 5. 构建图 ========== graph = StateGraph(AgentState) graph.add_node("init_state", init_state_node) @@ -80,10 +68,9 @@ def build_agent_graph( graph.add_node("retrieve_memory", retrieve_memory_node) graph.add_node("memory_trigger", memory_trigger_node) graph.add_node("agent", agent_node_fn) - graph.add_node("tools", tool_node_fn) graph.add_node("finalize", finalize_node_simple) - # ========== 7. 边的连接 ========== + # ========== 6. 边的连接 ========== graph.add_edge(START, "init_state") if retrieve_memory_node: @@ -93,18 +80,8 @@ def build_agent_graph( graph.add_edge("init_state", "memory_trigger") graph.add_edge("memory_trigger", "agent") - - graph.add_conditional_edges( - "agent", - should_continue, - { - "tools": "tools", - "finalize": "finalize" - } - ) - - graph.add_edge("tools", "agent") + graph.add_edge("agent", "finalize") graph.add_edge("finalize", END) - info("✅ [Graph Builder] 极简 Agent 图构建完成") + info("✅ [Graph Builder] 简化 Agent 图构建完成(ReAct 在节点内)") return graph diff --git a/backend/app/main_graph/nodes/agent.py b/backend/app/main_graph/nodes/agent.py index daa99de..18ec44c 100644 --- a/backend/app/main_graph/nodes/agent.py +++ b/backend/app/main_graph/nodes/agent.py @@ -1,11 +1,15 @@ -"""Agent 节点:核心推理与工具调用""" +""" +Agent 节点:完整的 ReAct 循环 + 流式 Tool Calling 拼接 +完全参考指南实现! +""" -from typing import Dict, Any, Optional -from langchain_core.messages import SystemMessage, AIMessage, AIMessageChunk +from typing import Dict, Any, Optional, List +from langchain_core.messages import SystemMessage, AIMessage, AIMessageChunk, ToolMessage from langchain_core.runnables.config import RunnableConfig -from ..state import AgentState +from backend.app.main_graph.state import AgentState from backend.app.logger import info, warning, error -from .stream_context import token_queue_var +from backend.app.agent.stream_context import get_stream_queue +from backend.app.tools import ALL_TOOLS # 系统提示词(从 main_graph_builder.py 搬过来) @@ -54,11 +58,12 @@ SYSTEM_PROMPT = """你是一个智能助手,可以使用多种工具完成复 def create_agent_node(llm_with_tools, llm): - """创建 Agent 节点函数""" + """创建 Agent 节点函数,完整 ReAct 循环""" async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """ - Agent 节点:调用带工具的 LLM,处理步数限制 + Agent 节点:完整的 ReAct 循环,带流式 token 和工具调用事件 + 兼容流式和非流式两种情况! Args: state: 当前状态 @@ -67,114 +72,214 @@ def create_agent_node(llm_with_tools, llm): Returns: 状态更新字典 """ - current_step = state.get("current_step", 0) - info(f"[Agent] 第 {current_step} 步推理") + # 获取队列 + queue = get_stream_queue() + is_streaming = queue is not None + + # 获取当前步数 + current_step = getattr(state, "current_step", 0) + max_steps = getattr(state, "max_steps", 10) + info(f"[Agent] 从第 {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}") + + # 组装完整消息 + messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(state.messages) + turn = current_step # 轮次从当前步数开始 try: - # 组装完整消息:系统提示 + 历史消息 - full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + state.messages - - info(f"[Agent] 消息数量: {len(full_messages)}, 最后一条: {type(full_messages[-1]).__name__}") + while turn < max_steps: + turn += 1 + info(f"[Agent] 第 {turn} 轮思考") - # 判断是否达到步数上限 - if current_step >= state.get("max_steps", 10): - info(f"[Agent] 达到步数上限 {state.get('max_steps', 10)},强制结束,不绑定工具") - current_llm = llm.bind_tools([]) - else: - current_llm = llm_with_tools + # 告诉前端:新的一轮开始(如果流式) + if is_streaming: + await queue.put({ + "type": "turn_start", + "turn": turn, + }) - info(f"[Agent] 调用带工具的 LLM...") + # 选择 LLM + if turn >= max_steps: + info(f"[Agent] 达到步数上限,用不带工具的 LLM") + current_llm = llm.bind_tools([]) + else: + current_llm = llm_with_tools - # 获取 token 队列 - token_queue = token_queue_var.get() - if token_queue is None: - error("[Agent] ❌ token_queue 为 None!") - raise RuntimeError("token_queue 上下文变量未设置") + # 初始化变量 + full_content = "" + full_reasoning_content = "" + pending_tool_calls = {} # key: index, value: {id, name, args_str} + final_tool_calls = [] - # 完整消息 - full_content = "" - full_reasoning_content = "" - full_tool_calls = [] - - # 流式调用 LLM - async for chunk in current_llm.astream(full_messages): - if isinstance(chunk, AIMessageChunk): - # 处理 content - if chunk.content: - full_content += chunk.content - await token_queue.put({ - "type": "llm_token", - "node": "agent", - "token": chunk.content, - "reasoning_token": "", - "turn": current_step, - "phase": "answering" if not full_tool_calls else "thinking" - }) - - # 处理 reasoning_content - if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: - reasoning_content = chunk.additional_kwargs.get("reasoning_content", "") - if reasoning_content: - full_reasoning_content += reasoning_content - await token_queue.put({ - "type": "llm_token", - "node": "agent", - "token": "", - "reasoning_token": reasoning_content, - "turn": current_step, - "phase": "thinking" - }) - - # 处理 tool_calls - if hasattr(chunk, 'tool_calls') and chunk.tool_calls: - # 合并 tool_calls - for tc in chunk.tool_calls: - # 查找是否已经有这个 id 的 tool_call - found = False - for existing_tc in full_tool_calls: - if existing_tc.get("id") == tc.get("id"): - # 合并 args - existing_tc["args"] = {**existing_tc.get("args", {}), **tc.get("args", {})} - found = True - break - if not found: - full_tool_calls.append(tc) - # 发送工具调用开始事件 - await token_queue.put({ - "type": "tool_call_start", - "tool": tc.get("name"), - "args": tc.get("args"), - "id": tc.get("id", ""), - "turn": current_step + # 只有流式的时候用 astream,非流式直接用 ainvoke 更快! + if is_streaming: + async for chunk in current_llm.astream(messages): + if isinstance(chunk, AIMessageChunk): + # 1. 处理文本 token + if chunk.content: + full_content += chunk.content + await queue.put({ + "type": "llm_token", + "turn": turn, + "phase": "answering", + "token": chunk.content, + "reasoning_token": "" }) - # 构建完整的 AIMessage - response = AIMessage( - content=full_content, - tool_calls=full_tool_calls if full_tool_calls else None - ) + # 2. 处理 reasoning token + if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: + reasoning_content = chunk.additional_kwargs.get("reasoning_content", "") + if reasoning_content: + full_reasoning_content += reasoning_content + await queue.put({ + "type": "llm_token", + "turn": turn, + "phase": "reasoning", + "token": "", + "reasoning_token": reasoning_content + }) + + # 3. 流式 Tool Calling 拼接逻辑(核心!用 tool_call_chunks!) + if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: + for tc_chunk in chunk.tool_call_chunks: + idx = tc_chunk.get("index", 0) + if idx not in pending_tool_calls: + pending_tool_calls[idx] = { + "id": "", + "name": "", + "args": "" # 初始化为字符串 + } + + if tc_chunk.get("id"): + pending_tool_calls[idx]["id"] += tc_chunk["id"] + if tc_chunk.get("name"): + pending_tool_calls[idx]["name"] += tc_chunk["name"] + if tc_chunk.get("args"): + args_val = tc_chunk["args"] + if isinstance(args_val, str): + pending_tool_calls[idx]["args"] += args_val + else: + import json + pending_tool_calls[idx]["args"] += json.dumps(args_val) + else: + # 非流式,直接 ainvoke + result = await current_llm.ainvoke(messages) + full_content = result.content if result.content else "" + if hasattr(result, 'tool_calls') and result.tool_calls: + final_tool_calls = result.tool_calls + if hasattr(result, 'additional_kwargs') and result.additional_kwargs: + full_reasoning_content = result.additional_kwargs.get("reasoning_content", "") + + # 流式调用结束后,整理最终的 tool_calls(只在流式时处理 pending!) + if is_streaming: + for idx in sorted(pending_tool_calls.keys()): + tc_data = pending_tool_calls[idx] + if tc_data["name"]: # 只有有名字的才是有效工具调用 + # 解析参数字符串 + args = {} + if tc_data["args"]: + try: + import json + args = json.loads(tc_data["args"]) + except Exception as e: + info(f"[Agent] Failed to parse args JSON: {e}, raw: {tc_data['args']}") + final_tool_calls.append({ + "id": tc_data["id"], + "name": tc_data["name"], + "args": args + }) + + # 判断是否有工具调用 + if final_tool_calls: + info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具") + + # 执行工具调用 + new_messages = [] + for tc in final_tool_calls: + tool_name = tc["name"] + tool_args = tc["args"] + tool_id = tc["id"] + + # 发送工具开始事件(如果流式) + if is_streaming: + await queue.put({ + "type": "tool_start", + "turn": turn, + "tool": tool_name, + "args": tool_args, + "id": tool_id + }) + + # 找到并执行对应工具 + tool_result = "" + tool_found = False + for tool in ALL_TOOLS: + if tool.name == tool_name: + tool_found = True + try: + tool_result = await tool.ainvoke(tool_args) + except Exception as e: + tool_result = f"工具调用出错: {str(e)}" + error(f"[Agent] 工具 {tool_name} 调用出错: {e}") + break + + if not tool_found: + tool_result = f"未找到工具: {tool_name}" + + # 发送工具结束事件(如果流式) + if is_streaming: + await queue.put({ + "type": "tool_end", + "turn": turn, + "tool": tool_name, + "id": tool_id, + "result": str(tool_result) + }) + + # 构造 ToolMessage + tool_msg = ToolMessage( + content=str(tool_result), + tool_call_id=tool_id, + name=tool_name + ) + new_messages.append(tool_msg) + + # 添加到 messages,继续下一轮 + messages.extend(new_messages) + continue + + else: + # 没有工具调用,最终输出 + info(f"[Agent] 第 {turn} 轮:完成,无工具调用") + if is_streaming: + await queue.put({ + "type": "final_answer", + "turn": turn, + "content": full_content + }) + break + + # 构建完整的 AIMessage 用于状态更新 + response_kwargs = {"content": full_content} + if final_tool_calls: + response_kwargs["tool_calls"] = final_tool_calls + response = AIMessage(**response_kwargs) if full_reasoning_content: response.additional_kwargs["reasoning_content"] = full_reasoning_content - info(f"[Agent] LLM 调用成功!响应类型: {type(response).__name__}") - if hasattr(response, 'tool_calls') and response.tool_calls: - info(f"[Agent] 检测到工具调用: {[tc['name'] for tc in response.tool_calls]}") - # 返回状态更新 return { "messages": [response], - "current_step": current_step + 1, - "llm_calls": state.get("llm_calls", 0) + 1 + "current_step": turn, + "llm_calls": getattr(state, "llm_calls", 0) + 1 } except Exception as e: - error(f"[Agent] ❌ 第 {current_step} 步推理出错: {e}") + error(f"[Agent] ❌ 第 {turn} 轮出错: {e}") import traceback error(f"[Agent] 堆栈: {traceback.format_exc()}") - # 发送错误事件 - token_queue = token_queue_var.get() - if token_queue: - await token_queue.put({ + # 发送错误事件(如果流式) + if is_streaming: + await queue.put({ "type": "error", "message": str(e) }) diff --git a/backend/app/tools/__init__.py b/backend/app/tools/__init__.py index 4cec476..59b72e5 100644 --- a/backend/app/tools/__init__.py +++ b/backend/app/tools/__init__.py @@ -7,7 +7,7 @@ from typing import Optional from backend.app.logger import info -# ====== RAG Pipeline(复用现有) +# ========== RAG Pipeline(复用现有) _rag_pipeline = None diff --git a/frontend/src/components/chat_area.py b/frontend/src/components/chat_area.py index 258b1f3..5dd5b95 100644 --- a/frontend/src/components/chat_area.py +++ b/frontend/src/components/chat_area.py @@ -226,7 +226,7 @@ def _handle_ai_response(): elif event_type == "llm_token": node_name = event.get("node", "unknown") # 确保只处理来自 LLM 的 token,避免将工具的输出作为 token 显示 - if node_name in ("llm_call", "fallback"): + if node_name in ("llm_call", "fallback", "agent"): token = str(event.get("token", "")) reasoning_token = str(event.get("reasoning_token", "")) diff --git a/tools/test/test_full_react_streaming.py b/tools/test/test_full_react_streaming.py new file mode 100644 index 0000000..51643da --- /dev/null +++ b/tools/test/test_full_react_streaming.py @@ -0,0 +1,104 @@ +""" +测试新的完整 ReAct 循环架构 + 流式 Tool Calling +""" + +import asyncio +import sys +import os + +sys.path.insert(0, "/root/projects/ailine/backend") + +from app.main_graph.main_graph_builder import build_agent_graph +from app.model_services import get_cached_chat_services +from app.agent.stream_context import set_stream_queue +from app.logger import info, error + + +async def test_full_react_streaming(): + """测试完整的 ReAct 循环流式架构""" + info("=" * 60) + info("🧪 测试完整 ReAct 循环 + 流式 Tool Calling") + info("=" * 60) + + # 1. 获取服务 + chat_services = get_cached_chat_services() + info(f"✅ 加载了 {len(chat_services)} 个模型: {list(chat_services.keys())}") + + # 2. 构建图 + graph_builder = build_agent_graph(chat_services, mem0_client=None) + graph = graph_builder.compile() + info(f"✅ 图构建完成") + + # 3. 创建队列 + queue = asyncio.Queue() + set_stream_queue(queue) + + # 4. 定义后台任务 + async def run_graph(): + try: + input_state = { + "messages": [ + {"role": "user", "content": "你好,请介绍一下你自己"} + ], + "user_id": "test_user", + } + async for chunk in graph.astream( + input_state, + stream_mode=["updates"], + version="v2" + ): + await queue.put({ + "type": "graph_update", + "data": chunk, + }) + except Exception as e: + error(f"❌ 图执行出错: {e}") + import traceback + error(f"📋 堆栈: {traceback.format_exc()}") + await queue.put({"type": "error", "message": str(e)}) + finally: + await queue.put(None) + + # 5. 启动后台任务并处理事件 + bg_task = asyncio.create_task(run_graph()) + + info("\n📡 开始接收流式事件:\n") + try: + while True: + event = await queue.get() + if event is None: + break + if event["type"] == "llm_token": + if event["token"]: + print(event["token"], end="") + if event["reasoning_token"]: + print(f"{event['reasoning_token']}", end="") + elif event["type"] == "turn_start": + print(f"\n===== Turn {event['turn']} 开始 =====") + elif event["type"] == "tool_start": + print(f"\n🔧 工具调用: {event['tool']}") + elif event["type"] == "tool_end": + print(f"\n✅ 工具调用完成") + elif event["type"] == "final_answer": + print(f"\n📝 最终答案") + elif event["type"] == "graph_update": + # 忽略 update 事件,只关心 agent 节点发的事件 + pass + else: + print(f"\n📋 其他事件: {event}") + + print("\n✅ 流式测试完成") + return True + + except Exception as e: + error(f"❌ 测试出错: {e}") + import traceback + error(f"📋 堆栈: {traceback.format_exc()}") + return False + finally: + if not bg_task.done(): + bg_task.cancel() + + +if __name__ == "__main__": + asyncio.run(test_full_react_streaming()) diff --git a/tools/test/test_stream.py b/tools/test/test_stream.py new file mode 100644 index 0000000..2aa312e --- /dev/null +++ b/tools/test/test_stream.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +"""测试后端流式接口,看看是否真的有流式输出""" + +import asyncio +import aiohttp +import json + +BACKEND_URL = "http://localhost:8079/chat/stream" + + +async def test_stream(): + print("=" * 60) + print("🧪 测试后端流式接口") + print("=" * 60) + + async with aiohttp.ClientSession() as session: + payload = { + "message": "你好,请简单介绍一下自己", + "thread_id": "test-thread-001", + "model": "zhipu", + "user_id": "test-user" + } + + print(f"\n📤 发送请求: {json.dumps(payload, ensure_ascii=False)}") + + try: + async with session.post(BACKEND_URL, json=payload) as response: + print(f"\n✅ 响应状态: {response.status}") + print(f"\n📥 开始接收流式响应...\n") + + event_count = 0 + token_count = 0 + + async for line in response.content: + line = line.decode('utf-8').strip() + if line: + if line.startswith("data: "): + data_str = line[6:] + if data_str == "[DONE]": + print("\n🏁 收到 [DONE] 事件") + break + + try: + event = json.loads(data_str) + event_count += 1 + print(f" [{event_count}] {event.get('type')}") + + if event.get('type') == 'llm_token' and 'token' in event: + token = event['token'] + token_count += 1 + print(f" → token: {repr(token)}") + + if event.get('type') == 'node_start': + print(f" → node: {event.get('node')}") + + if event.get('type') == 'tool_call_start': + print(f" → tool: {event.get('tool')}") + + if event.get('type') == 'tool_call_end': + print(f" → tool: {event.get('tool')}") + + if event.get('type') == 'error': + print(f" ❌ 错误: {event.get('message')}") + + except Exception as e: + print(f" ❌ 解析失败: {e}, 原始数据: {repr(data_str)}") + else: + print(f" 📝 原始行: {repr(line)}") + + print(f"\n📊 统计: {event_count} 个事件, {token_count} 个 token") + + except Exception as e: + print(f"\n❌ 请求异常: {e}") + import traceback + print(f"📋 堆栈: {traceback.format_exc()}") + + +if __name__ == "__main__": + asyncio.run(test_stream())