193 lines
6.4 KiB
Python
193 lines
6.4 KiB
Python
|
|
# backend/app/agent/intent_classifier.py
|
|||
|
|
|
|||
|
|
from enum import Enum
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Optional, Dict, Any
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# 添加项目路径
|
|||
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
|||
|
|
|
|||
|
|
from app.model_services.chat_services import get_chat_service
|
|||
|
|
|
|||
|
|
|
|||
|
|
class IntentType(Enum):
|
|||
|
|
"""意图类型枚举"""
|
|||
|
|
KNOWLEDGE = "knowledge" # 知识查询 → RAG
|
|||
|
|
REALTIME = "realtime" # 实时数据 → 工具
|
|||
|
|
ACTION = "action" # 执行操作 → 工具
|
|||
|
|
CHITCHAT = "chitchat" # 闲聊 → 直接回答
|
|||
|
|
CLARIFY = "clarify" # 需要澄清 → 反问用户
|
|||
|
|
MIXED = "mixed" # 复杂任务 → React 循环
|
|||
|
|
UNKNOWN = "unknown" # 未知意图
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class IntentResult:
|
|||
|
|
"""意图识别结果"""
|
|||
|
|
intent_type: IntentType
|
|||
|
|
confidence: float
|
|||
|
|
reasoning: str
|
|||
|
|
metadata: Dict[str, Any] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class IntentClassifier:
|
|||
|
|
"""意图分类器"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.llm = get_chat_service()
|
|||
|
|
self._intent_examples = self._build_examples()
|
|||
|
|
|
|||
|
|
def _build_examples(self) -> str:
|
|||
|
|
"""构建少样本示例"""
|
|||
|
|
return """
|
|||
|
|
<示例>
|
|||
|
|
用户: "公司的报销政策是什么?"
|
|||
|
|
意图: knowledge
|
|||
|
|
推理: 用户询问公司内部政策,需要查询知识库
|
|||
|
|
|
|||
|
|
用户: "帮我查一下订单 12345 的状态"
|
|||
|
|
意图: realtime
|
|||
|
|
推理: 需要查询实时订单数据
|
|||
|
|
|
|||
|
|
用户: "帮我申请退款,订单号 67890"
|
|||
|
|
意图: action
|
|||
|
|
推理: 需要执行退款操作
|
|||
|
|
|
|||
|
|
用户: "今天天气怎么样?"
|
|||
|
|
意图: realtime
|
|||
|
|
推理: 需要查询实时天气数据
|
|||
|
|
|
|||
|
|
用户: "帮我写一份邮件给客户,查询订单状态,然后附上退款政策"
|
|||
|
|
意图: mixed
|
|||
|
|
推理: 需要查询订单、查询政策、生成邮件,多步骤任务
|
|||
|
|
|
|||
|
|
用户: "你好"
|
|||
|
|
意图: chitchat
|
|||
|
|
推理: 简单寒暄
|
|||
|
|
|
|||
|
|
用户: "我想查点东西..."
|
|||
|
|
意图: clarify
|
|||
|
|
推理: 用户没有说清楚要查什么
|
|||
|
|
</示例>
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
async def classify(self, user_input: str, context: Optional[str] = None) -> IntentResult:
|
|||
|
|
"""
|
|||
|
|
分类用户意图
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_input: 用户输入
|
|||
|
|
context: 对话上下文(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
IntentResult
|
|||
|
|
"""
|
|||
|
|
prompt = self._build_classification_prompt(user_input, context)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = await self.llm.ainvoke(prompt)
|
|||
|
|
result = self._parse_response(response.content)
|
|||
|
|
return result
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Intent classification error: {e}")
|
|||
|
|
# 降级策略:默认返回 mixed,走 React 循环
|
|||
|
|
return IntentResult(
|
|||
|
|
intent_type=IntentType.MIXED,
|
|||
|
|
confidence=0.5,
|
|||
|
|
reasoning="分类失败,走通用路径"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _build_classification_prompt(self, user_input: str, context: Optional[str]) -> str:
|
|||
|
|
"""构建分类提示词"""
|
|||
|
|
context_part = f"\n对话上下文:\n{context}" if context else ""
|
|||
|
|
|
|||
|
|
return f"""
|
|||
|
|
你是一个专业的意图识别助手。请分析用户的输入,判断其意图类型。
|
|||
|
|
|
|||
|
|
可选意图类型:
|
|||
|
|
- knowledge: 用户询问知识、政策、文档等,需要查询知识库
|
|||
|
|
- realtime: 用户需要查询实时数据(订单状态、天气、股票等)
|
|||
|
|
- action: 用户需要执行某项操作(退款、下单、发送邮件等)
|
|||
|
|
- chitchat: 用户只是闲聊、打招呼,不需要工具或检索
|
|||
|
|
- clarify: 用户的问题不明确,需要追问澄清
|
|||
|
|
- mixed: 复杂任务,需要多步骤处理(同时需要检索+工具)
|
|||
|
|
|
|||
|
|
{self._intent_examples}
|
|||
|
|
|
|||
|
|
用户输入: {user_input}
|
|||
|
|
{context_part}
|
|||
|
|
|
|||
|
|
请按以下格式输出(纯JSON):
|
|||
|
|
{{
|
|||
|
|
"intent": "knowledge|realtime|action|chitchat|clarify|mixed",
|
|||
|
|
"confidence": 0.85,
|
|||
|
|
"reasoning": "简要说明为什么这个意图"
|
|||
|
|
}}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def _parse_response(self, response: str) -> IntentResult:
|
|||
|
|
"""解析 LLM 响应"""
|
|||
|
|
import json
|
|||
|
|
import re
|
|||
|
|
|
|||
|
|
# 尝试提取 JSON
|
|||
|
|
json_match = re.search(r'\{[\s\S]*\}', response)
|
|||
|
|
if json_match:
|
|||
|
|
try:
|
|||
|
|
data = json.loads(json_match.group())
|
|||
|
|
return IntentResult(
|
|||
|
|
intent_type=IntentType(data['intent']),
|
|||
|
|
confidence=float(data['confidence']),
|
|||
|
|
reasoning=data['reasoning']
|
|||
|
|
)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 降级策略:关键词匹配
|
|||
|
|
return self._fallback_classify(response)
|
|||
|
|
|
|||
|
|
def _fallback_classify(self, user_input: str) -> IntentResult:
|
|||
|
|
"""关键词匹配降级策略"""
|
|||
|
|
keywords = {
|
|||
|
|
IntentType.KNOWLEDGE: ['政策', '文档', '规定', '手册', '指南', '什么是', '怎么'],
|
|||
|
|
IntentType.REALTIME: ['订单', '状态', '天气', '股票', '价格', '库存'],
|
|||
|
|
IntentType.ACTION: ['退款', '取消', '发送', '申请', '修改', '删除'],
|
|||
|
|
IntentType.CHITCHAT: ['你好', 'hi', 'hello', '嗨', '早上好', '晚上好'],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for intent_type, words in keywords.items():
|
|||
|
|
if any(word in user_input.lower() for word in words):
|
|||
|
|
return IntentResult(
|
|||
|
|
intent_type=intent_type,
|
|||
|
|
confidence=0.7,
|
|||
|
|
reasoning=f"关键词匹配: {', '.join(words)}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 默认走混合路径
|
|||
|
|
return IntentResult(
|
|||
|
|
intent_type=IntentType.MIXED,
|
|||
|
|
confidence=0.5,
|
|||
|
|
reasoning="无法明确分类,走通用路径"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def batch_classify(self, inputs: list[str]) -> list[IntentResult]:
|
|||
|
|
"""批量分类(带缓存)"""
|
|||
|
|
# 可以添加缓存逻辑
|
|||
|
|
results = []
|
|||
|
|
for inp in inputs:
|
|||
|
|
results.append(await self.classify(inp))
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局实例
|
|||
|
|
_classifier: Optional[IntentClassifier] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_intent_classifier() -> IntentClassifier:
|
|||
|
|
"""获取意图分类器实例"""
|
|||
|
|
global _classifier
|
|||
|
|
if _classifier is None:
|
|||
|
|
_classifier = IntentClassifier()
|
|||
|
|
return _classifier
|