""" Agent 节点 - 简化版本(单步推理) 只负责一次 LLM 调用,不执行工具 """ import json from typing import Dict, Any, Optional, List from langchain_core.runnables.config import RunnableConfig from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage from backend.app.main_graph.state import AgentState from backend.app.logger import info, error, debug from backend.app.tools import ALL_TOOLS from backend.app.agent.stream_context import get_stream_queue from backend.app.agent.prompts import SYSTEM_PROMPT def _normalize_args(args: dict) -> str: """标准化工具参数用于比较""" return str(sorted(args.items())) def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool: """检测结果是否相似(简单实现:长度相似+部分内容重复)""" if len(results) < 2: return False latest = results[-1] prev = results[-2] if len(latest) == 0 or len(prev) == 0: return len(latest) == len(prev) len_ratio = min(len(latest), len(prev)) / max(len(latest), len(prev)) if len_ratio < 0.5: return False common_len = 0 for a, b in zip(latest[:100], prev[:100]): if a == b: common_len += 1 else: break return (common_len / 100) > threshold def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool: """ 检测是否应该停止(循环检测) 条件:连续2次调用相同工具 + 参数相似 + 结果相似 """ if len(tool_calls) < 2: return False last_tc = tool_calls[-1] prev_tc = tool_calls[-2] if last_tc["name"] != prev_tc["name"]: return False last_args = _normalize_args(last_tc["args"]) prev_args = _normalize_args(prev_tc["args"]) if last_args != prev_args: return False if len(tool_results) >= 2: return _is_similar_result(tool_results[-2:]) return False def create_agent_node(chat_services: dict): """ 创建 Agent 节点 - 单步推理版本 设计: - 只做一次 LLM 调用 - 不执行工具(工具执行由 tools 节点负责) - 返回 AIMessage(可能包含 tool_calls) """ async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """Agent 节点:单步 LLM 调用""" queue = get_stream_queue() is_streaming = queue is not None # 获取步数 current_step = getattr(state, "current_step", 0) max_steps = getattr(state, "max_steps", 10) info(f"[Agent] 第 {current_step + 1} 步开始") # 步数已达上限 if current_step >= max_steps: info("[Agent] 达到步数上限,强制结束") return { "messages": [AIMessage(content="[系统] 已达到最大步数限制。")], "stop": True, "stop_reason": "max_steps", } # 循环检测 tool_history = getattr(state, "tool_call_history", []) result_history = getattr(state, "tool_result_history", []) if _should_stop_for_loop(tool_history, result_history): info("[Agent] 检测到循环,终止推理") return { "messages": [AIMessage(content="[系统] 检测到工具调用循环,已终止。")], "stop": True, "stop_reason": "loop_detected", } # 动态获取模型 model_name = "primary" if config: configurable = config.get("configurable", {}) model_name = configurable.get("model", "primary") llm = chat_services.get(model_name) if llm is None: llm = next(iter(chat_services.values())) info(f"[Agent] 模型 '{model_name}' 不可用,使用 '{type(llm).__name__}'") llm_with_tools = llm.bind_tools(ALL_TOOLS) # 获取记忆上下文 memory_context = getattr(state, "memory_context", "暂无用户背景信息") prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context) messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages) # 发送节点开始事件 if is_streaming: await queue.put({"type": "node_start", "node": "agent"}) # 选择 LLM(最后一轮不带工具) if current_step + 1 >= max_steps: current_llm = llm.bind_tools([]) info(f"[Agent] 达到步数上限,使用无工具模型") else: current_llm = llm_with_tools # 初始化 full_content = "" full_reasoning_content = "" pending_tool_calls = {} final_tool_calls = [] try: # 调用 LLM if is_streaming: async for chunk in current_llm.astream(messages): if isinstance(chunk, AIMessageChunk): if chunk.content: full_content += chunk.content await queue.put({ "type": "llm_token", "node": "agent", "token": chunk.content, "reasoning_token": "" }) if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: reasoning = chunk.additional_kwargs.get("reasoning_content", "") if reasoning: full_reasoning_content += reasoning await queue.put({ "type": "llm_token", "node": "agent", "token": "", "reasoning_token": reasoning }) if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: for tc_chunk in chunk.tool_call_chunks: idx = tc_chunk.get("index", 0) if idx not in pending_tool_calls: pending_tool_calls[idx] = {"id": "", "name": "", "args": ""} if tc_chunk.get("id"): pending_tool_calls[idx]["id"] += tc_chunk["id"] if tc_chunk.get("name"): pending_tool_calls[idx]["name"] += tc_chunk["name"] if tc_chunk.get("args"): args_val = tc_chunk["args"] if isinstance(args_val, str): pending_tool_calls[idx]["args"] += args_val else: pending_tool_calls[idx]["args"] += json.dumps(args_val) else: result = await current_llm.ainvoke(messages) full_content = result.content if result.content else "" if hasattr(result, 'tool_calls') and result.tool_calls: final_tool_calls = result.tool_calls if hasattr(result, 'additional_kwargs'): full_reasoning_content = result.additional_kwargs.get("reasoning_content", "") # 整理工具调用 if is_streaming: for idx in sorted(pending_tool_calls.keys()): tc_data = pending_tool_calls[idx] if tc_data["name"]: args = {} if tc_data["args"]: try: args = json.loads(tc_data["args"]) except Exception as e: info(f"[Agent] 解析参数失败: {e}") final_tool_calls.append({ "id": tc_data["id"], "name": tc_data["name"], "args": args }) # 发送节点结束事件 if is_streaming: await queue.put({"type": "node_end", "node": "agent"}) # 构建响应 response_kwargs = {"content": full_content} if final_tool_calls: response_kwargs["tool_calls"] = final_tool_calls response = AIMessage(**response_kwargs) if full_reasoning_content: response.additional_kwargs["reasoning_content"] = full_reasoning_content info(f"[Agent] 完成 - content长度: {len(full_content)}, tool_calls: {len(final_tool_calls)}") return { "messages": [response], "llm_calls": getattr(state, "llm_calls", 0) + 1 } except Exception as e: error(f"[Agent] 执行出错: {e}") import traceback error(f"[Agent] 堆栈: {traceback.format_exc()}") if is_streaming: await queue.put({"type": "error", "message": str(e)}) raise return agent_node