Files
ailine/backend/app/core/json_parser.py

204 lines
5.6 KiB
Python
Raw Normal View History

"""
统一的 JSON 解析工具保证 LLM JSON 输出的稳定性
处理各种边界情况
1. JSON 字符串
2. JSON markdown 代码块中
3. JSON 在文本中间
4. JSON 有多余的逗号
5. JSON 有尾随内容
"""
import re
import json
from typing import TypeVar, Type, Dict, Any, Optional
from dataclasses import dataclass
from json import JSONDecodeError
T = TypeVar('T')
@dataclass
class ParseResult:
"""JSON 解析结果"""
success: bool
data: Optional[Dict[str, Any]] = None
error: Optional[str] = None
raw_response: str = ""
def extract_and_parse_json(
response: str,
schema: Optional[Dict[str, Any]] = None
) -> ParseResult:
"""
LLM 响应中提取并解析 JSON使用多种策略处理边界情况
Args:
response: LLM 的原始响应
schema: 可选的 JSON Schema预留暂未使用
Returns:
ParseResult: 解析结果
"""
result = ParseResult(raw_response=response, success=False)
# 前置清理
cleaned = response.strip()
if not cleaned:
result.error = "响应为空"
return result
# 策略1尝试直接解析完整响应
try:
data = json.loads(cleaned)
result.data = data
result.success = True
return result
except JSONDecodeError:
pass
# 策略2尝试匹配 markdown 代码块(优先)
codeblock_patterns = [
r'```(?:json)?\s*([\s\S]*?)\s*```', # ```json ... ```
r'```([\s\S]*?)```', # ``` ... ```
]
for pattern in codeblock_patterns:
match = re.search(pattern, cleaned)
if match:
json_str = match.group(1).strip()
if json_str:
try:
data = json.loads(json_str)
result.data = data
result.success = True
return result
except JSONDecodeError:
continue
# 策略3提取最外层的完整 {} 块(处理嵌套)
json_match = _extract_outermost_json(cleaned)
if json_match:
try:
data = json.loads(json_match)
result.data = data
result.success = True
return result
except JSONDecodeError:
pass
# 策略4尝试修复常见问题
try:
# 去除多余的尾随逗号
fixed = re.sub(r',\s*([}\]])', r'\1', cleaned)
# 提取第一个 { 到最后一个 } 的内容
first_brace = fixed.find('{')
last_brace = fixed.rfind('}')
if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
json_str = fixed[first_brace:last_brace+1]
data = json.loads(json_str)
result.data = data
result.success = True
return result
except Exception:
pass
# 所有策略都失败
result.error = f"无法从响应中提取有效 JSON: {cleaned[:200]}..."
return result
def _extract_outermost_json(text: str) -> Optional[str]:
"""
提取最外层的完整 JSON 处理嵌套
使用栈方法正确处理嵌套的 {}
"""
stack = []
start_idx = -1
for i, char in enumerate(text):
if char == '{':
if not stack:
start_idx = i
stack.append('{')
elif char == '}':
if stack:
stack.pop()
if not stack and start_idx != -1:
# 找到完整的外层块
return text[start_idx:i+1]
return None
def parse_json_to_dataclass(
response: str,
dataclass_type: Type[T],
default_factory: callable
) -> T:
"""
解析 JSON 并转换为 dataclass 实例失败时返回默认值
Args:
response: LLM 响应
dataclass_type: 目标 dataclass 类型
default_factory: 生成默认值的工厂函数
Returns:
T: dataclass 实例
"""
parse_result = extract_and_parse_json(response)
if not parse_result.success or not parse_result.data:
return default_factory()
try:
return dataclass_type(**parse_result.data)
except (TypeError, ValueError) as e:
# 字段不匹配时尝试降级
return default_factory()
def safe_get(data: Dict[str, Any], key: str, default: Any = None) -> Any:
"""安全地从字典中获取值"""
if not data or not isinstance(data, dict):
return default
return data.get(key, default)
def safe_get_bool(data: Dict[str, Any], key: str, default: bool = False) -> bool:
"""安全地获取布尔值"""
value = safe_get(data, key, default)
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes', 'on')
if isinstance(value, (int, float)):
return bool(value)
return default
def safe_get_float(data: Dict[str, Any], key: str, default: float = 0.0) -> float:
"""安全地获取浮点值"""
value = safe_get(data, key, default)
try:
return float(value)
except (TypeError, ValueError):
return default
def safe_get_int(data: Dict[str, Any], key: str, default: int = 0) -> int:
"""安全地获取整数值"""
value = safe_get(data, key, default)
try:
return int(value)
except (TypeError, ValueError):
return default
def safe_get_str(data: Dict[str, Any], key: str, default: str = "") -> str:
"""安全地获取字符串值"""
value = safe_get(data, key, default)
return str(value) if value is not None else default