This commit is contained in:
1
backend/app/core/__init__.py
Normal file
1
backend/app/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""核心模块 - 基类和通用工具"""
|
||||
482
backend/app/core/formatter.py
Normal file
482
backend/app/core/formatter.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
格式化输出工具模块
|
||||
提供基于 Jinja2 模板的 Markdown 格式化输出能力
|
||||
|
||||
功能:
|
||||
1. TemplateManager - 模板管理器,支持加载和渲染 Jinja2 模板
|
||||
2. MarkdownFormatter - Markdown 格式化工具,提供常用格式(表格、列表、引用等)
|
||||
3. OutputRenderer - 输出渲染器,统一接口生成最终输出
|
||||
4. PresetTemplates - 预置模板(对话摘要、报告、列表等)
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# 尝试导入 Jinja2,如果没有则提供基础实现
|
||||
try:
|
||||
from jinja2 import Template as JinjaTemplate, Environment, BaseLoader
|
||||
HAS_JINJA2 = True
|
||||
except ImportError:
|
||||
HAS_JINJA2 = False
|
||||
|
||||
|
||||
class BaseFormatter(ABC):
|
||||
"""格式化器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def format(self, data: Any) -> str:
|
||||
"""格式化数据为字符串"""
|
||||
pass
|
||||
|
||||
|
||||
class MarkdownFormatter(BaseFormatter):
|
||||
"""Markdown 格式化工具"""
|
||||
|
||||
@staticmethod
|
||||
def table(data: List[Dict[str, Any]], headers: Optional[List[str]] = None) -> str:
|
||||
"""
|
||||
生成 Markdown 表格
|
||||
|
||||
Args:
|
||||
data: 数据列表,每个元素是一个字典
|
||||
headers: 表头列表,如果为 None 则使用字典的键
|
||||
|
||||
Returns:
|
||||
Markdown 表格字符串
|
||||
"""
|
||||
if not data:
|
||||
return ""
|
||||
|
||||
if headers is None:
|
||||
headers = list(data[0].keys()) if data else []
|
||||
|
||||
if not headers:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
|
||||
# 表头行
|
||||
header_line = "| " + " | ".join(str(h) for h in headers) + " |"
|
||||
lines.append(header_line)
|
||||
|
||||
# 分隔线
|
||||
separator_line = "| " + " | ".join("---" for _ in headers) + " |"
|
||||
lines.append(separator_line)
|
||||
|
||||
# 数据行
|
||||
for row in data:
|
||||
row_values = [str(row.get(h, "")) for h in headers]
|
||||
row_line = "| " + " | ".join(row_values) + " |"
|
||||
lines.append(row_line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def bullet_list(items: List[str], indent: int = 0) -> str:
|
||||
"""
|
||||
生成无序列表
|
||||
|
||||
Args:
|
||||
items: 列表项
|
||||
indent: 缩进层级
|
||||
|
||||
Returns:
|
||||
Markdown 无序列表字符串
|
||||
"""
|
||||
indent_str = " " * indent
|
||||
return "\n".join(f"{indent_str}- {item}" for item in items)
|
||||
|
||||
@staticmethod
|
||||
def numbered_list(items: List[str], start: int = 1, indent: int = 0) -> str:
|
||||
"""
|
||||
生成有序列表
|
||||
|
||||
Args:
|
||||
items: 列表项
|
||||
start: 起始编号
|
||||
indent: 缩进层级
|
||||
|
||||
Returns:
|
||||
Markdown 有序列表字符串
|
||||
"""
|
||||
indent_str = " " * indent
|
||||
return "\n".join(f"{indent_str}{i}. {item}" for i, item in enumerate(items, start=start))
|
||||
|
||||
@staticmethod
|
||||
def quote(text: str, author: Optional[str] = None) -> str:
|
||||
"""
|
||||
生成引用块
|
||||
|
||||
Args:
|
||||
text: 引用文本
|
||||
author: 作者(可选)
|
||||
|
||||
Returns:
|
||||
Markdown 引用块字符串
|
||||
"""
|
||||
quoted_lines = "\n".join(f"> {line}" for line in text.split("\n"))
|
||||
if author:
|
||||
quoted_lines += f"\n> — {author}"
|
||||
return quoted_lines
|
||||
|
||||
@staticmethod
|
||||
def code(code: str, language: str = "") -> str:
|
||||
"""
|
||||
生成代码块
|
||||
|
||||
Args:
|
||||
code: 代码内容
|
||||
language: 语言标识符
|
||||
|
||||
Returns:
|
||||
Markdown 代码块字符串
|
||||
"""
|
||||
return f"```{language}\n{code}\n```"
|
||||
|
||||
@staticmethod
|
||||
def heading(text: str, level: int = 1) -> str:
|
||||
"""
|
||||
生成标题
|
||||
|
||||
Args:
|
||||
text: 标题文本
|
||||
level: 标题级别(1-6)
|
||||
|
||||
Returns:
|
||||
Markdown 标题字符串
|
||||
"""
|
||||
level = max(1, min(6, level))
|
||||
return f"{'#' * level} {text}"
|
||||
|
||||
@staticmethod
|
||||
def link(text: str, url: str) -> str:
|
||||
"""
|
||||
生成链接
|
||||
|
||||
Args:
|
||||
text: 链接文本
|
||||
url: 链接地址
|
||||
|
||||
Returns:
|
||||
Markdown 链接字符串
|
||||
"""
|
||||
return f"[{text}]({url})"
|
||||
|
||||
@staticmethod
|
||||
def bold(text: str) -> str:
|
||||
"""生成粗体"""
|
||||
return f"**{text}**"
|
||||
|
||||
@staticmethod
|
||||
def italic(text: str) -> str:
|
||||
"""生成斜体"""
|
||||
return f"*{text}*"
|
||||
|
||||
@staticmethod
|
||||
def divider() -> str:
|
||||
"""生成分割线"""
|
||||
return "---"
|
||||
|
||||
def format(self, data: Any) -> str:
|
||||
"""实现基类方法,根据数据类型自动选择格式化方式"""
|
||||
if isinstance(data, list):
|
||||
if len(data) > 0 and isinstance(data[0], dict):
|
||||
return self.table(data)
|
||||
else:
|
||||
return self.bullet_list([str(item) for item in data])
|
||||
elif isinstance(data, dict):
|
||||
return self.table([data])
|
||||
else:
|
||||
return str(data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
"""模板数据类"""
|
||||
name: str
|
||||
content: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class DictLoader(BaseLoader):
|
||||
"""字典模板加载器
|
||||
|
||||
用于从内存字典中加载模板
|
||||
"""
|
||||
|
||||
def __init__(self, templates: Dict[str, str]):
|
||||
self.templates = templates
|
||||
|
||||
def get_source(self, environment, template):
|
||||
if template not in self.templates:
|
||||
raise TemplateNotFound(template)
|
||||
source = self.templates[template]
|
||||
return source, None, lambda: True
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""Jinja2 模板管理器"""
|
||||
|
||||
def __init__(self, template_dir: Optional[Path] = None):
|
||||
"""
|
||||
初始化模板管理器
|
||||
|
||||
Args:
|
||||
template_dir: 模板目录路径
|
||||
"""
|
||||
self._templates: Dict[str, Template] = {}
|
||||
self.template_dir = template_dir
|
||||
self._env: Optional[Environment] = None
|
||||
|
||||
if HAS_JINJA2:
|
||||
self._env = Environment(loader=DictLoader({}))
|
||||
|
||||
def _refresh_env(self) -> None:
|
||||
"""刷新 Jinja2 环境"""
|
||||
if HAS_JINJA2 and self._env is not None:
|
||||
template_dict = {name: t.content for name, t in self._templates.items()}
|
||||
self._env = Environment(loader=DictLoader(template_dict))
|
||||
|
||||
def add_template(self, name: str, content: str, description: str = "") -> None:
|
||||
"""
|
||||
添加模板
|
||||
|
||||
Args:
|
||||
name: 模板名称
|
||||
content: 模板内容
|
||||
description: 模板描述
|
||||
"""
|
||||
self._templates[name] = Template(name=name, content=content, description=description)
|
||||
self._refresh_env()
|
||||
|
||||
def load_template(self, name: str, file_path: Path) -> None:
|
||||
"""
|
||||
从文件加载模板
|
||||
|
||||
Args:
|
||||
name: 模板名称
|
||||
file_path: 模板文件路径
|
||||
"""
|
||||
if file_path.exists():
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
self.add_template(name, content, f"从文件加载: {file_path}")
|
||||
|
||||
def get_template(self, name: str) -> Optional[Template]:
|
||||
"""
|
||||
获取模板
|
||||
|
||||
Args:
|
||||
name: 模板名称
|
||||
|
||||
Returns:
|
||||
模板对象,如果不存在返回 None
|
||||
"""
|
||||
return self._templates.get(name)
|
||||
|
||||
def render(self, template_name: str, context: Dict[str, Any]) -> str:
|
||||
"""
|
||||
渲染模板
|
||||
|
||||
Args:
|
||||
template_name: 模板名称
|
||||
context: 渲染上下文
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
"""
|
||||
template = self.get_template(template_name)
|
||||
if template is None:
|
||||
raise ValueError(f"模板不存在: {template_name}")
|
||||
|
||||
return self.render_string(template.content, context)
|
||||
|
||||
def render_string(self, template_string: str, context: Dict[str, Any]) -> str:
|
||||
"""
|
||||
渲染模板字符串
|
||||
|
||||
Args:
|
||||
template_string: 模板字符串
|
||||
context: 渲染上下文
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
"""
|
||||
if HAS_JINJA2 and self._env is not None:
|
||||
try:
|
||||
jinja_template = self._env.from_string(template_string)
|
||||
return jinja_template.render(**context)
|
||||
except Exception:
|
||||
# 如果 Jinja2 渲染失败,使用简单替换
|
||||
pass
|
||||
|
||||
# 简单的字符串替换作为备选方案
|
||||
result = template_string
|
||||
for key, value in context.items():
|
||||
result = result.replace(f"{{{{{key}}}}}", str(value))
|
||||
result = result.replace(f"{{{{ {key} }}}}", str(value))
|
||||
return result
|
||||
|
||||
|
||||
class PresetTemplates:
|
||||
"""预置模板集合"""
|
||||
|
||||
@staticmethod
|
||||
def conversation_summary() -> str:
|
||||
"""对话摘要模板"""
|
||||
return """# 对话摘要
|
||||
|
||||
**时间**: {{ timestamp }}
|
||||
|
||||
**参与者**: {{ participants }}
|
||||
|
||||
---
|
||||
|
||||
## 对话要点
|
||||
{{ bullet_list(points) }}
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
{{ summary }}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def research_report() -> str:
|
||||
"""研究报告模板"""
|
||||
return """# {{ title }}
|
||||
|
||||
**日期**: {{ date }}
|
||||
**作者**: {{ author }}
|
||||
|
||||
---
|
||||
|
||||
## 摘要
|
||||
{{ summary }}
|
||||
|
||||
---
|
||||
|
||||
## 发现
|
||||
{{ bullet_list(findings) }}
|
||||
|
||||
---
|
||||
|
||||
## 数据来源
|
||||
{{ sources }}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def task_list() -> str:
|
||||
"""任务列表模板"""
|
||||
return """# 任务列表
|
||||
|
||||
**更新时间**: {{ update_time }}
|
||||
|
||||
---
|
||||
|
||||
## 待办
|
||||
{{ numbered_list(todos) }}
|
||||
|
||||
---
|
||||
|
||||
## 已完成
|
||||
{{ numbered_list(completed) }}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def data_summary() -> str:
|
||||
"""数据摘要模板"""
|
||||
return """# 数据摘要
|
||||
|
||||
**生成时间**: {{ timestamp }}
|
||||
|
||||
---
|
||||
|
||||
## 数据概览
|
||||
{{ table(data_overview) }}
|
||||
|
||||
---
|
||||
|
||||
## 关键指标
|
||||
{{ bullet_list(metrics) }}
|
||||
"""
|
||||
|
||||
|
||||
class OutputRenderer:
|
||||
"""输出渲染器"""
|
||||
|
||||
def __init__(self, template_manager: Optional[TemplateManager] = None):
|
||||
"""
|
||||
初始化输出渲染器
|
||||
|
||||
Args:
|
||||
template_manager: 模板管理器
|
||||
"""
|
||||
self.template_manager = template_manager or TemplateManager()
|
||||
self.markdown = MarkdownFormatter()
|
||||
|
||||
# 自动注册预置模板
|
||||
self._register_presets()
|
||||
|
||||
def _register_presets(self) -> None:
|
||||
"""注册预置模板"""
|
||||
self.template_manager.add_template(
|
||||
"conversation_summary",
|
||||
PresetTemplates.conversation_summary(),
|
||||
"对话摘要模板"
|
||||
)
|
||||
self.template_manager.add_template(
|
||||
"research_report",
|
||||
PresetTemplates.research_report(),
|
||||
"研究报告模板"
|
||||
)
|
||||
self.template_manager.add_template(
|
||||
"task_list",
|
||||
PresetTemplates.task_list(),
|
||||
"任务列表模板"
|
||||
)
|
||||
self.template_manager.add_template(
|
||||
"data_summary",
|
||||
PresetTemplates.data_summary(),
|
||||
"数据摘要模板"
|
||||
)
|
||||
|
||||
def render(self, template_name: str, context: Dict[str, Any]) -> str:
|
||||
"""
|
||||
使用模板渲染输出
|
||||
|
||||
Args:
|
||||
template_name: 模板名称
|
||||
context: 渲染上下文
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
"""
|
||||
# 将格式化工具注入上下文
|
||||
render_context = context.copy()
|
||||
render_context["bullet_list"] = self.markdown.bullet_list
|
||||
render_context["numbered_list"] = self.markdown.numbered_list
|
||||
render_context["table"] = self.markdown.table
|
||||
render_context["quote"] = self.markdown.quote
|
||||
render_context["code"] = self.markdown.code
|
||||
render_context["heading"] = self.markdown.heading
|
||||
render_context["link"] = self.markdown.link
|
||||
render_context["bold"] = self.markdown.bold
|
||||
render_context["italic"] = self.markdown.italic
|
||||
render_context["divider"] = self.markdown.divider
|
||||
|
||||
return self.template_manager.render(template_name, render_context)
|
||||
|
||||
def render_plain(self, data: Any) -> str:
|
||||
"""
|
||||
直接格式化数据为 Markdown
|
||||
|
||||
Args:
|
||||
data: 数据
|
||||
|
||||
Returns:
|
||||
格式化后的字符串
|
||||
"""
|
||||
return self.markdown.format(data)
|
||||
465
backend/app/core/human_review.py
Normal file
465
backend/app/core/human_review.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
人工审核工具模块
|
||||
提供 LangGraph interrupt 机制和状态持久化能力
|
||||
|
||||
功能:
|
||||
1. HumanReview - 人工审核数据类
|
||||
2. ReviewStatus - 审核状态枚举
|
||||
3. HumanReviewStore - 审核存储接口
|
||||
4. InMemoryReviewStore - 内存存储实现
|
||||
5. HumanReviewNode - LangGraph 审核节点
|
||||
6. ReviewManager - 审核管理器
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class ReviewStatus(Enum):
|
||||
"""审核状态枚举"""
|
||||
PENDING = auto() # 待审核
|
||||
APPROVED = auto() # 已通过
|
||||
REJECTED = auto() # 已拒绝
|
||||
MODIFIED = auto() # 已修改
|
||||
TIMEOUT = auto() # 已超时
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanReview:
|
||||
"""人工审核数据类"""
|
||||
review_id: str # 审核ID
|
||||
thread_id: str # 线程ID
|
||||
user_id: str # 用户ID
|
||||
status: ReviewStatus # 审核状态
|
||||
content_to_review: str # 待审核内容
|
||||
review_comment: str = "" # 审核意见
|
||||
modified_content: str = "" # 修改后的内容
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
reviewed_at: Optional[datetime] = None
|
||||
reviewer: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class HumanReviewStore(ABC):
|
||||
"""审核存储接口"""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, review: HumanReview) -> None:
|
||||
"""
|
||||
保存审核
|
||||
|
||||
Args:
|
||||
review: 审核对象
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, review_id: str) -> Optional[HumanReview]:
|
||||
"""
|
||||
获取审核
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
|
||||
Returns:
|
||||
审核对象,如果不存在返回 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_by_thread(self, thread_id: str) -> List[HumanReview]:
|
||||
"""
|
||||
获取线程的所有审核
|
||||
|
||||
Args:
|
||||
thread_id: 线程ID
|
||||
|
||||
Returns:
|
||||
审核列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pending(self, limit: int = 100) -> List[HumanReview]:
|
||||
"""
|
||||
获取待审核的列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
待审核列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_status(
|
||||
self,
|
||||
review_id: str,
|
||||
status: ReviewStatus,
|
||||
reviewer: Optional[str] = None,
|
||||
comment: str = "",
|
||||
modified_content: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
更新审核状态
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
status: 新状态
|
||||
reviewer: 审核人
|
||||
comment: 审核意见
|
||||
modified_content: 修改后的内容
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryReviewStore(HumanReviewStore):
|
||||
"""内存存储实现"""
|
||||
|
||||
def __init__(self):
|
||||
self._reviews: Dict[str, HumanReview] = {}
|
||||
|
||||
def save(self, review: HumanReview) -> None:
|
||||
"""
|
||||
保存审核
|
||||
|
||||
Args:
|
||||
review: 审核对象
|
||||
"""
|
||||
self._reviews[review.review_id] = review
|
||||
|
||||
def get(self, review_id: str) -> Optional[HumanReview]:
|
||||
"""
|
||||
获取审核
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
|
||||
Returns:
|
||||
审核对象,如果不存在返回 None
|
||||
"""
|
||||
return self._reviews.get(review_id)
|
||||
|
||||
def get_by_thread(self, thread_id: str) -> List[HumanReview]:
|
||||
"""
|
||||
获取线程的所有审核
|
||||
|
||||
Args:
|
||||
thread_id: 线程ID
|
||||
|
||||
Returns:
|
||||
审核列表
|
||||
"""
|
||||
return [
|
||||
review for review in self._reviews.values()
|
||||
if review.thread_id == thread_id
|
||||
]
|
||||
|
||||
def get_pending(self, limit: int = 100) -> List[HumanReview]:
|
||||
"""
|
||||
获取待审核的列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
待审核列表
|
||||
"""
|
||||
pending = [
|
||||
review for review in self._reviews.values()
|
||||
if review.status == ReviewStatus.PENDING
|
||||
]
|
||||
pending.sort(key=lambda r: r.created_at)
|
||||
return pending[:limit]
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
review_id: str,
|
||||
status: ReviewStatus,
|
||||
reviewer: Optional[str] = None,
|
||||
comment: str = "",
|
||||
modified_content: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
更新审核状态
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
status: 新状态
|
||||
reviewer: 审核人
|
||||
comment: 审核意见
|
||||
modified_content: 修改后的内容
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
review = self._reviews.get(review_id)
|
||||
if review is None:
|
||||
return False
|
||||
|
||||
review.status = status
|
||||
review.review_comment = comment
|
||||
review.modified_content = modified_content
|
||||
review.reviewer = reviewer
|
||||
review.reviewed_at = datetime.now()
|
||||
return True
|
||||
|
||||
|
||||
class HumanReviewNode:
|
||||
"""LangGraph 审核节点"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: HumanReviewStore,
|
||||
should_review: Optional[Callable[[Any], bool]] = None
|
||||
):
|
||||
"""
|
||||
初始化审核节点
|
||||
|
||||
Args:
|
||||
store: 审核存储
|
||||
should_review: 判断是否需要审核的函数
|
||||
"""
|
||||
self.store = store
|
||||
self.should_review = should_review or (lambda state: True)
|
||||
|
||||
def create_review(
|
||||
self,
|
||||
state: Any,
|
||||
thread_id: str,
|
||||
user_id: str,
|
||||
content_to_review: str
|
||||
) -> str:
|
||||
"""
|
||||
创建审核
|
||||
|
||||
Args:
|
||||
state: 状态
|
||||
thread_id: 线程ID
|
||||
user_id: 用户ID
|
||||
content_to_review: 待审核内容
|
||||
|
||||
Returns:
|
||||
审核ID
|
||||
"""
|
||||
review_id = str(uuid.uuid4())
|
||||
review = HumanReview(
|
||||
review_id=review_id,
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
status=ReviewStatus.PENDING,
|
||||
content_to_review=content_to_review
|
||||
)
|
||||
self.store.save(review)
|
||||
return review_id
|
||||
|
||||
def check_review_status(self, review_id: str) -> Optional[ReviewStatus]:
|
||||
"""
|
||||
检查审核状态
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
|
||||
Returns:
|
||||
审核状态,如果不存在返回 None
|
||||
"""
|
||||
review = self.store.get(review_id)
|
||||
return review.status if review else None
|
||||
|
||||
def get_review_result(self, review_id: str) -> Optional[HumanReview]:
|
||||
"""
|
||||
获取审核结果
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
|
||||
Returns:
|
||||
审核对象,如果不存在返回 None
|
||||
"""
|
||||
return self.store.get(review_id)
|
||||
|
||||
async def __call__(self, state: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
节点执行方法(LangGraph 兼容)
|
||||
|
||||
Args:
|
||||
state: 状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
# 检查是否需要审核
|
||||
if not self.should_review(state):
|
||||
return {"review_skipped": True}
|
||||
|
||||
# 从状态中提取信息
|
||||
thread_id = getattr(state, "thread_id", str(uuid.uuid4()))
|
||||
user_id = getattr(state, "user_id", "default_user")
|
||||
|
||||
# 获取待审核内容
|
||||
content_to_review = ""
|
||||
if hasattr(state, "messages") and state.messages:
|
||||
last_msg = state.messages[-1] if state.messages else None
|
||||
if last_msg and hasattr(last_msg, "content"):
|
||||
content_to_review = last_msg.content
|
||||
|
||||
# 创建审核
|
||||
review_id = self.create_review(state, thread_id, user_id, content_to_review)
|
||||
|
||||
# 返回状态更新
|
||||
return {
|
||||
"review_id": review_id,
|
||||
"review_pending": True,
|
||||
"interrupt": True # 标记需要中断
|
||||
}
|
||||
|
||||
|
||||
class ReviewManager:
|
||||
"""审核管理器"""
|
||||
|
||||
def __init__(self, store: Optional[HumanReviewStore] = None):
|
||||
"""
|
||||
初始化审核管理器
|
||||
|
||||
Args:
|
||||
store: 审核存储
|
||||
"""
|
||||
self.store = store or InMemoryReviewStore()
|
||||
|
||||
def request_review(
|
||||
self,
|
||||
thread_id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
请求审核
|
||||
|
||||
Args:
|
||||
thread_id: 线程ID
|
||||
user_id: 用户ID
|
||||
content: 待审核内容
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
审核ID
|
||||
"""
|
||||
review_id = str(uuid.uuid4())
|
||||
review = HumanReview(
|
||||
review_id=review_id,
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
status=ReviewStatus.PENDING,
|
||||
content_to_review=content,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
self.store.save(review)
|
||||
return review_id
|
||||
|
||||
def approve(
|
||||
self,
|
||||
review_id: str,
|
||||
reviewer: str,
|
||||
comment: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
审核通过
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
reviewer: 审核人
|
||||
comment: 审核意见
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
return self.store.update_status(
|
||||
review_id=review_id,
|
||||
status=ReviewStatus.APPROVED,
|
||||
reviewer=reviewer,
|
||||
comment=comment
|
||||
)
|
||||
|
||||
def reject(
|
||||
self,
|
||||
review_id: str,
|
||||
reviewer: str,
|
||||
comment: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
审核拒绝
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
reviewer: 审核人
|
||||
comment: 审核意见
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
return self.store.update_status(
|
||||
review_id=review_id,
|
||||
status=ReviewStatus.REJECTED,
|
||||
reviewer=reviewer,
|
||||
comment=comment
|
||||
)
|
||||
|
||||
def modify(
|
||||
self,
|
||||
review_id: str,
|
||||
reviewer: str,
|
||||
modified_content: str,
|
||||
comment: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
审核修改
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
reviewer: 审核人
|
||||
modified_content: 修改后的内容
|
||||
comment: 审核意见
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
return self.store.update_status(
|
||||
review_id=review_id,
|
||||
status=ReviewStatus.MODIFIED,
|
||||
reviewer=reviewer,
|
||||
comment=comment,
|
||||
modified_content=modified_content
|
||||
)
|
||||
|
||||
def get_pending_reviews(self, limit: int = 100) -> List[HumanReview]:
|
||||
"""
|
||||
获取待审核列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
待审核列表
|
||||
"""
|
||||
return self.store.get_pending(limit)
|
||||
|
||||
def get_review(self, review_id: str) -> Optional[HumanReview]:
|
||||
"""
|
||||
获取审核详情
|
||||
|
||||
Args:
|
||||
review_id: 审核ID
|
||||
|
||||
Returns:
|
||||
审核对象,如果不存在返回 None
|
||||
"""
|
||||
return self.store.get(review_id)
|
||||
382
backend/app/core/intent.py
Normal file
382
backend/app/core/intent.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
意图理解与推理模块 (React 模式)
|
||||
Intent Understanding & Reasoning Module (React Pattern)
|
||||
|
||||
这个模块实现了 React (Reasoning + Acting) 模式的意图理解节点,用于:
|
||||
1. 理解用户的查询意图
|
||||
2. 判断是否需要调用 RAG 检索
|
||||
3. 判断是否需要重新检索
|
||||
4. 决定下一步的动作(路由到子图、直接回答等)
|
||||
|
||||
核心设计:
|
||||
- 使用项目已有的 chat_services.py 进行 LLM 调用
|
||||
- 保持与现有架构一致(服务层模式)
|
||||
- 支持降级策略(LLM 失败时回退到规则)
|
||||
- 与 react_nodes.py 无缝集成
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
# ========== 1. 核心数据类型 ==========
|
||||
|
||||
class ReasoningAction(Enum):
|
||||
"""推理动作枚举 - 决定下一步做什么"""
|
||||
DIRECT_RESPONSE = auto() # 直接回答,不需要额外信息
|
||||
RETRIEVE_RAG = auto() # 需要调用 RAG 检索
|
||||
RE_RETRIEVE_RAG = auto() # 需要重新检索(更多/更好结果)
|
||||
ROUTE_SUBGRAPH = auto() # 需要路由到子图(contact/dictionary/news_analysis)
|
||||
CLARIFY = auto() # 需要澄清用户的问题
|
||||
UNKNOWN = auto() # 未知动作
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalConfig:
|
||||
"""检索配置"""
|
||||
need_retrieval: bool = False
|
||||
need_re_retrieval: bool = False
|
||||
retrieval_query: Optional[str] = None
|
||||
target_subgraph: Optional[str] = None
|
||||
collection_name: Optional[str] = None
|
||||
k: int = 5
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReasoningResult:
|
||||
"""推理结果数据类"""
|
||||
action: ReasoningAction = ReasoningAction.UNKNOWN
|
||||
confidence: float = 0.0
|
||||
reasoning: str = ""
|
||||
retrieval_config: RetrievalConfig = field(default_factory=RetrievalConfig)
|
||||
extracted_entities: Dict[str, Any] = field(default_factory=dict)
|
||||
next_hints: List[str] = field(default_factory=list)
|
||||
original_query: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ========== 2. React 推理器 ==========
|
||||
|
||||
class ReactIntentReasoner:
|
||||
"""
|
||||
React 模式意图推理器
|
||||
|
||||
核心功能:
|
||||
1. 使用 LLM 分析用户意图
|
||||
2. 决定是否需要 RAG 检索/重新检索
|
||||
3. 决定是否需要路由到子图
|
||||
4. 提供降级策略(规则匹配)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化推理器 - 懒加载 LLM 服务"""
|
||||
self._llm_service = None
|
||||
self._subgraph_keywords = {
|
||||
"contact": ["通讯录", "联系人", "contact", "email", "邮件", "邮箱"],
|
||||
"dictionary": ["词典", "单词", "翻译", "dictionary", "translate", "生词"],
|
||||
"news_analysis": ["资讯", "新闻", "分析", "news", "report", "热点"]
|
||||
}
|
||||
|
||||
def _get_llm_service(self):
|
||||
"""懒加载 LLM 服务(避免循环导入)"""
|
||||
if self._llm_service is None:
|
||||
from app.model_services.chat_services import get_chat_service
|
||||
self._llm_service = get_chat_service()
|
||||
return self._llm_service
|
||||
|
||||
async def reason(
|
||||
self,
|
||||
query: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> ReasoningResult:
|
||||
"""
|
||||
推理意图,决定下一步动作
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
context: 上下文信息(可能包含已检索文档、对话历史等)
|
||||
|
||||
Returns:
|
||||
ReasoningResult
|
||||
"""
|
||||
context = context or {}
|
||||
result = ReasoningResult(original_query=query)
|
||||
|
||||
# 策略1: 尝试使用 LLM 推理
|
||||
try:
|
||||
llm_result = await self._reason_with_llm(query, context)
|
||||
if llm_result.confidence >= 0.6: # 置信度足够高,直接返回
|
||||
return llm_result
|
||||
except Exception as e:
|
||||
print(f"[ReactReasoner] LLM 推理失败: {e}, 回退到规则")
|
||||
|
||||
# 策略2: LLM 失败或置信度低,使用规则匹配
|
||||
return self._reason_with_rules(query, context)
|
||||
|
||||
async def _reason_with_llm(
|
||||
self,
|
||||
query: str,
|
||||
context: Dict[str, Any]
|
||||
) -> ReasoningResult:
|
||||
"""使用 LLM 进行推理"""
|
||||
prompt = self._build_reasoning_prompt(query, context)
|
||||
llm = self._get_llm_service()
|
||||
|
||||
response = await llm.ainvoke(prompt)
|
||||
return self._parse_llm_response(response.content, query)
|
||||
|
||||
def _build_reasoning_prompt(self, query: str, context: Dict[str, Any]) -> str:
|
||||
"""构建推理提示词"""
|
||||
# 构建上下文描述
|
||||
context_parts = []
|
||||
if context.get("retrieved_docs"):
|
||||
context_parts.append(f"- 已检索文档: {len(context['retrieved_docs'])} 条")
|
||||
if context.get("previous_actions"):
|
||||
context_parts.append(f"- 历史动作: {context['previous_actions']}")
|
||||
|
||||
context_str = "\n".join(context_parts) if context_parts else "无"
|
||||
|
||||
return f"""你是一个专业的意图推理助手。请分析用户的查询,决定下一步应该做什么。
|
||||
|
||||
可选动作:
|
||||
1. DIRECT_RESPONSE - 直接回答(闲聊、打招呼、不需要额外信息)
|
||||
2. RETRIEVE_RAG - 需要查询知识库(询问知识、政策、文档等)
|
||||
3. RE_RETRIEVE_RAG - 需要重新检索(之前的结果不够,或者用户明确说"再查查"、"更多")
|
||||
4. ROUTE_SUBGRAPH - 需要路由到专门的子图:
|
||||
- contact: 通讯录、联系人、邮件相关
|
||||
- dictionary: 词典、翻译、单词相关
|
||||
- news_analysis: 资讯、新闻、热点分析相关
|
||||
5. CLARIFY - 需要澄清用户的问题(问题不明确)
|
||||
|
||||
用户查询: {query}
|
||||
当前上下文:
|
||||
{context_str}
|
||||
|
||||
请按以下 JSON 格式输出(仅输出 JSON,不要其他内容):
|
||||
{{
|
||||
"action": "DIRECT_RESPONSE|RETRIEVE_RAG|RE_RETRIEVE_RAG|ROUTE_SUBGRAPH|CLARIFY",
|
||||
"confidence": 0.85,
|
||||
"reasoning": "简要说明理由",
|
||||
"target_subgraph": "contact|dictionary|news_analysis|null (仅当 action=ROUTE_SUBGRAPH 时)",
|
||||
"retrieval_query": "优化后的检索查询 (可选)"
|
||||
}}
|
||||
"""
|
||||
|
||||
def _parse_llm_response(self, response: str, original_query: str) -> ReasoningResult:
|
||||
"""解析 LLM 响应"""
|
||||
result = ReasoningResult(original_query=original_query)
|
||||
|
||||
# 提取 JSON
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if not json_match:
|
||||
# 没有 JSON,回退到规则
|
||||
result.confidence = 0.0
|
||||
return result
|
||||
|
||||
try:
|
||||
data = json.loads(json_match.group())
|
||||
action_str = data.get("action", "UNKNOWN")
|
||||
|
||||
# 转换为枚举
|
||||
try:
|
||||
result.action = ReasoningAction[action_str]
|
||||
except KeyError:
|
||||
result.action = ReasoningAction.UNKNOWN
|
||||
|
||||
result.confidence = float(data.get("confidence", 0.5))
|
||||
result.reasoning = data.get("reasoning", "")
|
||||
|
||||
# 处理子图路由
|
||||
if result.action == ReasoningAction.ROUTE_SUBGRAPH:
|
||||
result.retrieval_config.target_subgraph = data.get("target_subgraph")
|
||||
result.metadata["target_subgraph"] = data.get("target_subgraph")
|
||||
|
||||
# 处理检索查询
|
||||
if result.action in [ReasoningAction.RETRIEVE_RAG, ReasoningAction.RE_RETRIEVE_RAG]:
|
||||
result.retrieval_config.need_retrieval = True
|
||||
result.retrieval_config.need_re_retrieval = (result.action == ReasoningAction.RE_RETRIEVE_RAG)
|
||||
result.retrieval_config.retrieval_query = data.get("retrieval_query", original_query)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"[ReactReasoner] 解析 LLM 响应失败: {e}")
|
||||
result.confidence = 0.0
|
||||
return result
|
||||
|
||||
def _reason_with_rules(
|
||||
self,
|
||||
query: str,
|
||||
context: Dict[str, Any]
|
||||
) -> ReasoningResult:
|
||||
"""基于规则的降级推理"""
|
||||
result = ReasoningResult(original_query=query)
|
||||
query_lower = query.lower()
|
||||
|
||||
# 1. 检查子图路由(最高优先级)
|
||||
for subgraph_name, keywords in self._subgraph_keywords.items():
|
||||
if any(kw in query_lower for kw in keywords):
|
||||
result.action = ReasoningAction.ROUTE_SUBGRAPH
|
||||
result.confidence = 0.85
|
||||
result.reasoning = f"关键词匹配: {subgraph_name} 子图"
|
||||
result.retrieval_config.target_subgraph = subgraph_name
|
||||
result.metadata["target_subgraph"] = subgraph_name
|
||||
return result
|
||||
|
||||
# 2. 检查是否需要重新检索
|
||||
re_retrieve_keywords = ["再", "重新", "更多", "不够", "其他", "没找到", "找不到", "不对", "another", "again", "more"]
|
||||
has_re_retrieve = any(kw in query_lower for kw in re_retrieve_keywords)
|
||||
has_docs = context.get("retrieved_docs") and len(context["retrieved_docs"]) > 0
|
||||
|
||||
if has_re_retrieve or (has_docs and len(context["retrieved_docs"]) < 2):
|
||||
result.action = ReasoningAction.RE_RETRIEVE_RAG
|
||||
result.confidence = 0.8 if has_re_retrieve else 0.65
|
||||
result.reasoning = "需要重新检索更多/更好结果"
|
||||
result.retrieval_config.need_retrieval = True
|
||||
result.retrieval_config.need_re_retrieval = True
|
||||
result.retrieval_config.retrieval_query = query
|
||||
return result
|
||||
|
||||
# 3. 检查是否需要 RAG 检索
|
||||
retrieve_keywords = ["什么", "怎么", "如何", "为什么", "哪", "谁", "介绍", "解释", "说明", "资料", "文档", "查询", "搜索", "what", "how", "why", "where", "who", "tell me", "explain", "about", "information"]
|
||||
has_retrieve = any(kw in query_lower for kw in retrieve_keywords)
|
||||
|
||||
if has_retrieve or len(query.strip()) > 5:
|
||||
result.action = ReasoningAction.RETRIEVE_RAG
|
||||
result.confidence = 0.8 if has_retrieve else 0.6
|
||||
result.reasoning = "需要查询知识库"
|
||||
result.retrieval_config.need_retrieval = True
|
||||
result.retrieval_config.retrieval_query = query
|
||||
return result
|
||||
|
||||
# 4. 检查直接回答
|
||||
direct_keywords = ["你好", "您好", "hi", "hello", "hey", "早上好", "晚上好", "下午好", "嗨", "谢谢", "感谢", "多谢", "thanks", "thank you", "再见", "拜拜", "goodbye", "回见"]
|
||||
if any(kw in query_lower for kw in direct_keywords):
|
||||
result.action = ReasoningAction.DIRECT_RESPONSE
|
||||
result.confidence = 0.9
|
||||
result.reasoning = "直接回答(问候/感谢/道别)"
|
||||
return result
|
||||
|
||||
# 5. 检查是否需要澄清
|
||||
if len(query.strip()) < 3 or any(q in query for q in ["?", "?", "哪个", "哪些", "什么意思", "请", "能详细"]):
|
||||
result.action = ReasoningAction.CLARIFY
|
||||
result.confidence = 0.7
|
||||
result.reasoning = "需要澄清问题"
|
||||
result.next_hints = ["请提供更多细节", "您想了解什么方面的内容?", "能否具体说明一下?"]
|
||||
return result
|
||||
|
||||
# 6. 默认直接回答
|
||||
result.action = ReasoningAction.DIRECT_RESPONSE
|
||||
result.confidence = 0.5
|
||||
result.reasoning = "默认直接回答模式"
|
||||
return result
|
||||
|
||||
|
||||
# ========== 3. 便捷函数(保持与旧代码兼容) ==========
|
||||
|
||||
# 全局推理器实例(懒加载)
|
||||
_reasoner: Optional[ReactIntentReasoner] = None
|
||||
|
||||
|
||||
def _get_reasoner() -> ReactIntentReasoner:
|
||||
"""获取推理器实例"""
|
||||
global _reasoner
|
||||
if _reasoner is None:
|
||||
_reasoner = ReactIntentReasoner()
|
||||
return _reasoner
|
||||
|
||||
|
||||
async def react_reason_async(
|
||||
query: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> ReasoningResult:
|
||||
"""
|
||||
便捷函数:异步 React 推理(推荐使用)
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
context: 上下文
|
||||
|
||||
Returns:
|
||||
ReasoningResult
|
||||
"""
|
||||
reasoner = _get_reasoner()
|
||||
return await reasoner.reason(query, context)
|
||||
|
||||
|
||||
def react_reason(
|
||||
query: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> ReasoningResult:
|
||||
"""
|
||||
便捷函数:同步 React 推理(保持向后兼容)
|
||||
|
||||
注意:内部会运行事件循环,建议在异步环境中使用 react_reason_async
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
context: 上下文
|
||||
|
||||
Returns:
|
||||
ReasoningResult
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
# 尝试获取现有事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 已经在运行的循环中,创建任务
|
||||
task = loop.create_task(react_reason_async(query, context))
|
||||
# 注意:这里不能真正等待,会导致死锁
|
||||
# 降级到规则推理
|
||||
print("[ReactReasoner] 检测到运行中的事件循环,使用规则推理")
|
||||
reasoner = _get_reasoner()
|
||||
return reasoner._reason_with_rules(query, context or {})
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(react_reason_async(query, context))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def get_route_by_reasoning(result: ReasoningResult) -> str:
|
||||
"""
|
||||
根据推理结果获取路由字符串(与旧代码兼容)
|
||||
|
||||
Args:
|
||||
result: ReasoningResult
|
||||
|
||||
Returns:
|
||||
str: 路由标识
|
||||
"""
|
||||
action_to_route = {
|
||||
ReasoningAction.DIRECT_RESPONSE: "direct_response",
|
||||
ReasoningAction.RETRIEVE_RAG: "retrieve_rag",
|
||||
ReasoningAction.RE_RETRIEVE_RAG: "re_retrieve_rag",
|
||||
ReasoningAction.CLARIFY: "clarify",
|
||||
ReasoningAction.ROUTE_SUBGRAPH: result.metadata.get("target_subgraph", "unknown_subgraph"),
|
||||
ReasoningAction.UNKNOWN: "unknown",
|
||||
}
|
||||
return action_to_route.get(result.action, "unknown")
|
||||
|
||||
|
||||
# ========== 4. 导出 ==========
|
||||
|
||||
__all__ = [
|
||||
"ReasoningAction",
|
||||
"RetrievalConfig",
|
||||
"ReasoningResult",
|
||||
"ReactIntentReasoner",
|
||||
"react_reason",
|
||||
"react_reason_async",
|
||||
"get_route_by_reasoning"
|
||||
]
|
||||
193
backend/app/core/intent_classifier.py
Normal file
193
backend/app/core/intent_classifier.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# backend/app/agent/intent_classifier.py
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
from app.model_services.chat_services import get_chat_service
|
||||
|
||||
|
||||
class IntentType(Enum):
|
||||
"""意图类型枚举"""
|
||||
KNOWLEDGE = "knowledge" # 知识查询 → RAG
|
||||
REALTIME = "realtime" # 实时数据 → 工具
|
||||
ACTION = "action" # 执行操作 → 工具
|
||||
CHITCHAT = "chitchat" # 闲聊 → 直接回答
|
||||
CLARIFY = "clarify" # 需要澄清 → 反问用户
|
||||
MIXED = "mixed" # 复杂任务 → React 循环
|
||||
UNKNOWN = "unknown" # 未知意图
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntentResult:
|
||||
"""意图识别结果"""
|
||||
intent_type: IntentType
|
||||
confidence: float
|
||||
reasoning: str
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
|
||||
class IntentClassifier:
|
||||
"""意图分类器"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = get_chat_service()
|
||||
self._intent_examples = self._build_examples()
|
||||
|
||||
def _build_examples(self) -> str:
|
||||
"""构建少样本示例"""
|
||||
return """
|
||||
<示例>
|
||||
用户: "公司的报销政策是什么?"
|
||||
意图: knowledge
|
||||
推理: 用户询问公司内部政策,需要查询知识库
|
||||
|
||||
用户: "帮我查一下订单 12345 的状态"
|
||||
意图: realtime
|
||||
推理: 需要查询实时订单数据
|
||||
|
||||
用户: "帮我申请退款,订单号 67890"
|
||||
意图: action
|
||||
推理: 需要执行退款操作
|
||||
|
||||
用户: "今天天气怎么样?"
|
||||
意图: realtime
|
||||
推理: 需要查询实时天气数据
|
||||
|
||||
用户: "帮我写一份邮件给客户,查询订单状态,然后附上退款政策"
|
||||
意图: mixed
|
||||
推理: 需要查询订单、查询政策、生成邮件,多步骤任务
|
||||
|
||||
用户: "你好"
|
||||
意图: chitchat
|
||||
推理: 简单寒暄
|
||||
|
||||
用户: "我想查点东西..."
|
||||
意图: clarify
|
||||
推理: 用户没有说清楚要查什么
|
||||
</示例>
|
||||
"""
|
||||
|
||||
async def classify(self, user_input: str, context: Optional[str] = None) -> IntentResult:
|
||||
"""
|
||||
分类用户意图
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
context: 对话上下文(可选)
|
||||
|
||||
Returns:
|
||||
IntentResult
|
||||
"""
|
||||
prompt = self._build_classification_prompt(user_input, context)
|
||||
|
||||
try:
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
result = self._parse_response(response.content)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Intent classification error: {e}")
|
||||
# 降级策略:默认返回 mixed,走 React 循环
|
||||
return IntentResult(
|
||||
intent_type=IntentType.MIXED,
|
||||
confidence=0.5,
|
||||
reasoning="分类失败,走通用路径"
|
||||
)
|
||||
|
||||
def _build_classification_prompt(self, user_input: str, context: Optional[str]) -> str:
|
||||
"""构建分类提示词"""
|
||||
context_part = f"\n对话上下文:\n{context}" if context else ""
|
||||
|
||||
return f"""
|
||||
你是一个专业的意图识别助手。请分析用户的输入,判断其意图类型。
|
||||
|
||||
可选意图类型:
|
||||
- knowledge: 用户询问知识、政策、文档等,需要查询知识库
|
||||
- realtime: 用户需要查询实时数据(订单状态、天气、股票等)
|
||||
- action: 用户需要执行某项操作(退款、下单、发送邮件等)
|
||||
- chitchat: 用户只是闲聊、打招呼,不需要工具或检索
|
||||
- clarify: 用户的问题不明确,需要追问澄清
|
||||
- mixed: 复杂任务,需要多步骤处理(同时需要检索+工具)
|
||||
|
||||
{self._intent_examples}
|
||||
|
||||
用户输入: {user_input}
|
||||
{context_part}
|
||||
|
||||
请按以下格式输出(纯JSON):
|
||||
{{
|
||||
"intent": "knowledge|realtime|action|chitchat|clarify|mixed",
|
||||
"confidence": 0.85,
|
||||
"reasoning": "简要说明为什么这个意图"
|
||||
}}
|
||||
"""
|
||||
|
||||
def _parse_response(self, response: str) -> IntentResult:
|
||||
"""解析 LLM 响应"""
|
||||
import json
|
||||
import re
|
||||
|
||||
# 尝试提取 JSON
|
||||
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group())
|
||||
return IntentResult(
|
||||
intent_type=IntentType(data['intent']),
|
||||
confidence=float(data['confidence']),
|
||||
reasoning=data['reasoning']
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 降级策略:关键词匹配
|
||||
return self._fallback_classify(response)
|
||||
|
||||
def _fallback_classify(self, user_input: str) -> IntentResult:
|
||||
"""关键词匹配降级策略"""
|
||||
keywords = {
|
||||
IntentType.KNOWLEDGE: ['政策', '文档', '规定', '手册', '指南', '什么是', '怎么'],
|
||||
IntentType.REALTIME: ['订单', '状态', '天气', '股票', '价格', '库存'],
|
||||
IntentType.ACTION: ['退款', '取消', '发送', '申请', '修改', '删除'],
|
||||
IntentType.CHITCHAT: ['你好', 'hi', 'hello', '嗨', '早上好', '晚上好'],
|
||||
}
|
||||
|
||||
for intent_type, words in keywords.items():
|
||||
if any(word in user_input.lower() for word in words):
|
||||
return IntentResult(
|
||||
intent_type=intent_type,
|
||||
confidence=0.7,
|
||||
reasoning=f"关键词匹配: {', '.join(words)}"
|
||||
)
|
||||
|
||||
# 默认走混合路径
|
||||
return IntentResult(
|
||||
intent_type=IntentType.MIXED,
|
||||
confidence=0.5,
|
||||
reasoning="无法明确分类,走通用路径"
|
||||
)
|
||||
|
||||
async def batch_classify(self, inputs: list[str]) -> list[IntentResult]:
|
||||
"""批量分类(带缓存)"""
|
||||
# 可以添加缓存逻辑
|
||||
results = []
|
||||
for inp in inputs:
|
||||
results.append(await self.classify(inp))
|
||||
return results
|
||||
|
||||
|
||||
# 全局实例
|
||||
_classifier: Optional[IntentClassifier] = None
|
||||
|
||||
|
||||
def get_intent_classifier() -> IntentClassifier:
|
||||
"""获取意图分类器实例"""
|
||||
global _classifier
|
||||
if _classifier is None:
|
||||
_classifier = IntentClassifier()
|
||||
return _classifier
|
||||
125
backend/app/core/state_base.py
Normal file
125
backend/app/core/state_base.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
状态基类工具模块
|
||||
提供类型安全的 LangGraph 状态基类和常用状态操作工具
|
||||
|
||||
功能:
|
||||
1. BaseState - 基础状态基类,包含通用字段(消息、token统计、耗时等)
|
||||
2. StateUtils - 状态操作工具类,提供常用的状态访问和修改方法
|
||||
3. TypedStateBuilder - 类型化状态构建器,支持链式创建自定义状态
|
||||
4. StateValidation - 状态验证工具,确保状态完整性
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
"""执行阶段枚举"""
|
||||
INIT = auto()
|
||||
INTENT_PARSING = auto()
|
||||
EXECUTING = auto()
|
||||
FORMATTING = auto()
|
||||
COMPLETED = auto()
|
||||
ERROR = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
"""Token 使用统计"""
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
def add(self, other: 'TokenUsage') -> 'TokenUsage':
|
||||
"""累加另一个统计"""
|
||||
return TokenUsage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
total_tokens=self.total_tokens + other.total_tokens
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseState:
|
||||
"""
|
||||
基础状态基类
|
||||
所有子图的状态都应继承此类
|
||||
"""
|
||||
# 核心字段
|
||||
user_query: str = ""
|
||||
user_id: str = "default"
|
||||
thread_id: Optional[str] = None
|
||||
|
||||
# 执行阶段
|
||||
current_phase: Phase = Phase.INIT
|
||||
phase_history: List[Phase] = field(default_factory=list)
|
||||
|
||||
# 结果
|
||||
final_result: str = ""
|
||||
success: bool = True
|
||||
error_message: str = ""
|
||||
|
||||
# 统计
|
||||
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
|
||||
# 元数据
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后调用"""
|
||||
if self.start_time is None:
|
||||
self.start_time = datetime.now()
|
||||
if not self.phase_history:
|
||||
self.phase_history.append(self.current_phase)
|
||||
|
||||
def transition_to(self, phase: Phase) -> None:
|
||||
"""转换到新阶段"""
|
||||
self.current_phase = phase
|
||||
self.phase_history.append(phase)
|
||||
|
||||
def complete(self, result: str, success: bool = True) -> None:
|
||||
"""完成执行"""
|
||||
self.final_result = result
|
||||
self.success = success
|
||||
self.end_time = datetime.now()
|
||||
self.transition_to(Phase.COMPLETED)
|
||||
|
||||
def fail(self, error: str) -> None:
|
||||
"""执行失败"""
|
||||
self.error_message = error
|
||||
self.success = False
|
||||
self.end_time = datetime.now()
|
||||
self.transition_to(Phase.ERROR)
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
"""获取耗时(秒)"""
|
||||
if self.start_time and self.end_time:
|
||||
return (self.end_time - self.start_time).total_seconds()
|
||||
return 0.0
|
||||
|
||||
|
||||
class StateUtils:
|
||||
"""状态操作工具类"""
|
||||
|
||||
@staticmethod
|
||||
def merge_metadata(base: Dict[str, Any], overlay: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""合并元数据"""
|
||||
result = base.copy()
|
||||
result.update(overlay)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def create_snapshot(state: BaseState) -> Dict[str, Any]:
|
||||
"""创建状态快照(用于调试)"""
|
||||
return {
|
||||
"user_query": state.user_query,
|
||||
"user_id": state.user_id,
|
||||
"current_phase": state.current_phase.name,
|
||||
"success": state.success,
|
||||
"elapsed_time": state.elapsed_time
|
||||
}
|
||||
Reference in New Issue
Block a user