统一 Repository 方案:添加 db 基类和子图模型 + 通讯录 API 支持真实数据库
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-04-27 16:37:45 +08:00
parent 0cb9571db7
commit 29016f8792
6 changed files with 460 additions and 57 deletions

View File

@@ -1,8 +1,7 @@
""" """
通讯录子图API调用工具 通讯录子图 API 调用工具
Contact Subgraph API Client 支持模拟数据和真实数据库两种模式
""" """
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
from datetime import datetime from datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
@@ -10,6 +9,8 @@ from dataclasses import dataclass
from .state import Contact, Email from .state import Contact, Email
# ========== 模拟数据(保留作为备选)==========
# 模拟数据库 # 模拟数据库
MOCK_CONTACTS_DB = {} MOCK_CONTACTS_DB = {}
MOCK_EMAILS_DB = [] MOCK_EMAILS_DB = []
@@ -18,13 +19,75 @@ MOCK_EMAILS_DB = []
@dataclass @dataclass
class ContactAPIClient: class ContactAPIClient:
""" """
通讯录API客户端 - 可扩展支持多种后端 通讯录 API 客户端 - 支持真实数据库和模拟模式
使用方式:
1. 真实数据库模式:传入 conn 参数
2. 模拟模式:不传入 conn或 conn 为 None
""" """
def __init__(self, conn=None):
"""
初始化
Args:
conn: 数据库连接(来自 checkpointer.conn为 None 时使用模拟模式
"""
self.conn = conn
self._use_db = conn is not None
if self._use_db:
try:
from ...db.models import ContactRepository, ContactEntity
self._repo = ContactRepository(conn)
except Exception as e:
print(f"Repository 初始化失败,回退到模拟模式: {e}")
self._use_db = False
self._repo = None
# ========== 真实数据库方法 ==========
async def list_contacts_db(self, user_id: str = "default") -> List[Contact]:
"""真实数据库:获取联系人列表"""
if not self._repo:
return await self.list_contacts_mock(user_id)
entities = await self._repo.list_by_user(user_id)
return [
Contact(
id=e.id,
name=e.name,
phone=e.phone,
email=e.email,
company=e.company,
position=e.position,
created_at=e.created_at
)
for e in entities
]
async def add_contact_db(self, user_id: str, contact: Contact) -> bool:
"""真实数据库:添加联系人"""
if not self._repo:
return await self.save_contact_mock(user_id, contact)
from ...db.models import ContactEntity
entity = ContactEntity(
user_id=user_id,
name=contact.name,
phone=contact.phone,
email=contact.email,
company=contact.company,
position=contact.position,
created_at=contact.created_at or datetime.now().isoformat()
)
await self._repo.insert(entity)
return True
# ========== 模拟数据方法(保留)==========
def list_contacts_mock(self, user_id: str = "default") -> List[Contact]: def list_contacts_mock(self, user_id: str = "default") -> List[Contact]:
""" """模拟查询联系人列表"""
模拟查询联系人列表
"""
if user_id not in MOCK_CONTACTS_DB: if user_id not in MOCK_CONTACTS_DB:
# 初始化一些示例数据 # 初始化一些示例数据
MOCK_CONTACTS_DB[user_id] = [ MOCK_CONTACTS_DB[user_id] = [
@@ -54,16 +117,13 @@ class ContactAPIClient:
company="咨询公司", company="咨询公司",
position="顾问", position="顾问",
created_at=datetime.now().isoformat() created_at=datetime.now().isoformat()
) ),
] ]
return MOCK_CONTACTS_DB[user_id] return MOCK_CONTACTS_DB[user_id]
def extract_contact_info_mock(self, query: str) -> Optional[Dict[str, Any]]: def extract_contact_info_mock(self, query: str) -> Optional[Dict[str, Any]]:
""" """模拟从查询中提取联系人信息"""
模拟从查询中提取联系人信息
"""
# 简化的提取逻辑
import re import re
# 提取邮箱 # 提取邮箱
@@ -71,11 +131,8 @@ class ContactAPIClient:
# 提取手机号 # 提取手机号
phone_match = re.search(r'1[3-9]\d{9}', query) phone_match = re.search(r'1[3-9]\d{9}', query)
# 提取姓名(简单匹配) # 提取姓名(简单匹配)
# 先看是否有"添加"等关键词
if any(keyword in query for keyword in ["添加", "add"]): if any(keyword in query for keyword in ["添加", "add"]):
# 尝试提取姓名
name = "未知" name = "未知"
# 简单的逻辑:去掉关键词、去掉邮箱和电话后剩下的
clean_query = query clean_query = query
if email_match: if email_match:
clean_query = clean_query.replace(email_match.group(), "") clean_query = clean_query.replace(email_match.group(), "")
@@ -95,9 +152,7 @@ class ContactAPIClient:
return None return None
def save_contact_mock(self, user_id: str, contact: Contact) -> bool: def save_contact_mock(self, user_id: str, contact: Contact) -> bool:
""" """模拟保存联系人"""
模拟保存联系人
"""
if user_id not in MOCK_CONTACTS_DB: if user_id not in MOCK_CONTACTS_DB:
MOCK_CONTACTS_DB[user_id] = [] MOCK_CONTACTS_DB[user_id] = []
@@ -108,21 +163,18 @@ class ContactAPIClient:
return True return True
def list_emails_mock(self) -> List[Email]: def list_emails_mock(self) -> List[Email]:
""" """模拟查询邮件列表"""
模拟查询邮件列表
"""
global MOCK_EMAILS_DB global MOCK_EMAILS_DB
if not MOCK_EMAILS_DB: if not MOCK_EMAILS_DB:
# 初始化一些示例邮件
MOCK_EMAILS_DB = [ MOCK_EMAILS_DB = [
Email( Email(
id="1", id="1",
subject="会议邀请AI技术分享", subject="会议邀请AI 技术分享",
sender="admin@example.com", sender="admin@example.com",
recipients=["user@example.com"], recipients=["user@example.com"],
date=datetime.now().isoformat(), date=datetime.now().isoformat(),
body="你好下周一将举办AI技术分享会欢迎参加。" body="你好,下周一将举办 AI 技术分享会,欢迎参加。"
), ),
Email( Email(
id="2", id="2",
@@ -131,16 +183,13 @@ class ContactAPIClient:
recipients=["user@example.com"], recipients=["user@example.com"],
date=datetime.now().isoformat(), date=datetime.now().isoformat(),
body="项目进度良好,继续保持。" body="项目进度良好,继续保持。"
) ),
] ]
return MOCK_EMAILS_DB return MOCK_EMAILS_DB
def generate_email_draft_mock(self, query: str) -> Dict[str, str]: def generate_email_draft_mock(self, query: str) -> Dict[str, str]:
""" """模拟生成邮件草稿"""
模拟生成邮件草稿
"""
# 简单的模板生成
return { return {
"subject": f"Re: {query}", "subject": f"Re: {query}",
"recipient": "recipient@example.com", "recipient": "recipient@example.com",
@@ -148,10 +197,9 @@ class ContactAPIClient:
} }
def send_email_mock(self, recipient: str, subject: str, body: str) -> Dict[str, Any]: def send_email_mock(self, recipient: str, subject: str, body: str) -> Dict[str, Any]:
""" """模拟发送邮件"""
模拟发送邮件 global MOCK_EMAILS_DB
"""
# 记录到模拟数据库
MOCK_EMAILS_DB.append( MOCK_EMAILS_DB.append(
Email( Email(
id=str(len(MOCK_EMAILS_DB) + 1), id=str(len(MOCK_EMAILS_DB) + 1),
@@ -169,10 +217,7 @@ class ContactAPIClient:
} }
def sniff_contacts_mock(self, query: str) -> Dict[str, Any]: def sniff_contacts_mock(self, query: str) -> Dict[str, Any]:
""" """模拟智能嗅探联系人"""
模拟智能嗅探联系人
"""
# 简化的检测逻辑
import re import re
emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', query) emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', query)
@@ -191,31 +236,36 @@ class ContactAPIClient:
"count": len(contacts), "count": len(contacts),
"suggestion": "是否添加这些联系人?" "suggestion": "是否添加这些联系人?"
} }
# 公共方法 # ========== 公共方法(自动选择模式)==========
def list_contacts(self, user_id: str = "default") -> List[Contact]:
"""查询联系人列表""" async def list_contacts(self, user_id: str = "default") -> List[Contact]:
"""获取联系人列表(自动选择数据库或模拟模式)"""
if self._use_db:
return await self.list_contacts_db(user_id)
return self.list_contacts_mock(user_id) return self.list_contacts_mock(user_id)
def add_contact(self, user_id: str, contact: Contact) -> bool: async def add_contact(self, user_id: str, contact: Contact) -> bool:
"""添加联系人""" """添加联系人(自动选择数据库或模拟模式)"""
if self._use_db:
return await self.add_contact_db(user_id, contact)
return self.save_contact_mock(user_id, contact) return self.save_contact_mock(user_id, contact)
def list_emails(self, user_id: str = "default") -> List[Email]: async def list_emails(self, user_id: str = "default") -> List[Email]:
"""查询邮件列表""" """查询邮件列表(目前用模拟)"""
return self.list_emails_mock() return self.list_emails_mock()
def generate_email_draft(self, query: str) -> Dict[str, str]: async def generate_email_draft(self, query: str) -> Dict[str, str]:
"""生成邮件草稿""" """生成邮件草稿(目前用模拟)"""
return self.generate_email_draft_mock(query) return self.generate_email_draft_mock(query)
def send_email(self, user_id: str, recipient: str, subject: str, body: str) -> bool: async def send_email(self, user_id: str, recipient: str, subject: str, body: str) -> bool:
"""发送邮件""" """发送邮件(目前用模拟)"""
result = self.send_email_mock(recipient, subject, body) result = self.send_email_mock(recipient, subject, body)
return result.get("success", False) return result.get("success", False)
def sniff_contacts(self, query: str) -> List[Contact]: async def sniff_contacts(self, query: str) -> List[Contact]:
"""智能嗅探联系人""" """智能嗅探联系人(目前用模拟)"""
result = self.sniff_contacts_mock(query) result = self.sniff_contacts_mock(query)
contact_dicts = result.get("contacts", []) contact_dicts = result.get("contacts", [])
return [ return [
@@ -232,5 +282,5 @@ class ContactAPIClient:
] ]
# 单例实例 # 全局实例(模拟模式,保留向后兼容)
contact_api = ContactAPIClient() contact_api = ContactAPIClient()

