diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 555c6b2..142acdf 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -16,6 +16,7 @@ from app.core.intent_classifier import get_intent_classifier from app.logger import info, warning from app.main_graph.state import MainGraphState, CurrentAction + class AIAgentService: def __init__(self, checkpointer): self.checkpointer = checkpointer @@ -112,7 +113,7 @@ class AIAgentService: return str(value) async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"): - """流式处理消息,返回异步生成器(支持混合路由)""" + """流式处理消息,返回异步生成器(全部走 React 模式)""" graph = self.graphs.get(model_name) if not graph: raise ValueError(f"模型 '{model_name}' 未找到或未初始化") @@ -128,7 +129,7 @@ class AIAgentService: "current_action": CurrentAction.NONE } - # ========== 新增:混合路由 ========== + # ========== 意图识别(保留用于日志) ========== intent_result = await self.intent_classifier.classify(message) info(f"🧠 意图识别: {intent_result.intent_type} (置信度: {intent_result.confidence:.2f})") info(f"📝 推理: {intent_result.reasoning}") @@ -141,269 +142,139 @@ class AIAgentService: "reasoning": intent_result.reasoning } - # 根据意图决定路径 - use_react_loop = True - if intent_result.confidence >= 0.6: - intent_str = intent_result.intent_type.value - if intent_str in ["chitchat", "clarify"]: - use_react_loop = False - elif intent_str == "knowledge" and self.rag_pipeline: - use_react_loop = False - - # 发送路径决策事件 + # 发送路径决策事件(现在都是 react_loop) yield { "type": "path_decision", - "path": "react_loop" if use_react_loop else "fast", + "path": "react_loop", "intent": intent_result.intent_type.value } - # ==================================== + # ======================================== - if use_react_loop: - # ========== React 循环路径 ========== - current_node = None - tool_calls_in_progress = {} + # ========== React 循环路径 ========== + current_node = None + tool_calls_in_progress = {} - async for chunk in graph.astream( - input_state, - config=config, - stream_mode=["messages", "updates", "custom"], - version="v2", - subgraphs=True - ): - chunk_type = chunk["type"] - processed_event = {} + async for chunk in graph.astream( + input_state, + config=config, + stream_mode=["messages", "updates", "custom"], + version="v2", + subgraphs=True + ): + chunk_type = chunk["type"] + processed_event = {} - if chunk_type == "messages": - message_chunk, metadata = chunk["data"] - node_name = metadata.get("langgraph_node", "unknown") + if chunk_type == "messages": + message_chunk, metadata = chunk["data"] + node_name = metadata.get("langgraph_node", "unknown") - # 检测节点变化,发送节点开始事件 - if node_name != current_node: - if current_node: - yield { - "type": "node_end", - "node": current_node + # 检测节点变化,发送节点开始事件 + if node_name != current_node: + if current_node: + yield { + "type": "node_end", + "node": current_node + } + yield { + "type": "node_start", + "node": node_name + } + current_node = node_name + + # 处理消息内容 + token_content = getattr(message_chunk, 'content', str(message_chunk)) + reasoning_token = "" + if hasattr(message_chunk, 'additional_kwargs'): + reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") + + # 处理思考过程 + if reasoning_token: + processed_event = { + "type": "llm_token", + "node": node_name, + "reasoning_token": reasoning_token + } + # 处理工具调用 + elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls: + for tool_call in message_chunk.tool_calls: + tool_call_id = tool_call.get("id", "") + tool_name = tool_call.get("name", "") + tool_args = tool_call.get("args", {}) + + # 记录工具调用开始 + if tool_call_id not in tool_calls_in_progress: + tool_calls_in_progress[tool_call_id] = { + "name": tool_name, + "args": tool_args } - yield { - "type": "node_start", - "node": node_name - } - current_node = node_name + yield { + "type": "tool_call_start", + "tool": tool_name, + "args": tool_args, + "id": tool_call_id + } + # 处理普通 token + elif token_content: + processed_event = { + "type": "llm_token", + "node": node_name, + "token": token_content, + "reasoning_token": reasoning_token + } - # 处理消息内容 - token_content = getattr(message_chunk, 'content', str(message_chunk)) - reasoning_token = "" - if hasattr(message_chunk, 'additional_kwargs'): - reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") + elif chunk_type == "updates": + updates_data = chunk["data"] + serialized_data = self._serialize_value(updates_data) - # 处理思考过程 - if reasoning_token: - processed_event = { - "type": "llm_token", - "node": node_name, - "reasoning_token": reasoning_token - } - # 处理工具调用 - elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls: - for tool_call in message_chunk.tool_calls: - tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("name", "") - tool_args = tool_call.get("args", {}) + # 检查是否有人工审核请求 + if "review_pending" in serialized_data and serialized_data["review_pending"]: + review_id = serialized_data.get("review_id", "") + content_to_review = serialized_data.get("content_to_review", "") + yield { + "type": "human_review_request", + "review_id": review_id, + "content": content_to_review + } - # 记录工具调用开始 - if tool_call_id not in tool_calls_in_progress: - tool_calls_in_progress[tool_call_id] = { - "name": tool_name, - "args": tool_args - } + # 检查是否有工具结果 + if "messages" in serialized_data: + for msg in serialized_data["messages"]: + # 检测工具结果消息 + if msg.get("role") == "tool": + tool_call_id = msg.get("tool_call_id", "") + tool_name = msg.get("name", "") + tool_output = msg.get("content", "") + + if tool_call_id in tool_calls_in_progress: yield { - "type": "tool_call_start", + "type": "tool_call_end", "tool": tool_name, - "args": tool_args, - "id": tool_call_id + "id": tool_call_id, + "result": tool_output } - # 处理普通 token - elif token_content: - processed_event = { - "type": "llm_token", - "node": node_name, - "token": token_content, # ✅ 改为 token - "reasoning_token": reasoning_token - } + del tool_calls_in_progress[tool_call_id] - elif chunk_type == "updates": - updates_data = chunk["data"] - serialized_data = self._serialize_value(updates_data) - - # 检查是否有人工审核请求 - if "review_pending" in serialized_data and serialized_data["review_pending"]: - review_id = serialized_data.get("review_id", "") - content_to_review = serialized_data.get("content_to_review", "") - yield { - "type": "human_review_request", - "review_id": review_id, - "content": content_to_review - } - - # 检查是否有工具结果 - if "messages" in serialized_data: - for msg in serialized_data["messages"]: - # 检测工具结果消息 - if msg.get("role") == "tool": - tool_call_id = msg.get("tool_call_id", "") - tool_name = msg.get("name", "") - tool_output = msg.get("content", "") - - if tool_call_id in tool_calls_in_progress: - yield { - "type": "tool_call_end", - "tool": tool_name, - "id": tool_call_id, - "result": tool_output - } - del tool_calls_in_progress[tool_call_id] - - processed_event = { - "type": "state_update", - "data": serialized_data - } - - elif chunk_type == "custom": - serialized_data = self._serialize_value(chunk["data"]) - processed_event = { - "type": "custom", - "data": serialized_data - } - - if processed_event: - yield processed_event - - # 发送结束事件 - if current_node: - yield { - "type": "node_end", - "node": current_node + processed_event = { + "type": "state_update", + "data": serialized_data } + + elif chunk_type == "custom": + serialized_data = self._serialize_value(chunk["data"]) + processed_event = { + "type": "custom", + "data": serialized_data + } + + if processed_event: + yield processed_event + + # 发送结束事件 + if current_node: yield { - "type": "done" + "type": "node_end", + "node": current_node } - - else: - # ========== 快速路径 ========== - intent_str = intent_result.intent_type.value - - if intent_str == "chitchat": - # 闲聊直接回答 - reply = await self._generate_fast_reply( - message, - "你是一个友好的助手,请礼貌回应用户的问候或闲聊。" - ) - for char in reply: - yield { - "type": "llm_token", - "node": "fast_path", - "token": char # ✅ 改为 token - } - await asyncio.sleep(0.03) - - elif intent_str == "clarify": - # 澄清反问 - reply = await self._generate_fast_reply( - message, - "用户的问题不够明确,请礼貌地询问更多细节,以便更好地帮助用户。" - ) - for char in reply: - yield { - "type": "llm_token", - "node": "fast_path", - "token": char # ✅ 改为 token - } - await asyncio.sleep(0.03) - - elif intent_str == "knowledge" and self.rag_pipeline: - # 快速 RAG - yield { - "type": "node_start", - "node": "fast_rag" - } - yield { - "type": "reasoning", - "node": "fast_rag", - "content": "正在查询知识库..." - } - - # 模拟 RAG 检索 - await asyncio.sleep(0.3) - - # 使用 RAG 生成回答 - reply = await self._generate_rag_reply(message) - - yield { - "type": "node_end", - "node": "fast_rag" - } - - for char in reply: - yield { - "type": "llm_token", - "node": "fast_path", - "token": char # ✅ 改为 token - } - await asyncio.sleep(0.03) - - else: - # 兜底:直接回答 - reply = await self._generate_fast_reply( - message, - "请简洁回答用户的问题。" - ) - for char in reply: - yield { - "type": "llm_token", - "node": "fast_path", - "token": char # ✅ 改为 token - } - await asyncio.sleep(0.03) - - yield { - "type": "done" - } - - async def _generate_fast_reply(self, message: str, system_prompt: str) -> str: - """快速生成回复(不经过 React 循环)""" - # 使用默认模型生成回复 - model_name = next(iter(self.graphs.keys()), "zhipu") - llm = get_all_chat_services().get(model_name) - - if not llm: - return "抱歉,服务暂时不可用。" - - prompt = f"{system_prompt}\n\n用户: {message}" - response = await llm.ainvoke(prompt) - return response.content if hasattr(response, 'content') else str(response) - - async def _generate_rag_reply(self, message: str) -> str: - """使用 RAG 生成回复""" - if not self.rag_pipeline: - return await self._generate_fast_reply(message, "请简洁回答用户的问题。") - - # 检索文档 - docs = await self.rag_pipeline.aretrieve(message) - context = self.rag_pipeline.format_context(docs) - - # 生成回答 - model_name = next(iter(self.graphs.keys()), "zhipu") - llm = get_all_chat_services().get(model_name) - - if not llm: - return "抱歉,服务暂时不可用。" - - prompt = f"""请根据以下参考文档回答用户问题。 - -参考文档: -{context or "(无相关文档)"} - -用户问题: {message} -""" - response = await llm.ainvoke(prompt) - return response.content if hasattr(response, 'content') else str(response) \ No newline at end of file + yield { + "type": "done" + } \ No newline at end of file diff --git a/frontend/src/components/chat_area.py b/frontend/src/components/chat_area.py index 207ad9a..4711773 100644 --- a/frontend/src/components/chat_area.py +++ b/frontend/src/components/chat_area.py @@ -143,7 +143,7 @@ def _handle_ai_response(): # 1. 处理 LLM Token 流 (打字机效果) if event_type == "llm_token": # 确保只处理来自 LLM 的 token,避免将工具的输出作为 token 显示 - if event.get("node") in ("llm_call", "fallback", "fast_path"): + if event.get("node") in ("llm_call", "fallback"): token = str(event.get("token", "")) reasoning_token = str(event.get("reasoning_token", ""))