Files
ailine/backend/app/agent_subgraphs/common/intent.py

428 lines
12 KiB
Python
Raw Normal View History

"""
意图理解工具模块
提供标准化的意图分类和信息提取能力
功能
1. Intent - 意图数据类
2. IntentClassifier - 意图分类器
3. EntityExtractor - 实体提取器
4. IntentParser - 完整的意图解析器
5. IntentRegistry - 意图注册器
"""
import re
from typing import Dict, List, Any, Optional, Set, Tuple, Callable
from dataclasses import dataclass, field
from enum import Enum, auto
from abc import ABC, abstractmethod
class IntentType(Enum):
"""意图类型枚举"""
UNKNOWN = auto()
GREETING = auto() # 问候
QUESTION = auto() # 提问
REQUEST = auto() # 请求
COMMAND = auto() # 命令
INFORM = auto() # 告知信息
CONFIRM = auto() # 确认
DENY = auto() # 否认
THANKS = auto() # 感谢
GOODBYE = auto() # 告别
COMPLAINT = auto() # 投诉
PRAISE = auto() # 表扬
CLARIFY = auto() # 澄清
SUGGEST = auto() # 建议
@dataclass
class Entity:
"""实体数据类"""
entity_type: str # 实体类型
value: str # 实体值
start_pos: int = 0 # 起始位置
end_pos: int = 0 # 结束位置
confidence: float = 1.0 # 置信度
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
@dataclass
class Intent:
"""意图数据类"""
intent_type: IntentType # 意图类型
confidence: float = 1.0 # 置信度
entities: List[Entity] = field(default_factory=list) # 提取的实体
parameters: Dict[str, Any] = field(default_factory=dict) # 参数
original_text: str = "" # 原始文本
normalized_text: str = "" # 标准化后的文本
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
class BaseIntentClassifier(ABC):
"""意图分类器基类"""
@abstractmethod
def classify(self, text: str) -> Tuple[IntentType, float]:
"""
分类意图
Args:
text: 输入文本
Returns:
(意图类型, 置信度)
"""
pass
@abstractmethod
def classify_with_scores(self, text: str) -> Dict[IntentType, float]:
"""
分类意图返回所有类型的置信度
Args:
text: 输入文本
Returns:
{意图类型: 置信度}
"""
pass
class RuleBasedIntentClassifier(BaseIntentClassifier):
"""基于规则的意图分类器"""
def __init__(self):
self._rules: Dict[IntentType, Set[str]] = {}
self._initialize_default_rules()
def _initialize_default_rules(self) -> None:
"""初始化默认规则"""
# 问候
self.add_rule(IntentType.GREETING, {
"你好", "您好", "hi", "hello", "hey", "早上好", "下午好", "晚上好", "哈喽"
})
# 告别
self.add_rule(IntentType.GOODBYE, {
"再见", "拜拜", "bye", "goodbye", "回见", "下次见", "再见了"
})
# 感谢
self.add_rule(IntentType.THANKS, {
"谢谢", "感谢", "多谢", "thanks", "thank you", "3q", "谢谢了"
})
# 确认
self.add_rule(IntentType.CONFIRM, {
"是的", "", "没错", "好的", "可以", "", "同意", "确认", "yes", "yep"
})
# 否认
self.add_rule(IntentType.DENY, {
"", "不是", "不对", "不行", "不要", "拒绝", "no", "nope", "没有"
})
# 提问
self.add_rule(IntentType.QUESTION, {
"?", "", "什么", "怎么", "如何", "为什么", "", "", "多少", "", ""
})
def add_rule(self, intent_type: IntentType, keywords: Set[str]) -> None:
"""
添加规则
Args:
intent_type: 意图类型
keywords: 关键词集合
"""
if intent_type not in self._rules:
self._rules[intent_type] = set()
self._rules[intent_type].update(keywords)
def classify(self, text: str) -> Tuple[IntentType, float]:
"""
分类意图
Args:
text: 输入文本
Returns:
(意图类型, 置信度)
"""
scores = self.classify_with_scores(text)
if not scores:
return IntentType.UNKNOWN, 0.0
best_intent = max(scores.items(), key=lambda x: x[1])
return best_intent[0], best_intent[1]
def classify_with_scores(self, text: str) -> Dict[IntentType, float]:
"""
分类意图返回所有类型的置信度
Args:
text: 输入文本
Returns:
{意图类型: 置信度}
"""
scores: Dict[IntentType, float] = {}
normalized_text = text.lower()
for intent_type, keywords in self._rules.items():
match_count = 0
for keyword in keywords:
if keyword.lower() in normalized_text:
match_count += 1
if match_count > 0:
scores[intent_type] = min(1.0, match_count / 3.0)
# 如果没有匹配返回UNKNOWN
if not scores:
scores[IntentType.UNKNOWN] = 0.5
return scores
class BaseEntityExtractor(ABC):
"""实体提取器基类"""
@abstractmethod
def extract(self, text: str) -> List[Entity]:
"""
提取实体
Args:
text: 输入文本
Returns:
实体列表
"""
pass
class RuleBasedEntityExtractor(BaseEntityExtractor):
"""基于规则的实体提取器"""
def __init__(self):
self._patterns: Dict[str, re.Pattern] = {} # 正则模式
self._keywords: Dict[str, Set[str]] = {} # 关键词列表
self._initialize_default_patterns()
def _initialize_default_patterns(self) -> None:
"""初始化默认模式"""
# 邮箱
self.add_regex_pattern(
"email",
r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
)
# 电话号码
self.add_regex_pattern(
"phone",
r'1[3-9]\d{9}'
)
# 日期(简单模式)
self.add_regex_pattern(
"date",
r'\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日号]?|\d{1,2}[-/月]\d{1,2}[日号]?'
)
# 数字
self.add_regex_pattern(
"number",
r'\d+\.?\d*'
)
def add_regex_pattern(self, entity_type: str, pattern: str) -> None:
"""
添加正则匹配规则
Args:
entity_type: 实体类型
pattern: 正则表达式
"""
try:
self._patterns[entity_type] = re.compile(pattern, re.IGNORECASE)
except re.error:
pass
def add_keywords(self, entity_type: str, keywords: Set[str]) -> None:
"""
添加关键词匹配规则
Args:
entity_type: 实体类型
keywords: 关键词集合
"""
if entity_type not in self._keywords:
self._keywords[entity_type] = set()
self._keywords[entity_type].update(keywords)
def extract(self, text: str) -> List[Entity]:
"""
提取实体
Args:
text: 输入文本
Returns:
实体列表
"""
entities: List[Entity] = []
# 正则匹配
for entity_type, pattern in self._patterns.items():
for match in pattern.finditer(text):
entity = Entity(
entity_type=entity_type,
value=match.group(),
start_pos=match.start(),
end_pos=match.end(),
confidence=0.95
)
entities.append(entity)
# 关键词匹配
for entity_type, keywords in self._keywords.items():
for keyword in keywords:
start_idx = 0
while True:
pos = text.lower().find(keyword.lower(), start_idx)
if pos == -1:
break
entity = Entity(
entity_type=entity_type,
value=text[pos:pos + len(keyword)],
start_pos=pos,
end_pos=pos + len(keyword),
confidence=0.9
)
entities.append(entity)
start_idx = pos + len(keyword)
# 按位置排序
entities.sort(key=lambda e: e.start_pos)
return entities
class IntentRegistry:
"""意图注册器"""
def __init__(self):
self._intent_handlers: Dict[IntentType, Callable] = {}
def register(self, intent_type: IntentType, handler: Callable) -> None:
"""
注册意图处理器
Args:
intent_type: 意图类型
handler: 处理器
"""
self._intent_handlers[intent_type] = handler
def get_handler(self, intent_type: IntentType) -> Optional[Callable]:
"""
获取意图处理器
Args:
intent_type: 意图类型
Returns:
处理器如果不存在返回 None
"""
return self._intent_handlers.get(intent_type)
class IntentParser:
"""完整的意图解析器"""
def __init__(
self,
classifier: Optional[BaseIntentClassifier] = None,
extractor: Optional[BaseEntityExtractor] = None,
registry: Optional[IntentRegistry] = None
):
"""
初始化意图解析器
Args:
classifier: 意图分类器
extractor: 实体提取器
registry: 意图注册器
"""
self.classifier = classifier or RuleBasedIntentClassifier()
self.extractor = extractor or RuleBasedEntityExtractor()
self.registry = registry or IntentRegistry()
def parse(self, text: str) -> Intent:
"""
解析文本返回完整的意图对象
Args:
text: 输入文本
Returns:
意图对象
"""
# 分类意图
intent_type, confidence = self.classifier.classify(text)
# 提取实体
entities = self.extractor.extract(text)
# 构建意图对象
intent = Intent(
intent_type=intent_type,
confidence=confidence,
entities=entities,
original_text=text,
normalized_text=text.lower().strip()
)
# 从实体中提取参数
for entity in entities:
intent.parameters[entity.entity_type] = entity.value
return intent
def parse_and_execute(self, text: str, context: Optional[Dict[str, Any]] = None) -> Any:
"""
解析文本并执行对应的处理器
Args:
text: 输入文本
context: 上下文
Returns:
执行结果
"""
intent = self.parse(text)
handler = self.registry.get_handler(intent.intent_type)
if handler:
return handler(intent, context or {})
return None
def create_default_intent_parser() -> IntentParser:
"""
创建默认配置的意图解析器
Returns:
配置好的意图解析器
"""
parser = IntentParser()
# 注册默认处理器
def greeting_handler(intent: Intent, context: Dict) -> str:
return "你好!很高兴为你服务。"
def thanks_handler(intent: Intent, context: Dict) -> str:
return "不客气!"
def goodbye_handler(intent: Intent, context: Dict) -> str:
return "再见!有需要随时找我。"
parser.registry.register(IntentType.GREETING, greeting_handler)
parser.registry.register(IntentType.THANKS, thanks_handler)
parser.registry.register(IntentType.GOODBYE, goodbye_handler)
return parser