View File

@@ -23,6 +23,8 @@ from .agent_subgraphs.common.human_review import (
ReviewStatus, ReviewStatus,
HumanReview HumanReview
) )
from .agent_subgraphs.contact.api_client import ContactAPIClient
from .db.init_db import init_subgraph_tables
from .logger import info, error from .logger import info, error
@asynccontextmanager @asynccontextmanager
@@ -32,6 +34,9 @@ async def lifespan(app: FastAPI):
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup() await checkpointer.setup()
# 1.5 初始化子图表
await init_subgraph_tables(checkpointer.conn)
# 2. 构建 AI Agent 服务 # 2. 构建 AI Agent 服务
agent_service = AIAgentService(checkpointer) agent_service = AIAgentService(checkpointer)
await agent_service.initialize() await agent_service.initialize()
@@ -39,6 +44,9 @@ async def lifespan(app: FastAPI):
# 3. 创建历史查询服务 # 3. 创建历史查询服务
history_service = ThreadHistoryService(checkpointer) history_service = ThreadHistoryService(checkpointer)
# 3.5 创建子图 API 客户端(真实数据库模式)
contact_api = ContactAPIClient(checkpointer.conn)
# 4. 创建审核管理器 # 4. 创建审核管理器
review_manager = ReviewManager(InMemoryReviewStore()) review_manager = ReviewManager(InMemoryReviewStore())
@@ -46,6 +54,7 @@ async def lifespan(app: FastAPI):
app.state.agent_service = agent_service app.state.agent_service = agent_service
app.state.history_service = history_service app.state.history_service = history_service
app.state.review_manager = review_manager app.state.review_manager = review_manager
app.state.contact_api = contact_api
# 应用运行中... # 应用运行中...
yield yield

