重构:通讯录子图支持 async 和 API 注入工厂模式
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
This commit is contained in:
@@ -1,11 +1,31 @@
|
|||||||
"""
|
"""
|
||||||
通讯录子图构建器
|
通讯录子图构建器
|
||||||
Contact Subgraph Builder
|
Contact Subgraph Builder
|
||||||
|
支持 API 注入的工厂模式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
|
||||||
from .state import ContactState
|
from .state import ContactState
|
||||||
|
from .nodes import create_contact_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def build_contact_subgraph(contact_api=None):
|
||||||
|
"""
|
||||||
|
构建通讯录子图(工厂模式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contact_api: 可选的 ContactAPIClient 实例(支持真实数据库或模拟模式)
|
||||||
|
不传入则使用默认模拟 API(向后兼容)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的 StateGraph
|
||||||
|
"""
|
||||||
|
# 创建节点(传入 API)
|
||||||
|
nodes = create_contact_nodes(contact_api) if contact_api else None
|
||||||
|
|
||||||
|
# 如果没有传入 API,使用向后兼容的导入
|
||||||
|
if nodes is None:
|
||||||
from .nodes import (
|
from .nodes import (
|
||||||
parse_intent,
|
parse_intent,
|
||||||
list_contacts,
|
list_contacts,
|
||||||
@@ -18,15 +38,18 @@ from .nodes import (
|
|||||||
format_result,
|
format_result,
|
||||||
should_continue
|
should_continue
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
parse_intent = nodes["parse_intent"]
|
||||||
|
list_contacts = nodes["list_contacts"]
|
||||||
|
add_contact = nodes["add_contact"]
|
||||||
|
list_emails = nodes["list_emails"]
|
||||||
|
generate_email_draft = nodes["generate_email_draft"]
|
||||||
|
human_review = nodes["human_review"]
|
||||||
|
send_email = nodes["send_email"]
|
||||||
|
sniff_contacts = nodes["sniff_contacts"]
|
||||||
|
format_result = nodes["format_result"]
|
||||||
|
should_continue = nodes["should_continue"]
|
||||||
|
|
||||||
|
|
||||||
def build_contact_subgraph() -> StateGraph:
|
|
||||||
"""
|
|
||||||
构建通讯录子图
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
配置好的 StateGraph
|
|
||||||
"""
|
|
||||||
# 创建图
|
# 创建图
|
||||||
graph = StateGraph(ContactState)
|
graph = StateGraph(ContactState)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
通讯录子图节点 - 使用公共工具版本
|
通讯录子图节点 - 使用公共工具版本
|
||||||
Contact Subgraph Nodes - Using Common Tools
|
Contact Subgraph Nodes - Using Common Tools
|
||||||
|
支持 async 和 API 注入
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
@@ -10,14 +11,25 @@ from datetime import datetime
|
|||||||
from ..common import MarkdownFormatter
|
from ..common import MarkdownFormatter
|
||||||
|
|
||||||
from .state import ContactState, ContactAction, Contact, Email
|
from .state import ContactState, ContactAction, Contact, Email
|
||||||
from .api_client import contact_api
|
from .api_client import ContactAPIClient
|
||||||
|
|
||||||
|
|
||||||
# 模拟联系人数据库(临时存储)
|
# 模拟联系人数据库(临时存储,保留作为备选)
|
||||||
CONTACT_DB = {}
|
CONTACT_DB = {}
|
||||||
|
|
||||||
|
|
||||||
def parse_intent(state: ContactState) -> ContactState:
|
def create_contact_nodes(contact_api: ContactAPIClient):
|
||||||
|
"""
|
||||||
|
创建通讯录子图节点工厂函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contact_api: 已初始化的 ContactAPIClient(支持真实数据库或模拟模式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
节点函数字典
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def parse_intent(state: ContactState) -> ContactState:
|
||||||
"""
|
"""
|
||||||
解析用户意图节点
|
解析用户意图节点
|
||||||
"""
|
"""
|
||||||
@@ -37,21 +49,19 @@ def parse_intent(state: ContactState) -> ContactState:
|
|||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
async def list_contacts(state: ContactState) -> ContactState:
|
||||||
def list_contacts(state: ContactState) -> ContactState:
|
|
||||||
"""
|
"""
|
||||||
列出联系人节点
|
列出联系人节点
|
||||||
"""
|
"""
|
||||||
state.current_phase = "executing"
|
state.current_phase = "executing"
|
||||||
|
|
||||||
# 使用 API 客户端
|
# 使用 API 客户端(async)
|
||||||
contacts = contact_api.list_contacts(state.user_id)
|
contacts = await contact_api.list_contacts(state.user_id)
|
||||||
state.contacts = contacts
|
state.contacts = contacts
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
async def add_contact(state: ContactState) -> ContactState:
|
||||||
def add_contact(state: ContactState) -> ContactState:
|
|
||||||
"""
|
"""
|
||||||
添加联系人节点
|
添加联系人节点
|
||||||
"""
|
"""
|
||||||
@@ -59,59 +69,57 @@ def add_contact(state: ContactState) -> ContactState:
|
|||||||
|
|
||||||
# 使用 API 客户端(简化添加,实际项目应解析用户输入)
|
# 使用 API 客户端(简化添加,实际项目应解析用户输入)
|
||||||
new_contact = Contact(
|
new_contact = Contact(
|
||||||
id=len(CONTACT_DB) + 1,
|
id=str(len(CONTACT_DB) + 1),
|
||||||
name="新联系人",
|
name="新联系人",
|
||||||
email="new@example.com",
|
email="new@example.com",
|
||||||
phone="13800000000",
|
phone="13800000000",
|
||||||
created_at=datetime.now()
|
created_at=datetime.now().isoformat()
|
||||||
)
|
)
|
||||||
|
# 保存到数据库
|
||||||
|
await contact_api.add_contact(state.user_id, new_contact)
|
||||||
state.current_contact = new_contact
|
state.current_contact = new_contact
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
async def list_emails(state: ContactState) -> ContactState:
|
||||||
def list_emails(state: ContactState) -> ContactState:
|
|
||||||
"""
|
"""
|
||||||
列出邮件节点
|
列出邮件节点
|
||||||
"""
|
"""
|
||||||
state.current_phase = "executing"
|
state.current_phase = "executing"
|
||||||
|
|
||||||
# 使用 API 客户端
|
# 使用 API 客户端(async)
|
||||||
emails = contact_api.list_emails(state.user_id)
|
emails = await contact_api.list_emails(state.user_id)
|
||||||
state.emails = emails
|
state.emails = emails
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
async def generate_email_draft(state: ContactState) -> ContactState:
|
||||||
def generate_email_draft(state: ContactState) -> ContactState:
|
|
||||||
"""
|
"""
|
||||||
生成邮件草稿节点
|
生成邮件草稿节点
|
||||||
"""
|
"""
|
||||||
state.current_phase = "executing"
|
state.current_phase = "executing"
|
||||||
|
|
||||||
# 使用 API 客户端
|
# 使用 API 客户端(async)
|
||||||
draft = contact_api.generate_email_draft(state.user_query)
|
draft = await contact_api.generate_email_draft(state.user_query)
|
||||||
state.draft_recipient = draft.get("recipient", "recipient@example.com")
|
state.draft_recipient = draft.get("recipient", "recipient@example.com")
|
||||||
state.draft_subject = draft.get("subject", "邮件主题")
|
state.draft_subject = draft.get("subject", "邮件主题")
|
||||||
state.draft_body = draft.get("body", "邮件正文")
|
state.draft_body = draft.get("body", "邮件正文")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
async def sniff_contacts(state: ContactState) -> ContactState:
|
||||||
def sniff_contacts(state: ContactState) -> ContactState:
|
|
||||||
"""
|
"""
|
||||||
嗅探联系人节点
|
嗅探联系人节点
|
||||||
"""
|
"""
|
||||||
state.current_phase = "executing"
|
state.current_phase = "executing"
|
||||||
|
|
||||||
# 使用 API 客户端
|
# 使用 API 客户端(async)
|
||||||
contacts = contact_api.sniff_contacts(state.user_query)
|
contacts = await contact_api.sniff_contacts(state.user_query)
|
||||||
state.sniffed_contacts = contacts
|
state.sniffed_contacts = contacts
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
async def format_result(state: ContactState) -> ContactState:
|
||||||
def format_result(state: ContactState) -> ContactState:
|
|
||||||
"""
|
"""
|
||||||
格式化结果节点(使用公共工具)
|
格式化结果节点(使用公共工具)
|
||||||
"""
|
"""
|
||||||
@@ -137,10 +145,20 @@ def format_result(state: ContactState) -> ContactState:
|
|||||||
elif state.action == ContactAction.EMAIL_LIST and state.emails:
|
elif state.action == ContactAction.EMAIL_LIST and state.emails:
|
||||||
output_lines.append(md.heading("📬 邮件列表", 2))
|
output_lines.append(md.heading("📬 邮件列表", 2))
|
||||||
output_lines.append("")
|
output_lines.append("")
|
||||||
email_data = [
|
# 兼容两种 date 格式
|
||||||
{"发件人": e.sender, "主题": e.subject, "时间": e.received_at.strftime('%Y-%m-%d %H:%M')}
|
email_data = []
|
||||||
for e in state.emails
|
for e in state.emails:
|
||||||
]
|
date_str = e.date
|
||||||
|
if hasattr(e, 'received_at') and e.received_at:
|
||||||
|
try:
|
||||||
|
date_str = e.received_at.strftime('%Y-%m-%d %H:%M')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
email_data.append({
|
||||||
|
"发件人": e.sender,
|
||||||
|
"主题": e.subject,
|
||||||
|
"时间": date_str
|
||||||
|
})
|
||||||
output_lines.append(md.table(email_data))
|
output_lines.append(md.table(email_data))
|
||||||
|
|
||||||
elif state.action == ContactAction.EMAIL_SEND and state.draft_subject:
|
elif state.action == ContactAction.EMAIL_SEND and state.draft_subject:
|
||||||
@@ -175,8 +193,7 @@ def format_result(state: ContactState) -> ContactState:
|
|||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
async def human_review(state: ContactState) -> ContactState:
|
||||||
def human_review(state: ContactState) -> ContactState:
|
|
||||||
"""
|
"""
|
||||||
人工审核节点(用于邮件草稿)
|
人工审核节点(用于邮件草稿)
|
||||||
"""
|
"""
|
||||||
@@ -185,14 +202,13 @@ def human_review(state: ContactState) -> ContactState:
|
|||||||
state.needs_approval = True
|
state.needs_approval = True
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
async def send_email(state: ContactState) -> ContactState:
|
||||||
def send_email(state: ContactState) -> ContactState:
|
|
||||||
"""
|
"""
|
||||||
发送邮件节点
|
发送邮件节点
|
||||||
"""
|
"""
|
||||||
state.current_phase = "executing"
|
state.current_phase = "executing"
|
||||||
# 使用 API 客户端发送邮件
|
# 使用 API 客户端发送邮件(async)
|
||||||
success = contact_api.send_email(
|
success = await contact_api.send_email(
|
||||||
state.user_id,
|
state.user_id,
|
||||||
state.draft_recipient,
|
state.draft_recipient,
|
||||||
state.draft_subject,
|
state.draft_subject,
|
||||||
@@ -201,7 +217,6 @@ def send_email(state: ContactState) -> ContactState:
|
|||||||
state.success = success
|
state.success = success
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
def should_continue(state: ContactState) -> str:
|
def should_continue(state: ContactState) -> str:
|
||||||
"""
|
"""
|
||||||
条件路由函数:根据 action 和状态决定下一个节点
|
条件路由函数:根据 action 和状态决定下一个节点
|
||||||
@@ -228,3 +243,36 @@ def should_continue(state: ContactState) -> str:
|
|||||||
return "sniff_contacts"
|
return "sniff_contacts"
|
||||||
else:
|
else:
|
||||||
return "format_result"
|
return "format_result"
|
||||||
|
|
||||||
|
# 返回节点字典
|
||||||
|
return {
|
||||||
|
"parse_intent": parse_intent,
|
||||||
|
"list_contacts": list_contacts,
|
||||||
|
"add_contact": add_contact,
|
||||||
|
"list_emails": list_emails,
|
||||||
|
"generate_email_draft": generate_email_draft,
|
||||||
|
"sniff_contacts": sniff_contacts,
|
||||||
|
"format_result": format_result,
|
||||||
|
"human_review": human_review,
|
||||||
|
"send_email": send_email,
|
||||||
|
"should_continue": should_continue
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 向后兼容的全局版本(使用模拟 API) ==========
|
||||||
|
from .api_client import contact_api as _default_contact_api
|
||||||
|
|
||||||
|
# 创建默认节点(用模拟 API,保持向后兼容)
|
||||||
|
_default_nodes = create_contact_nodes(_default_contact_api)
|
||||||
|
|
||||||
|
# 导出默认节点
|
||||||
|
parse_intent = _default_nodes["parse_intent"]
|
||||||
|
list_contacts = _default_nodes["list_contacts"]
|
||||||
|
add_contact = _default_nodes["add_contact"]
|
||||||
|
list_emails = _default_nodes["list_emails"]
|
||||||
|
generate_email_draft = _default_nodes["generate_email_draft"]
|
||||||
|
sniff_contacts = _default_nodes["sniff_contacts"]
|
||||||
|
format_result = _default_nodes["format_result"]
|
||||||
|
human_review = _default_nodes["human_review"]
|
||||||
|
send_email = _default_nodes["send_email"]
|
||||||
|
should_continue = _default_nodes["should_continue"]
|
||||||
|
|||||||
Reference in New Issue
Block a user