From 29016f87928351e65cf5da7c14172a39820efbd4 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Mon, 27 Apr 2026 16:37:45 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=9F=E4=B8=80=20Repository=20=E6=96=B9?= =?UTF-8?q?=E6=A1=88=EF=BC=9A=E6=B7=BB=E5=8A=A0=20db=20=E5=9F=BA=E7=B1=BB?= =?UTF-8?q?=E5=92=8C=E5=AD=90=E5=9B=BE=E6=A8=A1=E5=9E=8B=20+=20=E9=80=9A?= =?UTF-8?q?=E8=AE=AF=E5=BD=95=20API=20=E6=94=AF=E6=8C=81=E7=9C=9F=E5=AE=9E?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/agent_subgraphs/contact/api_client.py | 164 ++++++++++++------ backend/app/backend.py | 9 + backend/app/db/__init__.py | 7 + backend/app/db/base.py | 134 ++++++++++++++ backend/app/db/init_db.py | 97 +++++++++++ backend/app/db/models.py | 106 +++++++++++ 6 files changed, 460 insertions(+), 57 deletions(-) create mode 100644 backend/app/db/__init__.py create mode 100644 backend/app/db/base.py create mode 100644 backend/app/db/init_db.py create mode 100644 backend/app/db/models.py diff --git a/backend/app/agent_subgraphs/contact/api_client.py b/backend/app/agent_subgraphs/contact/api_client.py index 4dbaef0..ea6b3f3 100644 --- a/backend/app/agent_subgraphs/contact/api_client.py +++ b/backend/app/agent_subgraphs/contact/api_client.py @@ -1,8 +1,7 @@ """ -通讯录子图API调用工具 -Contact Subgraph API Client +通讯录子图 API 调用工具 +支持模拟数据和真实数据库两种模式 """ - from typing import Dict, Any, Optional, List from datetime import datetime from dataclasses import dataclass @@ -10,6 +9,8 @@ from dataclasses import dataclass from .state import Contact, Email +# ========== 模拟数据(保留作为备选)========== + # 模拟数据库 MOCK_CONTACTS_DB = {} MOCK_EMAILS_DB = [] @@ -18,13 +19,75 @@ MOCK_EMAILS_DB = [] @dataclass class ContactAPIClient: """ - 通讯录API客户端 - 可扩展支持多种后端 + 通讯录 API 客户端 - 支持真实数据库和模拟模式 + + 使用方式: + 1. 真实数据库模式:传入 conn 参数 + 2. 模拟模式:不传入 conn,或 conn 为 None """ + def __init__(self, conn=None): + """ + 初始化 + + Args: + conn: 数据库连接(来自 checkpointer.conn),为 None 时使用模拟模式 + """ + self.conn = conn + self._use_db = conn is not None + + if self._use_db: + try: + from ...db.models import ContactRepository, ContactEntity + self._repo = ContactRepository(conn) + except Exception as e: + print(f"Repository 初始化失败,回退到模拟模式: {e}") + self._use_db = False + self._repo = None + + # ========== 真实数据库方法 ========== + + async def list_contacts_db(self, user_id: str = "default") -> List[Contact]: + """真实数据库:获取联系人列表""" + if not self._repo: + return await self.list_contacts_mock(user_id) + + entities = await self._repo.list_by_user(user_id) + return [ + Contact( + id=e.id, + name=e.name, + phone=e.phone, + email=e.email, + company=e.company, + position=e.position, + created_at=e.created_at + ) + for e in entities + ] + + async def add_contact_db(self, user_id: str, contact: Contact) -> bool: + """真实数据库:添加联系人""" + if not self._repo: + return await self.save_contact_mock(user_id, contact) + + from ...db.models import ContactEntity + entity = ContactEntity( + user_id=user_id, + name=contact.name, + phone=contact.phone, + email=contact.email, + company=contact.company, + position=contact.position, + created_at=contact.created_at or datetime.now().isoformat() + ) + await self._repo.insert(entity) + return True + + # ========== 模拟数据方法(保留)========== + def list_contacts_mock(self, user_id: str = "default") -> List[Contact]: - """ - 模拟查询联系人列表 - """ + """模拟查询联系人列表""" if user_id not in MOCK_CONTACTS_DB: # 初始化一些示例数据 MOCK_CONTACTS_DB[user_id] = [ @@ -54,16 +117,13 @@ class ContactAPIClient: company="咨询公司", position="顾问", created_at=datetime.now().isoformat() - ) + ), ] return MOCK_CONTACTS_DB[user_id] def extract_contact_info_mock(self, query: str) -> Optional[Dict[str, Any]]: - """ - 模拟从查询中提取联系人信息 - """ - # 简化的提取逻辑 + """模拟从查询中提取联系人信息""" import re # 提取邮箱 @@ -71,11 +131,8 @@ class ContactAPIClient: # 提取手机号 phone_match = re.search(r'1[3-9]\d{9}', query) # 提取姓名(简单匹配) - # 先看是否有"添加"等关键词 if any(keyword in query for keyword in ["添加", "add"]): - # 尝试提取姓名 name = "未知" - # 简单的逻辑:去掉关键词、去掉邮箱和电话后剩下的 clean_query = query if email_match: clean_query = clean_query.replace(email_match.group(), "") @@ -95,9 +152,7 @@ class ContactAPIClient: return None def save_contact_mock(self, user_id: str, contact: Contact) -> bool: - """ - 模拟保存联系人 - """ + """模拟保存联系人""" if user_id not in MOCK_CONTACTS_DB: MOCK_CONTACTS_DB[user_id] = [] @@ -108,21 +163,18 @@ class ContactAPIClient: return True def list_emails_mock(self) -> List[Email]: - """ - 模拟查询邮件列表 - """ + """模拟查询邮件列表""" global MOCK_EMAILS_DB if not MOCK_EMAILS_DB: - # 初始化一些示例邮件 MOCK_EMAILS_DB = [ Email( id="1", - subject="会议邀请:AI技术分享", + subject="会议邀请:AI 技术分享", sender="admin@example.com", recipients=["user@example.com"], date=datetime.now().isoformat(), - body="你好,下周一将举办AI技术分享会,欢迎参加。" + body="你好,下周一将举办 AI 技术分享会,欢迎参加。" ), Email( id="2", @@ -131,16 +183,13 @@ class ContactAPIClient: recipients=["user@example.com"], date=datetime.now().isoformat(), body="项目进度良好,继续保持。" - ) + ), ] return MOCK_EMAILS_DB def generate_email_draft_mock(self, query: str) -> Dict[str, str]: - """ - 模拟生成邮件草稿 - """ - # 简单的模板生成 + """模拟生成邮件草稿""" return { "subject": f"Re: {query}", "recipient": "recipient@example.com", @@ -148,10 +197,9 @@ class ContactAPIClient: } def send_email_mock(self, recipient: str, subject: str, body: str) -> Dict[str, Any]: - """ - 模拟发送邮件 - """ - # 记录到模拟数据库 + """模拟发送邮件""" + global MOCK_EMAILS_DB + MOCK_EMAILS_DB.append( Email( id=str(len(MOCK_EMAILS_DB) + 1), @@ -169,10 +217,7 @@ class ContactAPIClient: } def sniff_contacts_mock(self, query: str) -> Dict[str, Any]: - """ - 模拟智能嗅探联系人 - """ - # 简化的检测逻辑 + """模拟智能嗅探联系人""" import re emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', query) @@ -191,31 +236,36 @@ class ContactAPIClient: "count": len(contacts), "suggestion": "是否添加这些联系人?" } - - # 公共方法 - def list_contacts(self, user_id: str = "default") -> List[Contact]: - """查询联系人列表""" + + # ========== 公共方法(自动选择模式)========== + + async def list_contacts(self, user_id: str = "default") -> List[Contact]: + """获取联系人列表(自动选择数据库或模拟模式)""" + if self._use_db: + return await self.list_contacts_db(user_id) return self.list_contacts_mock(user_id) - - def add_contact(self, user_id: str, contact: Contact) -> bool: - """添加联系人""" + + async def add_contact(self, user_id: str, contact: Contact) -> bool: + """添加联系人(自动选择数据库或模拟模式)""" + if self._use_db: + return await self.add_contact_db(user_id, contact) return self.save_contact_mock(user_id, contact) - - def list_emails(self, user_id: str = "default") -> List[Email]: - """查询邮件列表""" + + async def list_emails(self, user_id: str = "default") -> List[Email]: + """查询邮件列表(目前用模拟)""" return self.list_emails_mock() - - def generate_email_draft(self, query: str) -> Dict[str, str]: - """生成邮件草稿""" + + async def generate_email_draft(self, query: str) -> Dict[str, str]: + """生成邮件草稿(目前用模拟)""" return self.generate_email_draft_mock(query) - - def send_email(self, user_id: str, recipient: str, subject: str, body: str) -> bool: - """发送邮件""" + + async def send_email(self, user_id: str, recipient: str, subject: str, body: str) -> bool: + """发送邮件(目前用模拟)""" result = self.send_email_mock(recipient, subject, body) return result.get("success", False) - - def sniff_contacts(self, query: str) -> List[Contact]: - """智能嗅探联系人""" + + async def sniff_contacts(self, query: str) -> List[Contact]: + """智能嗅探联系人(目前用模拟)""" result = self.sniff_contacts_mock(query) contact_dicts = result.get("contacts", []) return [ @@ -232,5 +282,5 @@ class ContactAPIClient: ] -# 单例实例 +# 全局实例(模拟模式,保留向后兼容) contact_api = ContactAPIClient() diff --git a/backend/app/backend.py b/backend/app/backend.py index ceee392..f9e9e20 100644 --- a/backend/app/backend.py +++ b/backend/app/backend.py @@ -23,6 +23,8 @@ from .agent_subgraphs.common.human_review import ( ReviewStatus, HumanReview ) +from .agent_subgraphs.contact.api_client import ContactAPIClient +from .db.init_db import init_subgraph_tables from .logger import info, error @asynccontextmanager @@ -32,6 +34,9 @@ async def lifespan(app: FastAPI): async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: await checkpointer.setup() + # 1.5 初始化子图表 + await init_subgraph_tables(checkpointer.conn) + # 2. 构建 AI Agent 服务 agent_service = AIAgentService(checkpointer) await agent_service.initialize() @@ -39,6 +44,9 @@ async def lifespan(app: FastAPI): # 3. 创建历史查询服务 history_service = ThreadHistoryService(checkpointer) + # 3.5 创建子图 API 客户端(真实数据库模式) + contact_api = ContactAPIClient(checkpointer.conn) + # 4. 创建审核管理器 review_manager = ReviewManager(InMemoryReviewStore()) @@ -46,6 +54,7 @@ async def lifespan(app: FastAPI): app.state.agent_service = agent_service app.state.history_service = history_service app.state.review_manager = review_manager + app.state.contact_api = contact_api # 应用运行中... yield diff --git a/backend/app/db/__init__.py b/backend/app/db/__init__.py new file mode 100644 index 0000000..86407cf --- /dev/null +++ b/backend/app/db/__init__.py @@ -0,0 +1,7 @@ +""" +数据库模块 +提供统一的 Repository 和实体类 +""" +from .base import BaseRepository, BaseEntity + +__all__ = ['BaseRepository', 'BaseEntity'] diff --git a/backend/app/db/base.py b/backend/app/db/base.py new file mode 100644 index 0000000..3900dcc --- /dev/null +++ b/backend/app/db/base.py @@ -0,0 +1,134 @@ +""" +统一的 Repository 基类 +所有子图的数据操作都用这个,避免重复代码 +""" +import json +from typing import Any, List, Dict, Optional +from dataclasses import dataclass, asdict + + +@dataclass +class BaseEntity: + """实体基类,用于类型提示和自动序列化""" + id: Optional[str] = None + user_id: Optional[str] = None + created_at: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典(用于存储)""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'BaseEntity': + """从字典创建实体""" + return cls(**data) + + +class BaseRepository: + """ + 统一的 Repository 基类 + + 核心思想: + 1. 传入连接,复用现有 checkpointer.conn + 2. 提供通用 CRUD 方法 + 3. 表名和字段映射通过子类定义 + """ + + # 子类需要定义的配置 + table_name: str + entity_class: type # 继承自 BaseEntity 的实体类 + id_column: str = "id" + user_id_column: str = "user_id" + + def __init__(self, conn): + """ + 构造函数:传入数据库连接 + + 这样可以复用 checkpointer.conn,不需要额外的连接池 + """ + self.conn = conn + + # ========== 通用 CRUD 方法 ========== + + async def insert(self, entity: BaseEntity) -> str: + """插入单个实体,返回 ID""" + data = entity.to_dict() + data.pop('id', None) # 如果 id 是自增的 + + if not data: + raise ValueError("实体数据为空") + + columns = list(data.keys()) + placeholders = [f"${i+1}" for i in range(len(columns))] + values = list(data.values()) + + sql = f""" + INSERT INTO {self.table_name} ({', '.join(columns)}) + VALUES ({', '.join(placeholders)}) + RETURNING {self.id_column} + """ + + async with self.conn.cursor() as cur: + await cur.execute(sql, values) + row = await cur.fetchone() + return row[self.id_column] if row else None + + async def update(self, entity_id: str, data: Dict[str, Any]) -> bool: + """更新实体""" + if not data: + return False + + set_clause = ", ".join([f"{k} = ${i+1}" for i, k in enumerate(data.keys())]) + values = list(data.values()) + [entity_id] + + sql = f""" + UPDATE {self.table_name} + SET {set_clause} + WHERE {self.id_column} = ${len(values)} + """ + + async with self.conn.cursor() as cur: + await cur.execute(sql, values) + return cur.rowcount > 0 + + async def delete(self, entity_id: str) -> bool: + """删除实体""" + sql = f"DELETE FROM {self.table_name} WHERE {self.id_column} = $1" + + async with self.conn.cursor() as cur: + await cur.execute(sql, (entity_id,)) + return cur.rowcount > 0 + + async def get_by_id(self, entity_id: str) -> Optional[BaseEntity]: + """根据 ID 查询单个实体""" + sql = f"SELECT * FROM {self.table_name} WHERE {self.id_column} = $1" + + async with self.conn.cursor() as cur: + await cur.execute(sql, (entity_id,)) + row = await cur.fetchone() + if row: + return self.entity_class.from_dict(dict(row)) + return None + + async def list_by_user(self, user_id: str, limit: int = 100) -> List[BaseEntity]: + """查询某个用户的所有实体""" + sql = f""" + SELECT * FROM {self.table_name} + WHERE {self.user_id_column} = $1 + ORDER BY created_at DESC + LIMIT $2 + """ + + async with self.conn.cursor() as cur: + await cur.execute(sql, (user_id, limit)) + rows = await cur.fetchall() + return [self.entity_class.from_dict(dict(row)) for row in rows] + + # ========== 自定义查询方法(留钩子)========== + + async def custom_query(self, sql: str, params: tuple = ()) -> List[Dict]: + """执行自定义 SQL""" + async with self.conn.cursor() as cur: + await cur.execute(sql, params) + rows = await cur.fetchall() + return [dict(row) for row in rows] diff --git a/backend/app/db/init_db.py b/backend/app/db/init_db.py new file mode 100644 index 0000000..0891d99 --- /dev/null +++ b/backend/app/db/init_db.py @@ -0,0 +1,97 @@ +""" +子图数据库表初始化 +在 FastAPI 启动时创建表 +""" +from typing import Optional + + +async def init_subgraph_tables(conn): + """ + 初始化子图所需的表 + + Args: + conn: 数据库连接(来自 AsyncPostgresSaver) + """ + + # 1. contacts 表(通讯录) + await _create_contacts_table(conn) + + # 2. words 表(词典) + await _create_words_table(conn) + + # 3. news 表(资讯) + await _create_news_table(conn) + + +async def _create_contacts_table(conn): + """创建 contacts 表""" + sql = """ + CREATE TABLE IF NOT EXISTS contacts ( + id VARCHAR(64) PRIMARY KEY, + user_id VARCHAR(64) NOT NULL, + name VARCHAR(100) NOT NULL, + phone VARCHAR(32), + email VARCHAR(100), + company VARCHAR(100), + position VARCHAR(100), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + INDEX idx_contacts_user (user_id) + ); + """ + + try: + async with conn.cursor() as cur: + await cur.execute(sql) + # 注释掉 INFO 打印,避免噪声 + except Exception as e: + print(f"contacts 表可能已存在: {e}") + + +async def _create_words_table(conn): + """创建 words 表""" + sql = """ + CREATE TABLE IF NOT EXISTS words ( + id VARCHAR(64) PRIMARY KEY, + user_id VARCHAR(64) NOT NULL, + word VARCHAR(100) NOT NULL, + phonetic VARCHAR(100), + part_of_speech VARCHAR(50), + definition TEXT, + examples TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + INDEX idx_words_user (user_id), + INDEX idx_words_word (word) + ); + """ + + try: + async with conn.cursor() as cur: + await cur.execute(sql) + except Exception as e: + print(f"words 表可能已存在: {e}") + + +async def _create_news_table(conn): + """创建 news 表""" + sql = """ + CREATE TABLE IF NOT EXISTS news ( + id VARCHAR(64) PRIMARY KEY, + user_id VARCHAR(64) NOT NULL, + title VARCHAR(200) NOT NULL, + content TEXT, + url VARCHAR(500), + source VARCHAR(100), + keywords TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + INDEX idx_news_user (user_id) + ); + """ + + try: + async with conn.cursor() as cur: + await cur.execute(sql) + except Exception as e: + print(f"news 表可能已存在: {e}") diff --git a/backend/app/db/models.py b/backend/app/db/models.py new file mode 100644 index 0000000..2c417ef --- /dev/null +++ b/backend/app/db/models.py @@ -0,0 +1,106 @@ +""" +子图数据模型和 Repository 定义 +只需要定义表名、实体类,基类搞定一切 +""" +from typing import List, Optional +from dataclasses import dataclass + +from .base import BaseRepository, BaseEntity + + +# ========== 通讯录 ========== + +@dataclass +class ContactEntity(BaseEntity): + """通讯录实体""" + name: str = "" + phone: str = "" + email: str = "" + company: str = "" + position: str = "" + + +class ContactRepository(BaseRepository): + """通讯录 Repository""" + table_name = "contacts" + entity_class = ContactEntity + + # 如果需要特殊查询,在这里加方法 + async def search_by_name(self, user_id: str, name_keyword: str) -> List[ContactEntity]: + """自定义查询:按姓名搜索""" + sql = """ + SELECT * FROM contacts + WHERE user_id = $1 AND name ILIKE $2 + ORDER BY created_at DESC + LIMIT 50 + """ + result = await self.custom_query(sql, (user_id, f"%{name_keyword}%")) + return [ContactEntity.from_dict(row) for row in result] + + +# ========== 词典 ========== + +@dataclass +class WordEntity(BaseEntity): + """单词实体""" + word: str = "" + phonetic: str = "" + part_of_speech: str = "" + definition: str = "" + examples: str = "" + + +class DictionaryRepository(BaseRepository): + """词典 Repository""" + table_name = "words" + entity_class = WordEntity + + async def search_by_word(self, user_id: str, word: str) -> Optional[WordEntity]: + """按单词查询""" + sql = """ + SELECT * FROM words + WHERE user_id = $1 AND word = $2 + LIMIT 1 + """ + result = await self.custom_query(sql, (user_id, word)) + if result: + return WordEntity.from_dict(result[0]) + return None + + +# ========== 资讯 ========== + +@dataclass +class NewsEntity(BaseEntity): + """资讯实体""" + title: str = "" + content: str = "" + url: str = "" + source: str = "" + keywords: str = "" + + +class NewsRepository(BaseRepository): + """资讯 Repository""" + table_name = "news" + entity_class = NewsEntity + + async def search_by_keywords(self, user_id: str, keyword: str) -> List[NewsEntity]: + """按关键词搜索""" + sql = """ + SELECT * FROM news + WHERE user_id = $1 AND (title ILIKE $2 OR keywords ILIKE $2) + ORDER BY created_at DESC + LIMIT 50 + """ + result = await self.custom_query(sql, (user_id, f"%{keyword}%")) + return [NewsEntity.from_dict(row) for row in result] + + +# ========== 导出 ========== + +__all__ = [ + 'ContactEntity', 'ContactRepository', + 'WordEntity', 'DictionaryRepository', + 'NewsEntity', 'NewsRepository', +]