View File

@@ -0,0 +1,7 @@
"""
数据库模块
提供统一的 Repository 和实体类
"""
from .base import BaseRepository, BaseEntity
__all__ = ['BaseRepository', 'BaseEntity']

134
backend/app/db/base.py Normal file
View File

@@ -0,0 +1,134 @@
"""
统一的 Repository 基类
所有子图的数据操作都用这个,避免重复代码
"""
import json
from typing import Any, List, Dict, Optional
from dataclasses import dataclass, asdict
@dataclass
class BaseEntity:
"""实体基类,用于类型提示和自动序列化"""
id: Optional[str] = None
user_id: Optional[str] = None
created_at: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于存储)"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'BaseEntity':
"""从字典创建实体"""
return cls(**data)
class BaseRepository:
"""
统一的 Repository 基类
核心思想:
1. 传入连接,复用现有 checkpointer.conn
2. 提供通用 CRUD 方法
3. 表名和字段映射通过子类定义
"""
# 子类需要定义的配置
table_name: str
entity_class: type # 继承自 BaseEntity 的实体类
id_column: str = "id"
user_id_column: str = "user_id"
def __init__(self, conn):
"""
构造函数:传入数据库连接
这样可以复用 checkpointer.conn不需要额外的连接池
"""
self.conn = conn
# ========== 通用 CRUD 方法 ==========
async def insert(self, entity: BaseEntity) -> str:
"""插入单个实体,返回 ID"""
data = entity.to_dict()
data.pop('id', None) # 如果 id 是自增的
if not data:
raise ValueError("实体数据为空")
columns = list(data.keys())
placeholders = [f"${i+1}" for i in range(len(columns))]
values = list(data.values())
sql = f"""
INSERT INTO {self.table_name} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
RETURNING {self.id_column}
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, values)
row = await cur.fetchone()
return row[self.id_column] if row else None
async def update(self, entity_id: str, data: Dict[str, Any]) -> bool:
"""更新实体"""
if not data:
return False
set_clause = ", ".join([f"{k} = ${i+1}" for i, k in enumerate(data.keys())])
values = list(data.values()) + [entity_id]
sql = f"""
UPDATE {self.table_name}
SET {set_clause}
WHERE {self.id_column} = ${len(values)}
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, values)
return cur.rowcount > 0
async def delete(self, entity_id: str) -> bool:
"""删除实体"""
sql = f"DELETE FROM {self.table_name} WHERE {self.id_column} = $1"
async with self.conn.cursor() as cur:
await cur.execute(sql, (entity_id,))
return cur.rowcount > 0
async def get_by_id(self, entity_id: str) -> Optional[BaseEntity]:
"""根据 ID 查询单个实体"""
sql = f"SELECT * FROM {self.table_name} WHERE {self.id_column} = $1"
async with self.conn.cursor() as cur:
await cur.execute(sql, (entity_id,))
row = await cur.fetchone()
if row:
return self.entity_class.from_dict(dict(row))
return None
async def list_by_user(self, user_id: str, limit: int = 100) -> List[BaseEntity]:
"""查询某个用户的所有实体"""
sql = f"""
SELECT * FROM {self.table_name}
WHERE {self.user_id_column} = $1
ORDER BY created_at DESC
LIMIT $2
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, (user_id, limit))
rows = await cur.fetchall()
return [self.entity_class.from_dict(dict(row)) for row in rows]
# ========== 自定义查询方法(留钩子)==========
async def custom_query(self, sql: str, params: tuple = ()) -> List[Dict]:
"""执行自定义 SQL"""
async with self.conn.cursor() as cur:
await cur.execute(sql, params)
rows = await cur.fetchall()
return [dict(row) for row in rows]

97
backend/app/db/init_db.py Normal file
View File

@@ -0,0 +1,97 @@
"""
子图数据库表初始化
在 FastAPI 启动时创建表
"""
from typing import Optional
async def init_subgraph_tables(conn):
"""
初始化子图所需的表
Args:
conn: 数据库连接(来自 AsyncPostgresSaver
"""
# 1. contacts 表(通讯录)
await _create_contacts_table(conn)
# 2. words 表(词典)
await _create_words_table(conn)
# 3. news 表(资讯)
await _create_news_table(conn)
async def _create_contacts_table(conn):
"""创建 contacts 表"""
sql = """
CREATE TABLE IF NOT EXISTS contacts (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
name VARCHAR(100) NOT NULL,
phone VARCHAR(32),
email VARCHAR(100),
company VARCHAR(100),
position VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_contacts_user (user_id)
);
"""
try:
async with conn.cursor() as cur:
await cur.execute(sql)
# 注释掉 INFO 打印,避免噪声
except Exception as e:
print(f"contacts 表可能已存在: {e}")
async def _create_words_table(conn):
"""创建 words 表"""
sql = """
CREATE TABLE IF NOT EXISTS words (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
word VARCHAR(100) NOT NULL,
phonetic VARCHAR(100),
part_of_speech VARCHAR(50),
definition TEXT,
examples TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_words_user (user_id),
INDEX idx_words_word (word)
);
"""
try:
async with conn.cursor() as cur:
await cur.execute(sql)
except Exception as e:
print(f"words 表可能已存在: {e}")
async def _create_news_table(conn):
"""创建 news 表"""
sql = """
CREATE TABLE IF NOT EXISTS news (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
title VARCHAR(200) NOT NULL,
content TEXT,
url VARCHAR(500),
source VARCHAR(100),
keywords TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_news_user (user_id)
);
"""
try:
async with conn.cursor() as cur:
await cur.execute(sql)
except Exception as e:
print(f"news 表可能已存在: {e}")

106
backend/app/db/models.py Normal file
View File

@@ -0,0 +1,106 @@
"""
子图数据模型和 Repository 定义
只需要定义表名、实体类,基类搞定一切
"""
from typing import List, Optional
from dataclasses import dataclass
from .base import BaseRepository, BaseEntity
# ========== 通讯录 ==========
@dataclass
class ContactEntity(BaseEntity):
"""通讯录实体"""
name: str = ""
phone: str = ""
email: str = ""
company: str = ""
position: str = ""
class ContactRepository(BaseRepository):
"""通讯录 Repository"""
table_name = "contacts"
entity_class = ContactEntity
# 如果需要特殊查询,在这里加方法
async def search_by_name(self, user_id: str, name_keyword: str) -> List[ContactEntity]:
"""自定义查询:按姓名搜索"""
sql = """
SELECT * FROM contacts
WHERE user_id = $1 AND name ILIKE $2
ORDER BY created_at DESC
LIMIT 50
"""
result = await self.custom_query(sql, (user_id, f"%{name_keyword}%"))
return [ContactEntity.from_dict(row) for row in result]
# ========== 词典 ==========
@dataclass
class WordEntity(BaseEntity):
"""单词实体"""
word: str = ""
phonetic: str = ""
part_of_speech: str = ""
definition: str = ""
examples: str = ""
class DictionaryRepository(BaseRepository):
"""词典 Repository"""
table_name = "words"
entity_class = WordEntity
async def search_by_word(self, user_id: str, word: str) -> Optional[WordEntity]:
"""按单词查询"""
sql = """
SELECT * FROM words
WHERE user_id = $1 AND word = $2
LIMIT 1
"""
result = await self.custom_query(sql, (user_id, word))
if result:
return WordEntity.from_dict(result[0])
return None
# ========== 资讯 ==========
@dataclass
class NewsEntity(BaseEntity):
"""资讯实体"""
title: str = ""
content: str = ""
url: str = ""
source: str = ""
keywords: str = ""
class NewsRepository(BaseRepository):
"""资讯 Repository"""
table_name = "news"
entity_class = NewsEntity
async def search_by_keywords(self, user_id: str, keyword: str) -> List[NewsEntity]:
"""按关键词搜索"""
sql = """
SELECT * FROM news
WHERE user_id = $1 AND (title ILIKE $2 OR keywords ILIKE $2)
ORDER BY created_at DESC
LIMIT 50
"""
result = await self.custom_query(sql, (user_id, f"%{keyword}%"))
return [NewsEntity.from_dict(row) for row in result]
# ========== 导出 ==========
__all__ = [
'ContactEntity', 'ContactRepository',
'WordEntity', 'DictionaryRepository',
'NewsEntity', 'NewsRepository',
]