2026-04-27 16:37:45 +08:00
|
|
|
|
"""
|
|
|
|
|
|
统一的 Repository 基类
|
|
|
|
|
|
所有子图的数据操作都用这个,避免重复代码
|
|
|
|
|
|
"""
|
|
|
|
|
|
import json
|
2026-04-27 16:38:39 +08:00
|
|
|
|
import uuid
|
2026-04-27 16:37:45 +08:00
|
|
|
|
from typing import Any, List, Dict, Optional
|
|
|
|
|
|
from dataclasses import dataclass, asdict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class BaseEntity:
|
|
|
|
|
|
"""实体基类,用于类型提示和自动序列化"""
|
|
|
|
|
|
id: Optional[str] = None
|
|
|
|
|
|
user_id: Optional[str] = None
|
|
|
|
|
|
created_at: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
|
|
|
|
"""转换为字典(用于存储)"""
|
|
|
|
|
|
return asdict(self)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'BaseEntity':
|
|
|
|
|
|
"""从字典创建实体"""
|
|
|
|
|
|
return cls(**data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseRepository:
|
|
|
|
|
|
"""
|
|
|
|
|
|
统一的 Repository 基类
|
|
|
|
|
|
|
|
|
|
|
|
核心思想:
|
|
|
|
|
|
1. 传入连接,复用现有 checkpointer.conn
|
|
|
|
|
|
2. 提供通用 CRUD 方法
|
|
|
|
|
|
3. 表名和字段映射通过子类定义
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# 子类需要定义的配置
|
|
|
|
|
|
table_name: str
|
|
|
|
|
|
entity_class: type # 继承自 BaseEntity 的实体类
|
|
|
|
|
|
id_column: str = "id"
|
|
|
|
|
|
user_id_column: str = "user_id"
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, conn):
|
|
|
|
|
|
"""
|
|
|
|
|
|
构造函数:传入数据库连接
|
|
|
|
|
|
|
|
|
|
|
|
这样可以复用 checkpointer.conn,不需要额外的连接池
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.conn = conn
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 通用 CRUD 方法 ==========
|
|
|
|
|
|
|
|
|
|
|
|
async def insert(self, entity: BaseEntity) -> str:
|
|
|
|
|
|
"""插入单个实体,返回 ID"""
|
|
|
|
|
|
data = entity.to_dict()
|
2026-04-27 16:38:39 +08:00
|
|
|
|
|
|
|
|
|
|
# 确保有 ID
|
|
|
|
|
|
if not data.get('id'):
|
|
|
|
|
|
data['id'] = str(uuid.uuid4())
|
2026-04-27 16:37:45 +08:00
|
|
|
|
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
raise ValueError("实体数据为空")
|
|
|
|
|
|
|
|
|
|
|
|
columns = list(data.keys())
|
|
|
|
|
|
placeholders = [f"${i+1}" for i in range(len(columns))]
|
|
|
|
|
|
values = list(data.values())
|
|
|
|
|
|
|
|
|
|
|
|
sql = f"""
|
|
|
|
|
|
INSERT INTO {self.table_name} ({', '.join(columns)})
|
|
|
|
|
|
VALUES ({', '.join(placeholders)})
|
|
|
|
|
|
RETURNING {self.id_column}
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
async with self.conn.cursor() as cur:
|
|
|
|
|
|
await cur.execute(sql, values)
|
|
|
|
|
|
row = await cur.fetchone()
|
2026-04-27 16:38:39 +08:00
|
|
|
|
return row[self.id_column] if row else data['id']
|
2026-04-27 16:37:45 +08:00
|
|
|
|
|
|
|
|
|
|
async def update(self, entity_id: str, data: Dict[str, Any]) -> bool:
|
|
|
|
|
|
"""更新实体"""
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
set_clause = ", ".join([f"{k} = ${i+1}" for i, k in enumerate(data.keys())])
|
|
|
|
|
|
values = list(data.values()) + [entity_id]
|
|
|
|
|
|
|
|
|
|
|
|
sql = f"""
|
|
|
|
|
|
UPDATE {self.table_name}
|
|
|
|
|
|
SET {set_clause}
|
|
|
|
|
|
WHERE {self.id_column} = ${len(values)}
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
async with self.conn.cursor() as cur:
|
|
|
|
|
|
await cur.execute(sql, values)
|
|
|
|
|
|
return cur.rowcount > 0
|
|
|
|
|
|
|
|
|
|
|
|
async def delete(self, entity_id: str) -> bool:
|
|
|
|
|
|
"""删除实体"""
|
|
|
|
|
|
sql = f"DELETE FROM {self.table_name} WHERE {self.id_column} = $1"
|
|
|
|
|
|
|
|
|
|
|
|
async with self.conn.cursor() as cur:
|
|
|
|
|
|
await cur.execute(sql, (entity_id,))
|
|
|
|
|
|
return cur.rowcount > 0
|
|
|
|
|
|
|
|
|
|
|
|
async def get_by_id(self, entity_id: str) -> Optional[BaseEntity]:
|
|
|
|
|
|
"""根据 ID 查询单个实体"""
|
|
|
|
|
|
sql = f"SELECT * FROM {self.table_name} WHERE {self.id_column} = $1"
|
|
|
|
|
|
|
|
|
|
|
|
async with self.conn.cursor() as cur:
|
|
|
|
|
|
await cur.execute(sql, (entity_id,))
|
|
|
|
|
|
row = await cur.fetchone()
|
|
|
|
|
|
if row:
|
|
|
|
|
|
return self.entity_class.from_dict(dict(row))
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
async def list_by_user(self, user_id: str, limit: int = 100) -> List[BaseEntity]:
|
|
|
|
|
|
"""查询某个用户的所有实体"""
|
|
|
|
|
|
sql = f"""
|
|
|
|
|
|
SELECT * FROM {self.table_name}
|
|
|
|
|
|
WHERE {self.user_id_column} = $1
|
|
|
|
|
|
ORDER BY created_at DESC
|
|
|
|
|
|
LIMIT $2
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
async with self.conn.cursor() as cur:
|
|
|
|
|
|
await cur.execute(sql, (user_id, limit))
|
|
|
|
|
|
rows = await cur.fetchall()
|
|
|
|
|
|
return [self.entity_class.from_dict(dict(row)) for row in rows]
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 自定义查询方法(留钩子)==========
|
|
|
|
|
|
|
|
|
|
|
|
async def custom_query(self, sql: str, params: tuple = ()) -> List[Dict]:
|
|
|
|
|
|
"""执行自定义 SQL"""
|
|
|
|
|
|
async with self.conn.cursor() as cur:
|
|
|
|
|
|
await cur.execute(sql, params)
|
|
|
|
|
|
rows = await cur.fetchall()
|
|
|
|
|
|
return [dict(row) for row in rows]
|