统一 Repository 方案:添加 db 基类和子图模型 + 通讯录 API 支持真实数据库
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled
This commit is contained in:
@@ -1,8 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
通讯录子图API调用工具
|
通讯录子图 API 调用工具
|
||||||
Contact Subgraph API Client
|
支持模拟数据和真实数据库两种模式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -10,6 +9,8 @@ from dataclasses import dataclass
|
|||||||
from .state import Contact, Email
|
from .state import Contact, Email
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 模拟数据(保留作为备选)==========
|
||||||
|
|
||||||
# 模拟数据库
|
# 模拟数据库
|
||||||
MOCK_CONTACTS_DB = {}
|
MOCK_CONTACTS_DB = {}
|
||||||
MOCK_EMAILS_DB = []
|
MOCK_EMAILS_DB = []
|
||||||
@@ -18,13 +19,75 @@ MOCK_EMAILS_DB = []
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ContactAPIClient:
|
class ContactAPIClient:
|
||||||
"""
|
"""
|
||||||
通讯录API客户端 - 可扩展支持多种后端
|
通讯录 API 客户端 - 支持真实数据库和模拟模式
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
1. 真实数据库模式:传入 conn 参数
|
||||||
|
2. 模拟模式:不传入 conn,或 conn 为 None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conn=None):
|
||||||
|
"""
|
||||||
|
初始化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conn: 数据库连接(来自 checkpointer.conn),为 None 时使用模拟模式
|
||||||
|
"""
|
||||||
|
self.conn = conn
|
||||||
|
self._use_db = conn is not None
|
||||||
|
|
||||||
|
if self._use_db:
|
||||||
|
try:
|
||||||
|
from ...db.models import ContactRepository, ContactEntity
|
||||||
|
self._repo = ContactRepository(conn)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Repository 初始化失败,回退到模拟模式: {e}")
|
||||||
|
self._use_db = False
|
||||||
|
self._repo = None
|
||||||
|
|
||||||
|
# ========== 真实数据库方法 ==========
|
||||||
|
|
||||||
|
async def list_contacts_db(self, user_id: str = "default") -> List[Contact]:
|
||||||
|
"""真实数据库:获取联系人列表"""
|
||||||
|
if not self._repo:
|
||||||
|
return await self.list_contacts_mock(user_id)
|
||||||
|
|
||||||
|
entities = await self._repo.list_by_user(user_id)
|
||||||
|
return [
|
||||||
|
Contact(
|
||||||
|
id=e.id,
|
||||||
|
name=e.name,
|
||||||
|
phone=e.phone,
|
||||||
|
email=e.email,
|
||||||
|
company=e.company,
|
||||||
|
position=e.position,
|
||||||
|
created_at=e.created_at
|
||||||
|
)
|
||||||
|
for e in entities
|
||||||
|
]
|
||||||
|
|
||||||
|
async def add_contact_db(self, user_id: str, contact: Contact) -> bool:
|
||||||
|
"""真实数据库:添加联系人"""
|
||||||
|
if not self._repo:
|
||||||
|
return await self.save_contact_mock(user_id, contact)
|
||||||
|
|
||||||
|
from ...db.models import ContactEntity
|
||||||
|
entity = ContactEntity(
|
||||||
|
user_id=user_id,
|
||||||
|
name=contact.name,
|
||||||
|
phone=contact.phone,
|
||||||
|
email=contact.email,
|
||||||
|
company=contact.company,
|
||||||
|
position=contact.position,
|
||||||
|
created_at=contact.created_at or datetime.now().isoformat()
|
||||||
|
)
|
||||||
|
await self._repo.insert(entity)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ========== 模拟数据方法(保留)==========
|
||||||
|
|
||||||
def list_contacts_mock(self, user_id: str = "default") -> List[Contact]:
|
def list_contacts_mock(self, user_id: str = "default") -> List[Contact]:
|
||||||
"""
|
"""模拟查询联系人列表"""
|
||||||
模拟查询联系人列表
|
|
||||||
"""
|
|
||||||
if user_id not in MOCK_CONTACTS_DB:
|
if user_id not in MOCK_CONTACTS_DB:
|
||||||
# 初始化一些示例数据
|
# 初始化一些示例数据
|
||||||
MOCK_CONTACTS_DB[user_id] = [
|
MOCK_CONTACTS_DB[user_id] = [
|
||||||
@@ -54,16 +117,13 @@ class ContactAPIClient:
|
|||||||
company="咨询公司",
|
company="咨询公司",
|
||||||
position="顾问",
|
position="顾问",
|
||||||
created_at=datetime.now().isoformat()
|
created_at=datetime.now().isoformat()
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
return MOCK_CONTACTS_DB[user_id]
|
return MOCK_CONTACTS_DB[user_id]
|
||||||
|
|
||||||
def extract_contact_info_mock(self, query: str) -> Optional[Dict[str, Any]]:
|
def extract_contact_info_mock(self, query: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""模拟从查询中提取联系人信息"""
|
||||||
模拟从查询中提取联系人信息
|
|
||||||
"""
|
|
||||||
# 简化的提取逻辑
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 提取邮箱
|
# 提取邮箱
|
||||||
@@ -71,11 +131,8 @@ class ContactAPIClient:
|
|||||||
# 提取手机号
|
# 提取手机号
|
||||||
phone_match = re.search(r'1[3-9]\d{9}', query)
|
phone_match = re.search(r'1[3-9]\d{9}', query)
|
||||||
# 提取姓名(简单匹配)
|
# 提取姓名(简单匹配)
|
||||||
# 先看是否有"添加"等关键词
|
|
||||||
if any(keyword in query for keyword in ["添加", "add"]):
|
if any(keyword in query for keyword in ["添加", "add"]):
|
||||||
# 尝试提取姓名
|
|
||||||
name = "未知"
|
name = "未知"
|
||||||
# 简单的逻辑:去掉关键词、去掉邮箱和电话后剩下的
|
|
||||||
clean_query = query
|
clean_query = query
|
||||||
if email_match:
|
if email_match:
|
||||||
clean_query = clean_query.replace(email_match.group(), "")
|
clean_query = clean_query.replace(email_match.group(), "")
|
||||||
@@ -95,9 +152,7 @@ class ContactAPIClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def save_contact_mock(self, user_id: str, contact: Contact) -> bool:
|
def save_contact_mock(self, user_id: str, contact: Contact) -> bool:
|
||||||
"""
|
"""模拟保存联系人"""
|
||||||
模拟保存联系人
|
|
||||||
"""
|
|
||||||
if user_id not in MOCK_CONTACTS_DB:
|
if user_id not in MOCK_CONTACTS_DB:
|
||||||
MOCK_CONTACTS_DB[user_id] = []
|
MOCK_CONTACTS_DB[user_id] = []
|
||||||
|
|
||||||
@@ -108,21 +163,18 @@ class ContactAPIClient:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def list_emails_mock(self) -> List[Email]:
|
def list_emails_mock(self) -> List[Email]:
|
||||||
"""
|
"""模拟查询邮件列表"""
|
||||||
模拟查询邮件列表
|
|
||||||
"""
|
|
||||||
global MOCK_EMAILS_DB
|
global MOCK_EMAILS_DB
|
||||||
|
|
||||||
if not MOCK_EMAILS_DB:
|
if not MOCK_EMAILS_DB:
|
||||||
# 初始化一些示例邮件
|
|
||||||
MOCK_EMAILS_DB = [
|
MOCK_EMAILS_DB = [
|
||||||
Email(
|
Email(
|
||||||
id="1",
|
id="1",
|
||||||
subject="会议邀请:AI技术分享",
|
subject="会议邀请:AI 技术分享",
|
||||||
sender="admin@example.com",
|
sender="admin@example.com",
|
||||||
recipients=["user@example.com"],
|
recipients=["user@example.com"],
|
||||||
date=datetime.now().isoformat(),
|
date=datetime.now().isoformat(),
|
||||||
body="你好,下周一将举办AI技术分享会,欢迎参加。"
|
body="你好,下周一将举办 AI 技术分享会,欢迎参加。"
|
||||||
),
|
),
|
||||||
Email(
|
Email(
|
||||||
id="2",
|
id="2",
|
||||||
@@ -131,16 +183,13 @@ class ContactAPIClient:
|
|||||||
recipients=["user@example.com"],
|
recipients=["user@example.com"],
|
||||||
date=datetime.now().isoformat(),
|
date=datetime.now().isoformat(),
|
||||||
body="项目进度良好,继续保持。"
|
body="项目进度良好,继续保持。"
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
return MOCK_EMAILS_DB
|
return MOCK_EMAILS_DB
|
||||||
|
|
||||||
def generate_email_draft_mock(self, query: str) -> Dict[str, str]:
|
def generate_email_draft_mock(self, query: str) -> Dict[str, str]:
|
||||||
"""
|
"""模拟生成邮件草稿"""
|
||||||
模拟生成邮件草稿
|
|
||||||
"""
|
|
||||||
# 简单的模板生成
|
|
||||||
return {
|
return {
|
||||||
"subject": f"Re: {query}",
|
"subject": f"Re: {query}",
|
||||||
"recipient": "recipient@example.com",
|
"recipient": "recipient@example.com",
|
||||||
@@ -148,10 +197,9 @@ class ContactAPIClient:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def send_email_mock(self, recipient: str, subject: str, body: str) -> Dict[str, Any]:
|
def send_email_mock(self, recipient: str, subject: str, body: str) -> Dict[str, Any]:
|
||||||
"""
|
"""模拟发送邮件"""
|
||||||
模拟发送邮件
|
global MOCK_EMAILS_DB
|
||||||
"""
|
|
||||||
# 记录到模拟数据库
|
|
||||||
MOCK_EMAILS_DB.append(
|
MOCK_EMAILS_DB.append(
|
||||||
Email(
|
Email(
|
||||||
id=str(len(MOCK_EMAILS_DB) + 1),
|
id=str(len(MOCK_EMAILS_DB) + 1),
|
||||||
@@ -169,10 +217,7 @@ class ContactAPIClient:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def sniff_contacts_mock(self, query: str) -> Dict[str, Any]:
|
def sniff_contacts_mock(self, query: str) -> Dict[str, Any]:
|
||||||
"""
|
"""模拟智能嗅探联系人"""
|
||||||
模拟智能嗅探联系人
|
|
||||||
"""
|
|
||||||
# 简化的检测逻辑
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', query)
|
emails = re.findall(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', query)
|
||||||
@@ -191,31 +236,36 @@ class ContactAPIClient:
|
|||||||
"count": len(contacts),
|
"count": len(contacts),
|
||||||
"suggestion": "是否添加这些联系人?"
|
"suggestion": "是否添加这些联系人?"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 公共方法
|
# ========== 公共方法(自动选择模式)==========
|
||||||
def list_contacts(self, user_id: str = "default") -> List[Contact]:
|
|
||||||
"""查询联系人列表"""
|
async def list_contacts(self, user_id: str = "default") -> List[Contact]:
|
||||||
|
"""获取联系人列表(自动选择数据库或模拟模式)"""
|
||||||
|
if self._use_db:
|
||||||
|
return await self.list_contacts_db(user_id)
|
||||||
return self.list_contacts_mock(user_id)
|
return self.list_contacts_mock(user_id)
|
||||||
|
|
||||||
def add_contact(self, user_id: str, contact: Contact) -> bool:
|
async def add_contact(self, user_id: str, contact: Contact) -> bool:
|
||||||
"""添加联系人"""
|
"""添加联系人(自动选择数据库或模拟模式)"""
|
||||||
|
if self._use_db:
|
||||||
|
return await self.add_contact_db(user_id, contact)
|
||||||
return self.save_contact_mock(user_id, contact)
|
return self.save_contact_mock(user_id, contact)
|
||||||
|
|
||||||
def list_emails(self, user_id: str = "default") -> List[Email]:
|
async def list_emails(self, user_id: str = "default") -> List[Email]:
|
||||||
"""查询邮件列表"""
|
"""查询邮件列表(目前用模拟)"""
|
||||||
return self.list_emails_mock()
|
return self.list_emails_mock()
|
||||||
|
|
||||||
def generate_email_draft(self, query: str) -> Dict[str, str]:
|
async def generate_email_draft(self, query: str) -> Dict[str, str]:
|
||||||
"""生成邮件草稿"""
|
"""生成邮件草稿(目前用模拟)"""
|
||||||
return self.generate_email_draft_mock(query)
|
return self.generate_email_draft_mock(query)
|
||||||
|
|
||||||
def send_email(self, user_id: str, recipient: str, subject: str, body: str) -> bool:
|
async def send_email(self, user_id: str, recipient: str, subject: str, body: str) -> bool:
|
||||||
"""发送邮件"""
|
"""发送邮件(目前用模拟)"""
|
||||||
result = self.send_email_mock(recipient, subject, body)
|
result = self.send_email_mock(recipient, subject, body)
|
||||||
return result.get("success", False)
|
return result.get("success", False)
|
||||||
|
|
||||||
def sniff_contacts(self, query: str) -> List[Contact]:
|
async def sniff_contacts(self, query: str) -> List[Contact]:
|
||||||
"""智能嗅探联系人"""
|
"""智能嗅探联系人(目前用模拟)"""
|
||||||
result = self.sniff_contacts_mock(query)
|
result = self.sniff_contacts_mock(query)
|
||||||
contact_dicts = result.get("contacts", [])
|
contact_dicts = result.get("contacts", [])
|
||||||
return [
|
return [
|
||||||
@@ -232,5 +282,5 @@ class ContactAPIClient:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# 单例实例
|
# 全局实例(模拟模式,保留向后兼容)
|
||||||
contact_api = ContactAPIClient()
|
contact_api = ContactAPIClient()
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ from .agent_subgraphs.common.human_review import (
|
|||||||
ReviewStatus,
|
ReviewStatus,
|
||||||
HumanReview
|
HumanReview
|
||||||
)
|
)
|
||||||
|
from .agent_subgraphs.contact.api_client import ContactAPIClient
|
||||||
|
from .db.init_db import init_subgraph_tables
|
||||||
from .logger import info, error
|
from .logger import info, error
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -32,6 +34,9 @@ async def lifespan(app: FastAPI):
|
|||||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||||
await checkpointer.setup()
|
await checkpointer.setup()
|
||||||
|
|
||||||
|
# 1.5 初始化子图表
|
||||||
|
await init_subgraph_tables(checkpointer.conn)
|
||||||
|
|
||||||
# 2. 构建 AI Agent 服务
|
# 2. 构建 AI Agent 服务
|
||||||
agent_service = AIAgentService(checkpointer)
|
agent_service = AIAgentService(checkpointer)
|
||||||
await agent_service.initialize()
|
await agent_service.initialize()
|
||||||
@@ -39,6 +44,9 @@ async def lifespan(app: FastAPI):
|
|||||||
# 3. 创建历史查询服务
|
# 3. 创建历史查询服务
|
||||||
history_service = ThreadHistoryService(checkpointer)
|
history_service = ThreadHistoryService(checkpointer)
|
||||||
|
|
||||||
|
# 3.5 创建子图 API 客户端(真实数据库模式)
|
||||||
|
contact_api = ContactAPIClient(checkpointer.conn)
|
||||||
|
|
||||||
# 4. 创建审核管理器
|
# 4. 创建审核管理器
|
||||||
review_manager = ReviewManager(InMemoryReviewStore())
|
review_manager = ReviewManager(InMemoryReviewStore())
|
||||||
|
|
||||||
@@ -46,6 +54,7 @@ async def lifespan(app: FastAPI):
|
|||||||
app.state.agent_service = agent_service
|
app.state.agent_service = agent_service
|
||||||
app.state.history_service = history_service
|
app.state.history_service = history_service
|
||||||
app.state.review_manager = review_manager
|
app.state.review_manager = review_manager
|
||||||
|
app.state.contact_api = contact_api
|
||||||
|
|
||||||
# 应用运行中...
|
# 应用运行中...
|
||||||
yield
|
yield
|
||||||
|
|||||||
7
backend/app/db/__init__.py
Normal file
7
backend/app/db/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
数据库模块
|
||||||
|
提供统一的 Repository 和实体类
|
||||||
|
"""
|
||||||
|
from .base import BaseRepository, BaseEntity
|
||||||
|
|
||||||
|
__all__ = ['BaseRepository', 'BaseEntity']
|
||||||
134
backend/app/db/base.py
Normal file
134
backend/app/db/base.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
统一的 Repository 基类
|
||||||
|
所有子图的数据操作都用这个,避免重复代码
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from typing import Any, List, Dict, Optional
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseEntity:
|
||||||
|
"""实体基类,用于类型提示和自动序列化"""
|
||||||
|
id: Optional[str] = None
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
created_at: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典(用于存储)"""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'BaseEntity':
|
||||||
|
"""从字典创建实体"""
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRepository:
|
||||||
|
"""
|
||||||
|
统一的 Repository 基类
|
||||||
|
|
||||||
|
核心思想:
|
||||||
|
1. 传入连接,复用现有 checkpointer.conn
|
||||||
|
2. 提供通用 CRUD 方法
|
||||||
|
3. 表名和字段映射通过子类定义
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 子类需要定义的配置
|
||||||
|
table_name: str
|
||||||
|
entity_class: type # 继承自 BaseEntity 的实体类
|
||||||
|
id_column: str = "id"
|
||||||
|
user_id_column: str = "user_id"
|
||||||
|
|
||||||
|
def __init__(self, conn):
|
||||||
|
"""
|
||||||
|
构造函数:传入数据库连接
|
||||||
|
|
||||||
|
这样可以复用 checkpointer.conn,不需要额外的连接池
|
||||||
|
"""
|
||||||
|
self.conn = conn
|
||||||
|
|
||||||
|
# ========== 通用 CRUD 方法 ==========
|
||||||
|
|
||||||
|
async def insert(self, entity: BaseEntity) -> str:
|
||||||
|
"""插入单个实体,返回 ID"""
|
||||||
|
data = entity.to_dict()
|
||||||
|
data.pop('id', None) # 如果 id 是自增的
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
raise ValueError("实体数据为空")
|
||||||
|
|
||||||
|
columns = list(data.keys())
|
||||||
|
placeholders = [f"${i+1}" for i in range(len(columns))]
|
||||||
|
values = list(data.values())
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
INSERT INTO {self.table_name} ({', '.join(columns)})
|
||||||
|
VALUES ({', '.join(placeholders)})
|
||||||
|
RETURNING {self.id_column}
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self.conn.cursor() as cur:
|
||||||
|
await cur.execute(sql, values)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
return row[self.id_column] if row else None
|
||||||
|
|
||||||
|
async def update(self, entity_id: str, data: Dict[str, Any]) -> bool:
|
||||||
|
"""更新实体"""
|
||||||
|
if not data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
set_clause = ", ".join([f"{k} = ${i+1}" for i, k in enumerate(data.keys())])
|
||||||
|
values = list(data.values()) + [entity_id]
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
UPDATE {self.table_name}
|
||||||
|
SET {set_clause}
|
||||||
|
WHERE {self.id_column} = ${len(values)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self.conn.cursor() as cur:
|
||||||
|
await cur.execute(sql, values)
|
||||||
|
return cur.rowcount > 0
|
||||||
|
|
||||||
|
async def delete(self, entity_id: str) -> bool:
|
||||||
|
"""删除实体"""
|
||||||
|
sql = f"DELETE FROM {self.table_name} WHERE {self.id_column} = $1"
|
||||||
|
|
||||||
|
async with self.conn.cursor() as cur:
|
||||||
|
await cur.execute(sql, (entity_id,))
|
||||||
|
return cur.rowcount > 0
|
||||||
|
|
||||||
|
async def get_by_id(self, entity_id: str) -> Optional[BaseEntity]:
|
||||||
|
"""根据 ID 查询单个实体"""
|
||||||
|
sql = f"SELECT * FROM {self.table_name} WHERE {self.id_column} = $1"
|
||||||
|
|
||||||
|
async with self.conn.cursor() as cur:
|
||||||
|
await cur.execute(sql, (entity_id,))
|
||||||
|
row = await cur.fetchone()
|
||||||
|
if row:
|
||||||
|
return self.entity_class.from_dict(dict(row))
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_by_user(self, user_id: str, limit: int = 100) -> List[BaseEntity]:
|
||||||
|
"""查询某个用户的所有实体"""
|
||||||
|
sql = f"""
|
||||||
|
SELECT * FROM {self.table_name}
|
||||||
|
WHERE {self.user_id_column} = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $2
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with self.conn.cursor() as cur:
|
||||||
|
await cur.execute(sql, (user_id, limit))
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
return [self.entity_class.from_dict(dict(row)) for row in rows]
|
||||||
|
|
||||||
|
# ========== 自定义查询方法(留钩子)==========
|
||||||
|
|
||||||
|
async def custom_query(self, sql: str, params: tuple = ()) -> List[Dict]:
|
||||||
|
"""执行自定义 SQL"""
|
||||||
|
async with self.conn.cursor() as cur:
|
||||||
|
await cur.execute(sql, params)
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
return [dict(row) for row in rows]
|
||||||
97
backend/app/db/init_db.py
Normal file
97
backend/app/db/init_db.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
子图数据库表初始化
|
||||||
|
在 FastAPI 启动时创建表
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
async def init_subgraph_tables(conn):
|
||||||
|
"""
|
||||||
|
初始化子图所需的表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conn: 数据库连接(来自 AsyncPostgresSaver)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. contacts 表(通讯录)
|
||||||
|
await _create_contacts_table(conn)
|
||||||
|
|
||||||
|
# 2. words 表(词典)
|
||||||
|
await _create_words_table(conn)
|
||||||
|
|
||||||
|
# 3. news 表(资讯)
|
||||||
|
await _create_news_table(conn)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_contacts_table(conn):
|
||||||
|
"""创建 contacts 表"""
|
||||||
|
sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS contacts (
|
||||||
|
id VARCHAR(64) PRIMARY KEY,
|
||||||
|
user_id VARCHAR(64) NOT NULL,
|
||||||
|
name VARCHAR(100) NOT NULL,
|
||||||
|
phone VARCHAR(32),
|
||||||
|
email VARCHAR(100),
|
||||||
|
company VARCHAR(100),
|
||||||
|
position VARCHAR(100),
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
|
||||||
|
INDEX idx_contacts_user (user_id)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(sql)
|
||||||
|
# 注释掉 INFO 打印,避免噪声
|
||||||
|
except Exception as e:
|
||||||
|
print(f"contacts 表可能已存在: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_words_table(conn):
|
||||||
|
"""创建 words 表"""
|
||||||
|
sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS words (
|
||||||
|
id VARCHAR(64) PRIMARY KEY,
|
||||||
|
user_id VARCHAR(64) NOT NULL,
|
||||||
|
word VARCHAR(100) NOT NULL,
|
||||||
|
phonetic VARCHAR(100),
|
||||||
|
part_of_speech VARCHAR(50),
|
||||||
|
definition TEXT,
|
||||||
|
examples TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
|
||||||
|
INDEX idx_words_user (user_id),
|
||||||
|
INDEX idx_words_word (word)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(sql)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"words 表可能已存在: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_news_table(conn):
|
||||||
|
"""创建 news 表"""
|
||||||
|
sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS news (
|
||||||
|
id VARCHAR(64) PRIMARY KEY,
|
||||||
|
user_id VARCHAR(64) NOT NULL,
|
||||||
|
title VARCHAR(200) NOT NULL,
|
||||||
|
content TEXT,
|
||||||
|
url VARCHAR(500),
|
||||||
|
source VARCHAR(100),
|
||||||
|
keywords TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
|
||||||
|
INDEX idx_news_user (user_id)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(sql)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"news 表可能已存在: {e}")
|
||||||
106
backend/app/db/models.py
Normal file
106
backend/app/db/models.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
子图数据模型和 Repository 定义
|
||||||
|
只需要定义表名、实体类,基类搞定一切
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .base import BaseRepository, BaseEntity
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 通讯录 ==========
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContactEntity(BaseEntity):
|
||||||
|
"""通讯录实体"""
|
||||||
|
name: str = ""
|
||||||
|
phone: str = ""
|
||||||
|
email: str = ""
|
||||||
|
company: str = ""
|
||||||
|
position: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ContactRepository(BaseRepository):
|
||||||
|
"""通讯录 Repository"""
|
||||||
|
table_name = "contacts"
|
||||||
|
entity_class = ContactEntity
|
||||||
|
|
||||||
|
# 如果需要特殊查询,在这里加方法
|
||||||
|
async def search_by_name(self, user_id: str, name_keyword: str) -> List[ContactEntity]:
|
||||||
|
"""自定义查询:按姓名搜索"""
|
||||||
|
sql = """
|
||||||
|
SELECT * FROM contacts
|
||||||
|
WHERE user_id = $1 AND name ILIKE $2
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT 50
|
||||||
|
"""
|
||||||
|
result = await self.custom_query(sql, (user_id, f"%{name_keyword}%"))
|
||||||
|
return [ContactEntity.from_dict(row) for row in result]
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 词典 ==========
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WordEntity(BaseEntity):
|
||||||
|
"""单词实体"""
|
||||||
|
word: str = ""
|
||||||
|
phonetic: str = ""
|
||||||
|
part_of_speech: str = ""
|
||||||
|
definition: str = ""
|
||||||
|
examples: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class DictionaryRepository(BaseRepository):
|
||||||
|
"""词典 Repository"""
|
||||||
|
table_name = "words"
|
||||||
|
entity_class = WordEntity
|
||||||
|
|
||||||
|
async def search_by_word(self, user_id: str, word: str) -> Optional[WordEntity]:
|
||||||
|
"""按单词查询"""
|
||||||
|
sql = """
|
||||||
|
SELECT * FROM words
|
||||||
|
WHERE user_id = $1 AND word = $2
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
result = await self.custom_query(sql, (user_id, word))
|
||||||
|
if result:
|
||||||
|
return WordEntity.from_dict(result[0])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 资讯 ==========
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NewsEntity(BaseEntity):
|
||||||
|
"""资讯实体"""
|
||||||
|
title: str = ""
|
||||||
|
content: str = ""
|
||||||
|
url: str = ""
|
||||||
|
source: str = ""
|
||||||
|
keywords: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class NewsRepository(BaseRepository):
|
||||||
|
"""资讯 Repository"""
|
||||||
|
table_name = "news"
|
||||||
|
entity_class = NewsEntity
|
||||||
|
|
||||||
|
async def search_by_keywords(self, user_id: str, keyword: str) -> List[NewsEntity]:
|
||||||
|
"""按关键词搜索"""
|
||||||
|
sql = """
|
||||||
|
SELECT * FROM news
|
||||||
|
WHERE user_id = $1 AND (title ILIKE $2 OR keywords ILIKE $2)
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT 50
|
||||||
|
"""
|
||||||
|
result = await self.custom_query(sql, (user_id, f"%{keyword}%"))
|
||||||
|
return [NewsEntity.from_dict(row) for row in result]
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 导出 ==========
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ContactEntity', 'ContactRepository',
|
||||||
|
'WordEntity', 'DictionaryRepository',
|
||||||
|
'NewsEntity', 'NewsRepository',
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user