重构:增强 JSON 解析稳定性,优化 Prompt,改进状态结构
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m36s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m36s
主要改进: 1. 新增 json_parser.py - 统一的 JSON 解析工具 - 支持多种格式(纯 JSON、markdown、文本中的 JSON) - 多层 fallback 策略 - 安全的字段提取函数 2. 优化 intent.py 和 hybrid_router.py - 使用新的 json_parser - 优化 Prompt,更清晰的格式要求 - 更好的错误处理 3. 改进 state.py - 新增结构化状态字段 - ReactReasoningState、HybridRouterState、FastPathState - 向后兼容旧的 debug_info 4. 更新各节点模块 - 同时更新旧字段保持兼容 - reasoning.py - 更新 state.react_reasoning - hybrid_router.py - 更新 state.hybrid_router - fast_paths.py - 更新 state.fast_path
This commit is contained in:
@@ -1,18 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
意图理解与推理模块 (React 模式)
|
意图理解与推理模块(React 模式)
|
||||||
Intent Understanding & Reasoning Module (React Pattern)
|
|
||||||
|
|
||||||
这个模块实现了 React (Reasoning + Acting) 模式的意图理解节点,用于:
|
核心改进:
|
||||||
1. 理解用户的查询意图
|
1. 使用统一的 JSON 解析器,保证稳定性
|
||||||
2. 判断是否需要调用 RAG 检索
|
2. 优化 Prompt,更清晰的指令
|
||||||
3. 判断是否需要重新检索
|
3. 更好的错误处理和降级策略
|
||||||
4. 决定下一步的动作(路由到子图、直接回答等)
|
|
||||||
|
|
||||||
核心设计:
|
|
||||||
- 使用项目已有的 chat_services.py 进行 LLM 调用
|
|
||||||
- 保持与现有架构一致(服务层模式)
|
|
||||||
- 支持降级策略(LLM 失败时回退到规则)
|
|
||||||
- 与 react_nodes.py 无缝集成
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
@@ -21,6 +13,13 @@ from typing import Dict, Any, Optional, List
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
|
from backend.app.core.json_parser import (
|
||||||
|
extract_and_parse_json,
|
||||||
|
safe_get,
|
||||||
|
safe_get_float,
|
||||||
|
safe_get_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ========== 1. 核心数据类型 ==========
|
# ========== 1. 核心数据类型 ==========
|
||||||
|
|
||||||
@@ -208,99 +207,138 @@ class ReactIntentReasoner:
|
|||||||
return self._parse_llm_response(response.content, query)
|
return self._parse_llm_response(response.content, query)
|
||||||
|
|
||||||
def _build_reasoning_prompt(self, query: str, context: Dict[str, Any]) -> str:
|
def _build_reasoning_prompt(self, query: str, context: Dict[str, Any]) -> str:
|
||||||
"""构建推理提示词"""
|
"""
|
||||||
|
构建推理提示词(优化版)
|
||||||
|
|
||||||
|
改进点:
|
||||||
|
1. 更清晰的指令和格式要求
|
||||||
|
2. 明确要求纯 JSON 输出,不要 markdown
|
||||||
|
3. 更好的示例和决策规则
|
||||||
|
"""
|
||||||
# 构建上下文描述
|
# 构建上下文描述
|
||||||
context_parts = []
|
context_parts = []
|
||||||
if context.get("retrieved_docs"):
|
if context.get("retrieved_docs"):
|
||||||
context_parts.append(f"- 已检索文档: {len(context['retrieved_docs'])} 条")
|
context_parts.append(f"- 已检索文档: {len(context['retrieved_docs'])} 条")
|
||||||
if context.get("rag_confidence") is not None:
|
rag_confidence = context.get("rag_confidence")
|
||||||
context_parts.append(f"- RAG 置信度: {context['rag_confidence']:.2f}")
|
if rag_confidence is not None:
|
||||||
if context.get("rag_attempts"):
|
context_parts.append(f"- RAG 置信度: {rag_confidence:.2f}")
|
||||||
context_parts.append(f"- RAG 尝试次数: {context['rag_attempts']}")
|
rag_attempts = context.get("rag_attempts", 0)
|
||||||
if context.get("previous_actions"):
|
if rag_attempts:
|
||||||
context_parts.append(f"- 历史动作: {context['previous_actions']}")
|
context_parts.append(f"- RAG 尝试次数: {rag_attempts}")
|
||||||
|
previous_actions = context.get("previous_actions", [])
|
||||||
|
if previous_actions:
|
||||||
|
context_parts.append(f"- 历史动作: {previous_actions}")
|
||||||
|
|
||||||
context_str = "\n".join(context_parts) if context_parts else "无"
|
context_str = "\n".join(context_parts) if context_parts else "无"
|
||||||
|
|
||||||
|
return f"""你是一个决策控制器。你需要根据当前状态决定下一步操作。
|
||||||
|
|
||||||
return f"""你是一个专业的意图推理助手。请分析用户的查询,决定下一步应该做什么。
|
【格式要求】
|
||||||
|
你必须严格输出 JSON 格式,不要加任何 Markdown 代码块标记(如 ```json)。
|
||||||
|
仅输出纯 JSON 字符串,不要有其他解释文字。
|
||||||
|
|
||||||
可选动作:
|
【可用动作】
|
||||||
1. DIRECT_RESPONSE - 直接回答(闲聊、打招呼、不需要额外信息,或已有足够信息)
|
1. DIRECT_RESPONSE - 直接回答(已有足够信息,不需要额外工具)
|
||||||
2. RETRIEVE_RAG - 需要查询知识库(询问知识、政策、文档等)
|
2. RETRIEVE_RAG - 检索知识库(需要查询相关知识)
|
||||||
3. RE_RETRIEVE_RAG - 需要重新检索(之前的结果不够,或者用户明确说"再查查"、"更多")
|
3. RE_RETRIEVE_RAG - 重新检索(已有结果不够,需要再次尝试)
|
||||||
4. WEB_SEARCH - 需要联网搜索(询问最新资讯、热点、实时信息、知识库中没有的内容)
|
4. WEB_SEARCH - 联网搜索(需要最新资讯或知识库没有的内容)
|
||||||
5. ROUTE_SUBGRAPH - 需要路由到专门的子图:
|
5. ROUTE_SUBGRAPH - 路由到子图(通讯录/词典/资讯分析)
|
||||||
- contact: 通讯录、联系人、邮件相关
|
6. CLARIFY - 澄清问题(问题不明确,需要用户补充)
|
||||||
- dictionary: 词典、翻译、单词相关
|
|
||||||
- news_analysis: 资讯、新闻、热点分析相关
|
|
||||||
6. CLARIFY - 需要澄清用户的问题(问题不明确)
|
|
||||||
|
|
||||||
判断规则:
|
【动作参数说明】
|
||||||
- 如果 RAG 置信度 >= 0.6 且有检索文档,应返回 DIRECT_RESPONSE
|
每个动作需要的参数:
|
||||||
- 如果 RAG 置信度 < 0.6 且尝试次数 < 2,可返回 RETRIEVE_RAG 再试一次
|
- RETRIEVE_RAG: {{"retrieval_query": "优化后的检索查询字符串"}}
|
||||||
- 如果 RAG 置信度 < 0.6 且尝试次数 >= 2,应返回 WEB_SEARCH
|
- RE_RETRIEVE_RAG: {{"retrieval_query": "优化后的检索查询字符串"}}
|
||||||
- 如果已联网搜索过,应返回 DIRECT_RESPONSE
|
- WEB_SEARCH: {{"search_query": "优化后的搜索查询字符串"}}
|
||||||
|
- ROUTE_SUBGRAPH: {{"target_subgraph": "contact|dictionary|news_analysis"}}
|
||||||
|
- DIRECT_RESPONSE/CLARIFY: {{}}(无需参数)
|
||||||
|
|
||||||
|
【决策规则】
|
||||||
|
1. 如果 RAG 置信度 >= 0.6 且有检索文档,使用 DIRECT_RESPONSE
|
||||||
|
2. 如果 RAG 置信度 < 0.6 且尝试次数 < 2,使用 RETRIEVE_RAG/RE_RETRIEVE_RAG
|
||||||
|
3. 如果 RAG 置信度 < 0.6 且尝试次数 >= 2,使用 WEB_SEARCH
|
||||||
|
4. 如果已执行过联网搜索,使用 DIRECT_RESPONSE
|
||||||
|
5. 如果问题涉及通讯录/词典/资讯分析,使用 ROUTE_SUBGRAPH
|
||||||
|
6. 如果问题不明确,使用 CLARIFY
|
||||||
|
|
||||||
|
【输出格式】
|
||||||
|
{{
|
||||||
|
"action": "动作名称(大写)",
|
||||||
|
"confidence": 0.85,
|
||||||
|
"reasoning": "简要说明决策理由",
|
||||||
|
"target_subgraph": "contact|dictionary|news_analysis|null",
|
||||||
|
"retrieval_query": "优化后的检索查询(可选)",
|
||||||
|
"search_query": "优化后的搜索查询(可选)"
|
||||||
|
}}
|
||||||
|
|
||||||
|
【重要提示】
|
||||||
|
- target_subgraph 仅在 action=ROUTE_SUBGRAPH 时提供,否则设为 null 或不包含
|
||||||
|
- retrieval_query 仅在 action=RETRIEVE_RAG/RE_RETRIEVE_RAG 时提供
|
||||||
|
- search_query 仅在 action=WEB_SEARCH 时提供
|
||||||
|
- confidence 是你对当前决策的信心(0.0-1.0)
|
||||||
|
|
||||||
|
【当前状态】
|
||||||
用户查询: {query}
|
用户查询: {query}
|
||||||
当前上下文:
|
当前上下文:
|
||||||
{context_str}
|
{context_str}
|
||||||
|
|
||||||
请按以下 JSON 格式输出(仅输出 JSON,不要其他内容):
|
【现在开始】
|
||||||
{{
|
请根据以上信息,输出你的决策 JSON:"""
|
||||||
"action": "DIRECT_RESPONSE|RETRIEVE_RAG|RE_RETRIEVE_RAG|WEB_SEARCH|ROUTE_SUBGRAPH|CLARIFY",
|
|
||||||
"confidence": 0.85,
|
|
||||||
"reasoning": "简要说明理由",
|
|
||||||
"target_subgraph": "contact|dictionary|news_analysis|null (仅当 action=ROUTE_SUBGRAPH 时)",
|
|
||||||
"retrieval_query": "优化后的检索查询 (可选)",
|
|
||||||
"search_query": "优化后的搜索查询 (仅当 action=WEB_SEARCH 时)"
|
|
||||||
}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _parse_llm_response(self, response: str, original_query: str) -> ReasoningResult:
|
def _parse_llm_response(self, response: str, original_query: str) -> ReasoningResult:
|
||||||
"""解析 LLM 响应"""
|
"""
|
||||||
|
解析 LLM 响应(优化版)
|
||||||
|
|
||||||
|
使用统一的 JSON 解析器,支持多种格式
|
||||||
|
"""
|
||||||
result = ReasoningResult(original_query=original_query)
|
result = ReasoningResult(original_query=original_query)
|
||||||
|
|
||||||
# 提取 JSON
|
# 使用新的 JSON 解析器
|
||||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
parse_result = extract_and_parse_json(response)
|
||||||
if not json_match:
|
|
||||||
# 没有 JSON,回退到规则
|
if not parse_result.success or not parse_result.data:
|
||||||
|
# 解析失败,使用规则推理降级
|
||||||
|
result.action = ReasoningAction.UNKNOWN
|
||||||
result.confidence = 0.0
|
result.confidence = 0.0
|
||||||
|
result.reasoning = f"LLM 响应解析失败: {parse_result.error or '未知错误'}"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
data = parse_result.data
|
||||||
|
|
||||||
|
# 安全地提取字段
|
||||||
|
action_str = safe_get_str(data, "action", "UNKNOWN")
|
||||||
|
confidence = safe_get_float(data, "confidence", 0.5)
|
||||||
|
reasoning = safe_get_str(data, "reasoning", "")
|
||||||
|
target_subgraph = safe_get_str(data, "target_subgraph", None)
|
||||||
|
retrieval_query = safe_get_str(data, "retrieval_query", original_query)
|
||||||
|
search_query = safe_get_str(data, "search_query", original_query)
|
||||||
|
|
||||||
|
# 转换为枚举
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_match.group())
|
result.action = ReasoningAction[action_str]
|
||||||
action_str = data.get("action", "UNKNOWN")
|
except (KeyError, ValueError):
|
||||||
|
result.action = ReasoningAction.UNKNOWN
|
||||||
# 转换为枚举
|
|
||||||
try:
|
result.confidence = confidence
|
||||||
result.action = ReasoningAction[action_str]
|
result.reasoning = reasoning
|
||||||
except KeyError:
|
|
||||||
result.action = ReasoningAction.UNKNOWN
|
# 处理子图路由
|
||||||
|
if result.action == ReasoningAction.ROUTE_SUBGRAPH and target_subgraph:
|
||||||
result.confidence = float(data.get("confidence", 0.5))
|
result.retrieval_config.target_subgraph = target_subgraph
|
||||||
result.reasoning = data.get("reasoning", "")
|
result.metadata["target_subgraph"] = target_subgraph
|
||||||
|
|
||||||
# 处理子图路由
|
# 处理检索查询
|
||||||
if result.action == ReasoningAction.ROUTE_SUBGRAPH:
|
if result.action in (ReasoningAction.RETRIEVE_RAG, ReasoningAction.RE_RETRIEVE_RAG):
|
||||||
result.retrieval_config.target_subgraph = data.get("target_subgraph")
|
result.retrieval_config.need_retrieval = True
|
||||||
result.metadata["target_subgraph"] = data.get("target_subgraph")
|
result.retrieval_config.need_re_retrieval = (result.action == ReasoningAction.RE_RETRIEVE_RAG)
|
||||||
|
result.retrieval_config.retrieval_query = retrieval_query
|
||||||
# 处理检索查询
|
|
||||||
if result.action in [ReasoningAction.RETRIEVE_RAG, ReasoningAction.RE_RETRIEVE_RAG]:
|
# 处理联网搜索
|
||||||
result.retrieval_config.need_retrieval = True
|
if result.action == ReasoningAction.WEB_SEARCH:
|
||||||
result.retrieval_config.need_re_retrieval = (result.action == ReasoningAction.RE_RETRIEVE_RAG)
|
result.metadata["need_web_search"] = True
|
||||||
result.retrieval_config.retrieval_query = data.get("retrieval_query", original_query)
|
result.metadata["search_query"] = search_query
|
||||||
|
|
||||||
# 处理联网搜索
|
return result
|
||||||
if result.action == ReasoningAction.WEB_SEARCH:
|
|
||||||
result.metadata["need_web_search"] = True
|
|
||||||
result.metadata["search_query"] = data.get("search_query", original_query)
|
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ReactReasoner] 解析 LLM 响应失败: {e}")
|
|
||||||
result.confidence = 0.0
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _reason_with_rules(
|
def _reason_with_rules(
|
||||||
self,
|
self,
|
||||||
|
|||||||
203
backend/app/core/json_parser.py
Normal file
203
backend/app/core/json_parser.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""
|
||||||
|
统一的 JSON 解析工具,保证 LLM JSON 输出的稳定性
|
||||||
|
|
||||||
|
处理各种边界情况:
|
||||||
|
1. 纯 JSON 字符串
|
||||||
|
2. JSON 在 markdown 代码块中
|
||||||
|
3. JSON 在文本中间
|
||||||
|
4. JSON 有多余的逗号
|
||||||
|
5. JSON 有尾随内容
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from typing import TypeVar, Type, Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from json import JSONDecodeError
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParseResult:
|
||||||
|
"""JSON 解析结果"""
|
||||||
|
success: bool
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
raw_response: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_and_parse_json(
|
||||||
|
response: str,
|
||||||
|
schema: Optional[Dict[str, Any]] = None
|
||||||
|
) -> ParseResult:
|
||||||
|
"""
|
||||||
|
从 LLM 响应中提取并解析 JSON,使用多种策略处理边界情况
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM 的原始响应
|
||||||
|
schema: 可选的 JSON Schema(预留,暂未使用)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParseResult: 解析结果
|
||||||
|
"""
|
||||||
|
result = ParseResult(raw_response=response, success=False)
|
||||||
|
|
||||||
|
# 前置清理
|
||||||
|
cleaned = response.strip()
|
||||||
|
if not cleaned:
|
||||||
|
result.error = "响应为空"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 策略1:尝试直接解析完整响应
|
||||||
|
try:
|
||||||
|
data = json.loads(cleaned)
|
||||||
|
result.data = data
|
||||||
|
result.success = True
|
||||||
|
return result
|
||||||
|
except JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 策略2:尝试匹配 markdown 代码块(优先)
|
||||||
|
codeblock_patterns = [
|
||||||
|
r'```(?:json)?\s*([\s\S]*?)\s*```', # ```json ... ```
|
||||||
|
r'```([\s\S]*?)```', # ``` ... ```
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in codeblock_patterns:
|
||||||
|
match = re.search(pattern, cleaned)
|
||||||
|
if match:
|
||||||
|
json_str = match.group(1).strip()
|
||||||
|
if json_str:
|
||||||
|
try:
|
||||||
|
data = json.loads(json_str)
|
||||||
|
result.data = data
|
||||||
|
result.success = True
|
||||||
|
return result
|
||||||
|
except JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 策略3:提取最外层的完整 {} 块(处理嵌套)
|
||||||
|
json_match = _extract_outermost_json(cleaned)
|
||||||
|
if json_match:
|
||||||
|
try:
|
||||||
|
data = json.loads(json_match)
|
||||||
|
result.data = data
|
||||||
|
result.success = True
|
||||||
|
return result
|
||||||
|
except JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 策略4:尝试修复常见问题
|
||||||
|
try:
|
||||||
|
# 去除多余的尾随逗号
|
||||||
|
fixed = re.sub(r',\s*([}\]])', r'\1', cleaned)
|
||||||
|
# 提取第一个 { 到最后一个 } 的内容
|
||||||
|
first_brace = fixed.find('{')
|
||||||
|
last_brace = fixed.rfind('}')
|
||||||
|
if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
|
||||||
|
json_str = fixed[first_brace:last_brace+1]
|
||||||
|
data = json.loads(json_str)
|
||||||
|
result.data = data
|
||||||
|
result.success = True
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 所有策略都失败
|
||||||
|
result.error = f"无法从响应中提取有效 JSON: {cleaned[:200]}..."
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_outermost_json(text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
提取最外层的完整 JSON 块(处理嵌套)
|
||||||
|
|
||||||
|
使用栈方法,正确处理嵌套的 {}
|
||||||
|
"""
|
||||||
|
stack = []
|
||||||
|
start_idx = -1
|
||||||
|
|
||||||
|
for i, char in enumerate(text):
|
||||||
|
if char == '{':
|
||||||
|
if not stack:
|
||||||
|
start_idx = i
|
||||||
|
stack.append('{')
|
||||||
|
elif char == '}':
|
||||||
|
if stack:
|
||||||
|
stack.pop()
|
||||||
|
if not stack and start_idx != -1:
|
||||||
|
# 找到完整的外层块
|
||||||
|
return text[start_idx:i+1]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_to_dataclass(
|
||||||
|
response: str,
|
||||||
|
dataclass_type: Type[T],
|
||||||
|
default_factory: callable
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
解析 JSON 并转换为 dataclass 实例,失败时返回默认值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM 响应
|
||||||
|
dataclass_type: 目标 dataclass 类型
|
||||||
|
default_factory: 生成默认值的工厂函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
T: dataclass 实例
|
||||||
|
"""
|
||||||
|
parse_result = extract_and_parse_json(response)
|
||||||
|
|
||||||
|
if not parse_result.success or not parse_result.data:
|
||||||
|
return default_factory()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return dataclass_type(**parse_result.data)
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
# 字段不匹配时尝试降级
|
||||||
|
return default_factory()
|
||||||
|
|
||||||
|
|
||||||
|
def safe_get(data: Dict[str, Any], key: str, default: Any = None) -> Any:
|
||||||
|
"""安全地从字典中获取值"""
|
||||||
|
if not data or not isinstance(data, dict):
|
||||||
|
return default
|
||||||
|
return data.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_get_bool(data: Dict[str, Any], key: str, default: bool = False) -> bool:
|
||||||
|
"""安全地获取布尔值"""
|
||||||
|
value = safe_get(data, key, default)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.lower() in ('true', '1', 'yes', 'on')
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return bool(value)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def safe_get_float(data: Dict[str, Any], key: str, default: float = 0.0) -> float:
|
||||||
|
"""安全地获取浮点值"""
|
||||||
|
value = safe_get(data, key, default)
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def safe_get_int(data: Dict[str, Any], key: str, default: int = 0) -> int:
|
||||||
|
"""安全地获取整数值"""
|
||||||
|
value = safe_get(data, key, default)
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def safe_get_str(data: Dict[str, Any], key: str, default: str = "") -> str:
|
||||||
|
"""安全地获取字符串值"""
|
||||||
|
value = safe_get(data, key, default)
|
||||||
|
return str(value) if value is not None else default
|
||||||
@@ -45,6 +45,9 @@ async def fast_chitchat_node(state: MainGraphState, config: Optional[RunnableCon
|
|||||||
state.success = True
|
state.success = True
|
||||||
state.current_phase = "llm_call"
|
state.current_phase = "llm_call"
|
||||||
state.debug_info["fast_chitchat_success"] = True
|
state.debug_info["fast_chitchat_success"] = True
|
||||||
|
|
||||||
|
# 更新新的结构化字段
|
||||||
|
state.fast_path.chitchat_success = True
|
||||||
|
|
||||||
# 发送完成事件
|
# 发送完成事件
|
||||||
await dispatch_custom_event("fast_path_end", {"path": "fast_chitchat", "success": True}, config)
|
await dispatch_custom_event("fast_path_end", {"path": "fast_chitchat", "success": True}, config)
|
||||||
@@ -200,6 +203,11 @@ def _mark_fast_path_failed(state: MainGraphState, reason: str = "") -> MainGraph
|
|||||||
state.debug_info["fast_path_failed"] = True
|
state.debug_info["fast_path_failed"] = True
|
||||||
state.debug_info["fast_path_fail_reason"] = reason
|
state.debug_info["fast_path_fail_reason"] = reason
|
||||||
state.success = False
|
state.success = False
|
||||||
|
|
||||||
|
# 更新新的结构化字段
|
||||||
|
state.fast_path.failed = True
|
||||||
|
state.fast_path.fail_reason = reason
|
||||||
|
|
||||||
info(f"[Fast Path] 标记失败,准备升级: {reason}")
|
info(f"[Fast Path] 标记失败,准备升级: {reason}")
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from ..state import MainGraphState
|
|||||||
from backend.app.logger import info, debug
|
from backend.app.logger import info, debug
|
||||||
from ...model_services.chat_services import get_small_llm_service
|
from ...model_services.chat_services import get_small_llm_service
|
||||||
from ._utils import dispatch_custom_event
|
from ._utils import dispatch_custom_event
|
||||||
|
from backend.app.core.json_parser import extract_and_parse_json, safe_get, safe_get_float, safe_get_str
|
||||||
|
|
||||||
|
|
||||||
# ========== 核心数据类型 ==========
|
# ========== 核心数据类型 ==========
|
||||||
@@ -44,24 +45,34 @@ SUBGRAPH_KEYWORDS = {
|
|||||||
# ========== 意图分类 Prompt 模板 ==========
|
# ========== 意图分类 Prompt 模板 ==========
|
||||||
INTENT_CLASSIFICATION_PROMPT = """你是一个专业的意图分类助手。请分析用户的查询,并输出 JSON 格式的结果。
|
INTENT_CLASSIFICATION_PROMPT = """你是一个专业的意图分类助手。请分析用户的查询,并输出 JSON 格式的结果。
|
||||||
|
|
||||||
意图类型(4选一):
|
【格式要求】
|
||||||
|
你必须严格输出 JSON 格式,不要加任何 Markdown 代码块标记(如 ```json)。
|
||||||
|
仅输出纯 JSON 字符串,不要有其他解释文字。
|
||||||
|
|
||||||
|
【意图类型(4选一):
|
||||||
- chitchat: 闲聊、问候、感谢、道别(不需要工具)
|
- chitchat: 闲聊、问候、感谢、道别(不需要工具)
|
||||||
- knowledge: 知识查询(需要查询知识库)
|
- knowledge: 知识查询(需要查询知识库)
|
||||||
- tool: 工具操作(需要调用通讯录/词典/新闻等子图)
|
- tool: 工具操作(需要调用通讯录/词典/新闻等子图)
|
||||||
- complex: 复杂任务(多步骤、不确定、或需要推理)
|
- complex: 复杂任务(多步骤、不确定、或需要推理)
|
||||||
|
|
||||||
用户查询:
|
【输出格式】
|
||||||
{query}
|
|
||||||
|
|
||||||
输出格式(仅 JSON,不要其他内容):
|
|
||||||
{{
|
{{
|
||||||
"intent": "chitchat|knowledge|tool|complex",
|
"intent": "chitchat|knowledge|tool|complex",
|
||||||
"confidence": 0.0-1.0,
|
"confidence": 0.85,
|
||||||
"reasoning": "简要说明理由",
|
"reasoning": "简要说明理由",
|
||||||
"suggested_tools": ["contact|dictionary|news_analysis", "other"]
|
"suggested_tools": ["contact|dictionary|news_analysis", "other"]
|
||||||
}}
|
}}
|
||||||
|
|
||||||
注意:如果不能100%确定意图,请选择 "complex",置信度设低一些。"""
|
【重要提示】
|
||||||
|
- 如果不能100%确定意图,请选择 "complex",置信度设低一些。
|
||||||
|
- confidence 是你对当前分类的信心(0.0-1.0)。
|
||||||
|
- suggested_tools 仅在 intent=tool 时提供,否则设为空数组。
|
||||||
|
|
||||||
|
【用户查询】
|
||||||
|
{query}
|
||||||
|
|
||||||
|
【现在开始】
|
||||||
|
请根据以上信息,输出你的分类 JSON:"""
|
||||||
|
|
||||||
|
|
||||||
# ========== 规则分流(<5ms) ==========
|
# ========== 规则分流(<5ms) ==========
|
||||||
@@ -109,13 +120,12 @@ async def _classify_with_llm(query: str) -> HybridRouterResult:
|
|||||||
prompt = INTENT_CLASSIFICATION_PROMPT.format(query=query)
|
prompt = INTENT_CLASSIFICATION_PROMPT.format(query=query)
|
||||||
response = await llm.ainvoke(prompt)
|
response = await llm.ainvoke(prompt)
|
||||||
|
|
||||||
# 解析 JSON
|
# 使用新的 JSON 解析器
|
||||||
json_match = re.search(r'\{[\s\S]*?\}', response.content)
|
parse_result = extract_and_parse_json(response.content)
|
||||||
if not json_match:
|
if not parse_result.success or not parse_result.data:
|
||||||
return _default_result()
|
return _default_result()
|
||||||
|
|
||||||
data = json.loads(json_match.group())
|
return _parse_classification_result(parse_result.data)
|
||||||
return _parse_classification_result(data)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug(f"LLM 分类失败: {e}")
|
debug(f"LLM 分类失败: {e}")
|
||||||
@@ -124,8 +134,10 @@ async def _classify_with_llm(query: str) -> HybridRouterResult:
|
|||||||
|
|
||||||
def _parse_classification_result(data: dict) -> HybridRouterResult:
|
def _parse_classification_result(data: dict) -> HybridRouterResult:
|
||||||
"""解析分类结果"""
|
"""解析分类结果"""
|
||||||
intent = data.get("intent", "complex")
|
intent = safe_get_str(data, "intent", "complex")
|
||||||
confidence = float(data.get("confidence", 0.3))
|
confidence = safe_get_float(data, "confidence", 0.3)
|
||||||
|
suggested_tools = safe_get(data, "suggested_tools", [])
|
||||||
|
reasoning = safe_get_str(data, "reasoning", "")
|
||||||
|
|
||||||
# 置信度低于阈值,走 complex
|
# 置信度低于阈值,走 complex
|
||||||
if confidence < 0.5:
|
if confidence < 0.5:
|
||||||
@@ -141,9 +153,9 @@ def _parse_classification_result(data: dict) -> HybridRouterResult:
|
|||||||
return HybridRouterResult(
|
return HybridRouterResult(
|
||||||
intent=intent,
|
intent=intent,
|
||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
suggested_tools=data.get("suggested_tools", []),
|
suggested_tools=suggested_tools,
|
||||||
path=path_map.get(intent, "react_loop"),
|
path=path_map.get(intent, "react_loop"),
|
||||||
reasoning=data.get("reasoning", "")
|
reasoning=reasoning
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -175,7 +187,7 @@ async def hybrid_router_node(state: MainGraphState, config: Optional[RunnableCon
|
|||||||
info("[Hybrid Router] 规则未命中,使用 LLM 分类")
|
info("[Hybrid Router] 规则未命中,使用 LLM 分类")
|
||||||
decision = await _classify_with_llm(query)
|
decision = await _classify_with_llm(query)
|
||||||
|
|
||||||
# 3. 更新状态
|
# 3. 更新状态(同时更新旧的 debug_info 和新的结构化字段)
|
||||||
state.debug_info["hybrid_decision"] = {
|
state.debug_info["hybrid_decision"] = {
|
||||||
"intent": decision.intent,
|
"intent": decision.intent,
|
||||||
"confidence": decision.confidence,
|
"confidence": decision.confidence,
|
||||||
@@ -184,6 +196,10 @@ async def hybrid_router_node(state: MainGraphState, config: Optional[RunnableCon
|
|||||||
"suggested_tools": decision.suggested_tools
|
"suggested_tools": decision.suggested_tools
|
||||||
}
|
}
|
||||||
state.debug_info["hybrid_start_time"] = datetime.now().isoformat()
|
state.debug_info["hybrid_start_time"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
# 更新新的结构化字段
|
||||||
|
state.hybrid_router.decision = decision
|
||||||
|
state.hybrid_router.start_time = datetime.now().isoformat()
|
||||||
|
|
||||||
# 4. 发送事件
|
# 4. 发送事件
|
||||||
await dispatch_custom_event("intent_classified", {
|
await dispatch_custom_event("intent_classified", {
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ async def react_reason_node(state: MainGraphState, config: Optional[RunnableConf
|
|||||||
"timestamp": datetime.now().isoformat()
|
"timestamp": datetime.now().isoformat()
|
||||||
})
|
})
|
||||||
|
|
||||||
# 步骤4: 更新调试信息
|
# 步骤4: 更新调试信息(同时更新旧的 debug_info 和新的结构化字段)
|
||||||
state.debug_info["last_reasoning"] = {
|
state.debug_info["last_reasoning"] = {
|
||||||
"action": result.action.name,
|
"action": result.action.name,
|
||||||
"confidence": result.confidence,
|
"confidence": result.confidence,
|
||||||
@@ -55,6 +55,14 @@ async def react_reason_node(state: MainGraphState, config: Optional[RunnableConf
|
|||||||
}
|
}
|
||||||
state.debug_info["reasoning_result"] = result
|
state.debug_info["reasoning_result"] = result
|
||||||
state.last_action = result.action.name
|
state.last_action = result.action.name
|
||||||
|
|
||||||
|
# 更新新的结构化字段
|
||||||
|
state.react_reasoning.last_reasoning = {
|
||||||
|
"action": result.action.name,
|
||||||
|
"confidence": result.confidence,
|
||||||
|
"reasoning": result.reasoning
|
||||||
|
}
|
||||||
|
state.react_reasoning.reasoning_result = result
|
||||||
|
|
||||||
# 步骤5: 发送推理事件
|
# 步骤5: 发送推理事件
|
||||||
await dispatch_custom_event(
|
await dispatch_custom_event(
|
||||||
|
|||||||
@@ -41,6 +41,30 @@ class ErrorRecord:
|
|||||||
context: Dict[str, Any] = field(default_factory=dict) # 错误上下文
|
context: Dict[str, Any] = field(default_factory=dict) # 错误上下文
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReactReasoningState:
|
||||||
|
"""React 推理状态 - 替代 debug_info 中的相关字段"""
|
||||||
|
last_reasoning: Optional[Dict[str, Any]] = None
|
||||||
|
reasoning_result: Optional[Any] = None # 实际类型是 ReasoningResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HybridRouterState:
|
||||||
|
"""混合路由状态 - 替代 debug_info 中的相关字段"""
|
||||||
|
decision: Optional[Any] = None # 实际类型是 HybridRouterResult
|
||||||
|
start_time: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FastPathState:
|
||||||
|
"""快速路径状态 - 替代 debug_info 中的相关字段"""
|
||||||
|
chitchat_success: bool = False
|
||||||
|
rag_success: bool = False
|
||||||
|
tool_success: bool = False
|
||||||
|
failed: bool = False
|
||||||
|
fail_reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MainGraphState:
|
class MainGraphState:
|
||||||
"""
|
"""
|
||||||
@@ -103,4 +127,11 @@ class MainGraphState:
|
|||||||
# ========== 元数据 ==========
|
# ========== 元数据 ==========
|
||||||
start_time: Optional[str] = None
|
start_time: Optional[str] = None
|
||||||
end_time: Optional[str] = None
|
end_time: Optional[str] = None
|
||||||
|
|
||||||
|
# ========== 结构化状态(替代黑盒 debug_info)==========
|
||||||
|
react_reasoning: ReactReasoningState = field(default_factory=ReactReasoningState)
|
||||||
|
hybrid_router: HybridRouterState = field(default_factory=HybridRouterState)
|
||||||
|
fast_path: FastPathState = field(default_factory=FastPathState)
|
||||||
|
|
||||||
|
# ========== 向后兼容 ==========
|
||||||
debug_info: Dict[str, Any] = field(default_factory=dict)
|
debug_info: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user