Files
ailine/backend/app/main_graph/nodes/agent.py
root 5b41598d50
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m41s
重构:简化流式架构,将 ReAct 循环移入 agent 节点
主要变更:
- 简化 agent_service:移除复杂双协程,只用 stream_mode=["updates"]
- stream_context:提供更清晰的 API (set_stream_queue/get_stream_queue)
- main_graph_builder:简化图结构,移除 tools 节点和条件边
- agent 节点:包含完整 ReAct 循环 + 流式 Tool Calling 拼接
- 前端:适配新的事件格式
- 添加测试文件:test_full_react_streaming.py, test_stream.py
2026-05-07 02:56:35 +08:00

289 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Agent 节点:完整的 ReAct 循环 + 流式 Tool Calling 拼接
完全参考指南实现!
"""
from typing import Dict, Any, Optional, List
from langchain_core.messages import SystemMessage, AIMessage, AIMessageChunk, ToolMessage
from langchain_core.runnables.config import RunnableConfig
from backend.app.main_graph.state import AgentState
from backend.app.logger import info, warning, error
from backend.app.agent.stream_context import get_stream_queue
from backend.app.tools import ALL_TOOLS
# 系统提示词(从 main_graph_builder.py 搬过来)
SYSTEM_PROMPT = """你是一个智能助手,可以使用多种工具完成复杂任务。你必须用中文回复。
## 核心工具与能力
你可以使用以下工具(函数),但只能在真正需要时调用,禁止无意义的测试调用或重复调用:
1. rag_search 从内部知识库中检索文档,输入为优化后的查询字符串。
2. web_search 联网搜索获取最新信息,输入为搜索关键词。
3. contact_lookup 查询企业通讯录,输入姓名、部门或邮箱等。
4. dictionary_lookup 翻译单词、查询词典或提取术语。
5. news_analysis 获取或分析新闻资讯。
## 工作流程ReAct 决策闭环)
你必须严格按照思考 → 行动 → 观察的闭环来处理每个请求,具体规则如下:
### 1. 初始决策
- 如果用户的问题很明确且你已有足够内部知识,可以直接回答,无需调用任何工具。
- 如果需要外部信息,请按以下优先级选择工具:
- 优先使用 rag_search。
- 若第一次 rag_search 返回的结果不相关或质量低,你可以改写查询关键词再次调用 rag_search最多重复一次
- 如果两次 rag_search 均无法获得满意信息,或者用户明确要求实时资讯,则必须切换为 web_search。
- 遇到通讯录、词典、新闻类明确需求,直接调用对应的专用工具。
### 2. 观察与反思
- 每次工具调用返回结果后,你必须先评估结果质量(内容是否相关、是否充分)。
- 如果信息不足,根据上述规则决定下一步行动;如果信息足够,则直接生成最终答案,绝不再调用任何工具。
- 在整个过程中,禁止使用工具返回的信息直接重复或编造来源,必须如实标注。
### 3. 结束条件
当你认为已经拥有足够信息回答用户时,输出最终回复并停止调用工具。若连续调用工具超过 5 轮仍未解决,也必须基于当前收集到的信息给出最佳回答并说明局限性。
## 回答规范
1. 来源标注:回答开头用方括号注明信息来源,如多处来源按使用顺序列出:
- 知识库:【知识库:相关文档主题】
- 联网搜索:【联网搜索:来源网站或摘要】
2. 思维链:对于需要复杂推理的问题,请将推理过程放在 <think>...</think> 标签内,并置于回答最前面(来源标注之前)。
3. 内容要求:回答应重点突出、条理清晰,优先结合用户背景信息进行个性化;若无任何可靠依据,如实说明“暂时无法回答”。
## 特别注意
- 不要向用户暴露任何工具调用的技术细节(如参数、函数名)。
- 如果用户只是闲聊、问候或道别,直接友好回复,严禁调用任何工具。
- 所有联网搜索必须以获取帮助用户为目的,不得搜索无关内容。
现在,请遵循以上规则处理用户的每一次输入。记住:思考 → 行动 → 观察 → 直到完成。"""
def create_agent_node(llm_with_tools, llm):
"""创建 Agent 节点函数,完整 ReAct 循环"""
async def agent_node(state: AgentState, config: Optional[RunnableConfig] = None) -> Dict[str, Any]:
"""
Agent 节点:完整的 ReAct 循环,带流式 token 和工具调用事件
兼容流式和非流式两种情况!
Args:
state: 当前状态
config: 运行配置
Returns:
状态更新字典
"""
# 获取队列
queue = get_stream_queue()
is_streaming = queue is not None
# 获取当前步数
current_step = getattr(state, "current_step", 0)
max_steps = getattr(state, "max_steps", 10)
info(f"[Agent] 从第 {current_step} 步开始,最大步数: {max_steps},流式: {is_streaming}")
# 组装完整消息
messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(state.messages)
turn = current_step # 轮次从当前步数开始
try:
while turn < max_steps:
turn += 1
info(f"[Agent] 第 {turn} 轮思考")
# 告诉前端:新的一轮开始(如果流式)
if is_streaming:
await queue.put({
"type": "turn_start",
"turn": turn,
})
# 选择 LLM
if turn >= max_steps:
info(f"[Agent] 达到步数上限,用不带工具的 LLM")
current_llm = llm.bind_tools([])
else:
current_llm = llm_with_tools
# 初始化变量
full_content = ""
full_reasoning_content = ""
pending_tool_calls = {} # key: index, value: {id, name, args_str}
final_tool_calls = []
# 只有流式的时候用 astream非流式直接用 ainvoke 更快!
if is_streaming:
async for chunk in current_llm.astream(messages):
if isinstance(chunk, AIMessageChunk):
# 1. 处理文本 token
if chunk.content:
full_content += chunk.content
await queue.put({
"type": "llm_token",
"turn": turn,
"phase": "answering",
"token": chunk.content,
"reasoning_token": ""
})
# 2. 处理 reasoning token
if hasattr(chunk, 'additional_kwargs') and chunk.additional_kwargs:
reasoning_content = chunk.additional_kwargs.get("reasoning_content", "")
if reasoning_content:
full_reasoning_content += reasoning_content
await queue.put({
"type": "llm_token",
"turn": turn,
"phase": "reasoning",
"token": "",
"reasoning_token": reasoning_content
})
# 3. 流式 Tool Calling 拼接逻辑(核心!用 tool_call_chunks
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = tc_chunk.get("index", 0)
if idx not in pending_tool_calls:
pending_tool_calls[idx] = {
"id": "",
"name": "",
"args": "" # 初始化为字符串
}
if tc_chunk.get("id"):
pending_tool_calls[idx]["id"] += tc_chunk["id"]
if tc_chunk.get("name"):
pending_tool_calls[idx]["name"] += tc_chunk["name"]
if tc_chunk.get("args"):
args_val = tc_chunk["args"]
if isinstance(args_val, str):
pending_tool_calls[idx]["args"] += args_val
else:
import json
pending_tool_calls[idx]["args"] += json.dumps(args_val)
else:
# 非流式,直接 ainvoke
result = await current_llm.ainvoke(messages)
full_content = result.content if result.content else ""
if hasattr(result, 'tool_calls') and result.tool_calls:
final_tool_calls = result.tool_calls
if hasattr(result, 'additional_kwargs') and result.additional_kwargs:
full_reasoning_content = result.additional_kwargs.get("reasoning_content", "")
# 流式调用结束后,整理最终的 tool_calls只在流式时处理 pending
if is_streaming:
for idx in sorted(pending_tool_calls.keys()):
tc_data = pending_tool_calls[idx]
if tc_data["name"]: # 只有有名字的才是有效工具调用
# 解析参数字符串
args = {}
if tc_data["args"]:
try:
import json
args = json.loads(tc_data["args"])
except Exception as e:
info(f"[Agent] Failed to parse args JSON: {e}, raw: {tc_data['args']}")
final_tool_calls.append({
"id": tc_data["id"],
"name": tc_data["name"],
"args": args
})
# 判断是否有工具调用
if final_tool_calls:
info(f"[Agent] 第 {turn} 轮:调用 {len(final_tool_calls)} 个工具")
# 执行工具调用
new_messages = []
for tc in final_tool_calls:
tool_name = tc["name"]
tool_args = tc["args"]
tool_id = tc["id"]
# 发送工具开始事件(如果流式)
if is_streaming:
await queue.put({
"type": "tool_start",
"turn": turn,
"tool": tool_name,
"args": tool_args,
"id": tool_id
})
# 找到并执行对应工具
tool_result = ""
tool_found = False
for tool in ALL_TOOLS:
if tool.name == tool_name:
tool_found = True
try:
tool_result = await tool.ainvoke(tool_args)
except Exception as e:
tool_result = f"工具调用出错: {str(e)}"
error(f"[Agent] 工具 {tool_name} 调用出错: {e}")
break
if not tool_found:
tool_result = f"未找到工具: {tool_name}"
# 发送工具结束事件(如果流式)
if is_streaming:
await queue.put({
"type": "tool_end",
"turn": turn,
"tool": tool_name,
"id": tool_id,
"result": str(tool_result)
})
# 构造 ToolMessage
tool_msg = ToolMessage(
content=str(tool_result),
tool_call_id=tool_id,
name=tool_name
)
new_messages.append(tool_msg)
# 添加到 messages继续下一轮
messages.extend(new_messages)
continue
else:
# 没有工具调用,最终输出
info(f"[Agent] 第 {turn} 轮:完成,无工具调用")
if is_streaming:
await queue.put({
"type": "final_answer",
"turn": turn,
"content": full_content
})
break
# 构建完整的 AIMessage 用于状态更新
response_kwargs = {"content": full_content}
if final_tool_calls:
response_kwargs["tool_calls"] = final_tool_calls
response = AIMessage(**response_kwargs)
if full_reasoning_content:
response.additional_kwargs["reasoning_content"] = full_reasoning_content
# 返回状态更新
return {
"messages": [response],
"current_step": turn,
"llm_calls": getattr(state, "llm_calls", 0) + 1
}
except Exception as e:
error(f"[Agent] ❌ 第 {turn} 轮出错: {e}")
import traceback
error(f"[Agent] 堆栈: {traceback.format_exc()}")
# 发送错误事件(如果流式)
if is_streaming:
await queue.put({
"type": "error",
"message": str(e)
})
raise
return agent_node