204 lines
5.6 KiB
Python
204 lines
5.6 KiB
Python
|
|
"""
|
|||
|
|
统一的 JSON 解析工具,保证 LLM JSON 输出的稳定性
|
|||
|
|
|
|||
|
|
处理各种边界情况:
|
|||
|
|
1. 纯 JSON 字符串
|
|||
|
|
2. JSON 在 markdown 代码块中
|
|||
|
|
3. JSON 在文本中间
|
|||
|
|
4. JSON 有多余的逗号
|
|||
|
|
5. JSON 有尾随内容
|
|||
|
|
"""
|
|||
|
|
import re
|
|||
|
|
import json
|
|||
|
|
from typing import TypeVar, Type, Dict, Any, Optional
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from json import JSONDecodeError
|
|||
|
|
|
|||
|
|
T = TypeVar('T')
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ParseResult:
|
|||
|
|
"""JSON 解析结果"""
|
|||
|
|
success: bool
|
|||
|
|
data: Optional[Dict[str, Any]] = None
|
|||
|
|
error: Optional[str] = None
|
|||
|
|
raw_response: str = ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_and_parse_json(
|
|||
|
|
response: str,
|
|||
|
|
schema: Optional[Dict[str, Any]] = None
|
|||
|
|
) -> ParseResult:
|
|||
|
|
"""
|
|||
|
|
从 LLM 响应中提取并解析 JSON,使用多种策略处理边界情况
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
response: LLM 的原始响应
|
|||
|
|
schema: 可选的 JSON Schema(预留,暂未使用)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
ParseResult: 解析结果
|
|||
|
|
"""
|
|||
|
|
result = ParseResult(raw_response=response, success=False)
|
|||
|
|
|
|||
|
|
# 前置清理
|
|||
|
|
cleaned = response.strip()
|
|||
|
|
if not cleaned:
|
|||
|
|
result.error = "响应为空"
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
# 策略1:尝试直接解析完整响应
|
|||
|
|
try:
|
|||
|
|
data = json.loads(cleaned)
|
|||
|
|
result.data = data
|
|||
|
|
result.success = True
|
|||
|
|
return result
|
|||
|
|
except JSONDecodeError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 策略2:尝试匹配 markdown 代码块(优先)
|
|||
|
|
codeblock_patterns = [
|
|||
|
|
r'```(?:json)?\s*([\s\S]*?)\s*```', # ```json ... ```
|
|||
|
|
r'```([\s\S]*?)```', # ``` ... ```
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for pattern in codeblock_patterns:
|
|||
|
|
match = re.search(pattern, cleaned)
|
|||
|
|
if match:
|
|||
|
|
json_str = match.group(1).strip()
|
|||
|
|
if json_str:
|
|||
|
|
try:
|
|||
|
|
data = json.loads(json_str)
|
|||
|
|
result.data = data
|
|||
|
|
result.success = True
|
|||
|
|
return result
|
|||
|
|
except JSONDecodeError:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 策略3:提取最外层的完整 {} 块(处理嵌套)
|
|||
|
|
json_match = _extract_outermost_json(cleaned)
|
|||
|
|
if json_match:
|
|||
|
|
try:
|
|||
|
|
data = json.loads(json_match)
|
|||
|
|
result.data = data
|
|||
|
|
result.success = True
|
|||
|
|
return result
|
|||
|
|
except JSONDecodeError:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 策略4:尝试修复常见问题
|
|||
|
|
try:
|
|||
|
|
# 去除多余的尾随逗号
|
|||
|
|
fixed = re.sub(r',\s*([}\]])', r'\1', cleaned)
|
|||
|
|
# 提取第一个 { 到最后一个 } 的内容
|
|||
|
|
first_brace = fixed.find('{')
|
|||
|
|
last_brace = fixed.rfind('}')
|
|||
|
|
if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
|
|||
|
|
json_str = fixed[first_brace:last_brace+1]
|
|||
|
|
data = json.loads(json_str)
|
|||
|
|
result.data = data
|
|||
|
|
result.success = True
|
|||
|
|
return result
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 所有策略都失败
|
|||
|
|
result.error = f"无法从响应中提取有效 JSON: {cleaned[:200]}..."
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_outermost_json(text: str) -> Optional[str]:
|
|||
|
|
"""
|
|||
|
|
提取最外层的完整 JSON 块(处理嵌套)
|
|||
|
|
|
|||
|
|
使用栈方法,正确处理嵌套的 {}
|
|||
|
|
"""
|
|||
|
|
stack = []
|
|||
|
|
start_idx = -1
|
|||
|
|
|
|||
|
|
for i, char in enumerate(text):
|
|||
|
|
if char == '{':
|
|||
|
|
if not stack:
|
|||
|
|
start_idx = i
|
|||
|
|
stack.append('{')
|
|||
|
|
elif char == '}':
|
|||
|
|
if stack:
|
|||
|
|
stack.pop()
|
|||
|
|
if not stack and start_idx != -1:
|
|||
|
|
# 找到完整的外层块
|
|||
|
|
return text[start_idx:i+1]
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_json_to_dataclass(
|
|||
|
|
response: str,
|
|||
|
|
dataclass_type: Type[T],
|
|||
|
|
default_factory: callable
|
|||
|
|
) -> T:
|
|||
|
|
"""
|
|||
|
|
解析 JSON 并转换为 dataclass 实例,失败时返回默认值
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
response: LLM 响应
|
|||
|
|
dataclass_type: 目标 dataclass 类型
|
|||
|
|
default_factory: 生成默认值的工厂函数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
T: dataclass 实例
|
|||
|
|
"""
|
|||
|
|
parse_result = extract_and_parse_json(response)
|
|||
|
|
|
|||
|
|
if not parse_result.success or not parse_result.data:
|
|||
|
|
return default_factory()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
return dataclass_type(**parse_result.data)
|
|||
|
|
except (TypeError, ValueError) as e:
|
|||
|
|
# 字段不匹配时尝试降级
|
|||
|
|
return default_factory()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def safe_get(data: Dict[str, Any], key: str, default: Any = None) -> Any:
|
|||
|
|
"""安全地从字典中获取值"""
|
|||
|
|
if not data or not isinstance(data, dict):
|
|||
|
|
return default
|
|||
|
|
return data.get(key, default)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def safe_get_bool(data: Dict[str, Any], key: str, default: bool = False) -> bool:
|
|||
|
|
"""安全地获取布尔值"""
|
|||
|
|
value = safe_get(data, key, default)
|
|||
|
|
if isinstance(value, bool):
|
|||
|
|
return value
|
|||
|
|
if isinstance(value, str):
|
|||
|
|
return value.lower() in ('true', '1', 'yes', 'on')
|
|||
|
|
if isinstance(value, (int, float)):
|
|||
|
|
return bool(value)
|
|||
|
|
return default
|
|||
|
|
|
|||
|
|
|
|||
|
|
def safe_get_float(data: Dict[str, Any], key: str, default: float = 0.0) -> float:
|
|||
|
|
"""安全地获取浮点值"""
|
|||
|
|
value = safe_get(data, key, default)
|
|||
|
|
try:
|
|||
|
|
return float(value)
|
|||
|
|
except (TypeError, ValueError):
|
|||
|
|
return default
|
|||
|
|
|
|||
|
|
|
|||
|
|
def safe_get_int(data: Dict[str, Any], key: str, default: int = 0) -> int:
|
|||
|
|
"""安全地获取整数值"""
|
|||
|
|
value = safe_get(data, key, default)
|
|||
|
|
try:
|
|||
|
|
return int(value)
|
|||
|
|
except (TypeError, ValueError):
|
|||
|
|
return default
|
|||
|
|
|
|||
|
|
|
|||
|
|
def safe_get_str(data: Dict[str, Any], key: str, default: str = "") -> str:
|
|||
|
|
"""安全地获取字符串值"""
|
|||
|
|
value = safe_get(data, key, default)
|
|||
|
|
return str(value) if value is not None else default
|