""" Agent 节点 - 简化版本 直接定义 agent_node 函数,支持动态模型切换和循环检测 """ import hashlib from typing import Dict, Any, Optional, List from langchain_core.runnables.config import RunnableConfig from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage, ToolMessage from backend.app.main_graph.state import AgentState from backend.app.logger import info, error from backend.app.tools import ALL_TOOLS from backend.app.agent.stream_context import get_stream_queue from backend.app.agent.prompts import SYSTEM_PROMPT def _normalize_args(args: dict) -> str: """标准化工具参数用于比较""" return str(sorted(args.items())) def _is_similar_result(results: List[str], threshold: float = 0.8) -> bool: """检测结果是否相似(简单实现:长度相似+部分内容重复)""" if len(results) < 2: return False latest = results[-1] prev = results[-2] # 长度差异太大,不算相似 if len(latest) == 0 or len(prev) == 0: return len(latest) == len(prev) len_ratio = min(len(latest), len(prev)) / max(len(latest), len(prev)) if len_ratio < 0.5: return False # 检查内容重复度(简单:前100字符) common_len = 0 for a, b in zip(latest[:100], prev[:100]): if a == b: common_len += 1 else: break return (common_len / 100) > threshold def _should_stop_for_loop(tool_calls: List[dict], tool_results: List[str]) -> bool: """ 检测是否应该停止(循环检测) 条件:连续2次调用相同工具 + 参数相似 + 结果相似 """ if len(tool_calls) < 2: return False # 检查最近的工具调用是否相同 last_tc = tool_calls[-1] prev_tc = tool_calls[-2] if last_tc["name"] != prev_tc["name"]: return False # 参数是否相似 last_args = _normalize_args(last_tc["args"]) prev_args = _normalize_args(prev_tc["args"]) if last_args != prev_args: return False # 结果是否相似 if len(tool_results) >= 2: return _is_similar_result(tool_results[-2:]) return False def create_agent_node(chat_services: dict): """ 创建 Agent 节点 - 支持动态模型切换 简化设计: - 直接返回 async 函数,无需工厂包装 - 从 config 中获取模型名称,运行时动态切换 """ async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]: """Agent 节点:完整的 ReAct 循环""" queue = get_stream_queue() is_streaming = queue is not None # 获取步数 current_step = getattr(state, "current_step", 0) max_steps = getattr(state, "max_steps", 10) info(f"[Agent] 从第 {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}") # 动态获取模型 model_name = "primary" if config: configurable = config.get("configurable", {}) model_name = configurable.get("model", "primary") llm = chat_services.get(model_name) if llm is None: llm = next(iter(chat_services.values())) info(f"[Agent] 模型 '{model_name}' 不可用,使用 '{type(llm).__name__}'") llm_with_tools = llm.bind_tools(ALL_TOOLS) # 获取记忆上下文 memory_context = getattr(state, "memory_context", "暂无用户背景信息") # 组装消息(注入记忆上下文到提示词) prompt_with_memory = SYSTEM_PROMPT.format(memory_context=memory_context) messages = [SystemMessage(content=prompt_with_memory)] + list(state.messages) turn = current_step try: while turn < max_steps: turn += 1 info(f"[Agent] 第 {turn} 轮思考") if is_streaming: await queue.put({"type": "node_start", "node": "agent"}) # 选择 LLM(最后一轮不带工具) if turn >= max_steps: current_llm = llm.bind_tools([]) info(f"[Agent] 达到步数上限,使用无工具模型") else: current_llm = llm_with_tools # 初始化 full_content = "" full_reasoning_content = "" pending_tool_calls = {} final_tool_calls = [] # 循环检测:记录历史调用 tool_call_history: List[dict] = [] tool_result_history: List[str] = [] # 调用 LLM if is_streaming: async for chunk in current_llm.astream(messages): if isinstance(chunk, AIMessageChunk): if chunk.content: full_content += chunk.content await queue.put({ "type": "llm_token", "node": "agent", "token": chunk.content, "reasoning_token": "" }) if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs: reasoning = chunk.additional_kwargs.get("reasoning_content", "") if reasoning: full_reasoning_content += reasoning await queue.put({ "type": "llm_token", "node": "agent", "token": "", "reasoning_token": reasoning }) if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: for tc_chunk in chunk.tool_call_chunks: idx = tc_chunk.get("index", 0) if idx not in pending_tool_calls: pending_tool_calls[idx] = {"id": "", "name": "", "args": ""} if tc_chunk.get("id"): pending_tool_calls[idx]["id"] += tc_chunk["id"] if tc_chunk.get("name"): pending_tool_calls[idx]["name"] += tc_chunk["name"] if tc_chunk.get("args"): args_val = tc_chunk["args"] if isinstance(args_val, str): pending_tool_calls[idx]["args"] += args_val else: import json pending_tool_calls[idx]["args"] += json.dumps(args_val) else: result = await current_llm.ainvoke(messages) full_content = result.content if result.content else "" if hasattr(result, 'tool_calls') and result.tool_calls: final_tool_calls = result.tool_calls if hasattr(result, 'additional_kwargs'): full_reasoning_content = result.additional_kwargs.get("reasoning_content", "") # 整理工具调用 if is_streaming: for idx in sorted(pending_tool_calls.keys()): tc_data = pending_tool_calls[idx] if tc_data["name"]: args = {} if tc_data["args"]: try: import json args = json.loads(tc_data["args"]) except Exception as e: info(f"[Agent] 解析参数失败: {e}") final_tool_calls.append({ "id": tc_data["id"], "name": tc_data["name"], "args": args }) # 执行工具 if final_tool_calls: info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具") new_messages = [] for tc in final_tool_calls: tool_name = tc["name"] tool_args = tc["args"] tool_id = tc["id"] if is_streaming: await queue.put({ "type": "custom", "data": {"type": "tool_start", "tool": tool_name, "args": tool_args, "id": tool_id} }) # 查找并执行工具 tool_result = "" tool_found = False for tool in ALL_TOOLS: if tool.name == tool_name: tool_found = True try: tool_result = await tool.ainvoke(tool_args) except Exception as e: tool_result = f"工具调用出错: {str(e)}" error(f"[Agent] 工具 {tool_name} 调用出错: {e}") break if not tool_found: tool_result = f"未找到工具: {tool_name}" if is_streaming: await queue.put({ "type": "custom", "data": {"type": "tool_end", "tool": tool_name, "id": tool_id, "result": str(tool_result)} }) # 记录历史(用于循环检测) tool_call_history.append({"name": tool_name, "args": tool_args}) tool_result_history.append(str(tool_result)) new_messages.append(ToolMessage(content=str(tool_result), tool_call_id=tool_id, name=tool_name)) # 循环检测:相同工具 + 相似参数 + 相似结果 → 终止 if _should_stop_for_loop(tool_call_history, tool_result_history): info(f"[Agent] ⚠️ 检测到循环,强制终止") # 添加一条终止消息 messages.append(AIMessage(content="[系统] 检测到工具调用循环,已终止。")) break messages.extend(new_messages) continue else: info(f"[Agent] 第 {turn} 轮:完成,无工具调用") break # 构建响应 response_kwargs = {"content": full_content} if final_tool_calls: response_kwargs["tool_calls"] = final_tool_calls response = AIMessage(**response_kwargs) if full_reasoning_content: response.additional_kwargs["reasoning_content"] = full_reasoning_content return { "messages": [response], "current_step": turn, "llm_calls": getattr(state, "llm_calls", 0) + 1 } except Exception as e: error(f"[Agent] ❌ 第 {turn} 轮出错: {e}") import traceback error(f"[Agent] 堆栈: {traceback.format_exc()}") if is_streaming: await queue.put({"type": "error", "message": str(e)}) raise return agent_node