""" 统一的 Repository 基类 所有子图的数据操作都用这个,避免重复代码 """ import json from typing import Any, List, Dict, Optional from dataclasses import dataclass, asdict @dataclass class BaseEntity: """实体基类,用于类型提示和自动序列化""" id: Optional[str] = None user_id: Optional[str] = None created_at: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """转换为字典(用于存储)""" return asdict(self) @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'BaseEntity': """从字典创建实体""" return cls(**data) class BaseRepository: """ 统一的 Repository 基类 核心思想: 1. 传入连接,复用现有 checkpointer.conn 2. 提供通用 CRUD 方法 3. 表名和字段映射通过子类定义 """ # 子类需要定义的配置 table_name: str entity_class: type # 继承自 BaseEntity 的实体类 id_column: str = "id" user_id_column: str = "user_id" def __init__(self, conn): """ 构造函数:传入数据库连接 这样可以复用 checkpointer.conn,不需要额外的连接池 """ self.conn = conn # ========== 通用 CRUD 方法 ========== async def insert(self, entity: BaseEntity) -> str: """插入单个实体,返回 ID""" data = entity.to_dict() data.pop('id', None) # 如果 id 是自增的 if not data: raise ValueError("实体数据为空") columns = list(data.keys()) placeholders = [f"${i+1}" for i in range(len(columns))] values = list(data.values()) sql = f""" INSERT INTO {self.table_name} ({', '.join(columns)}) VALUES ({', '.join(placeholders)}) RETURNING {self.id_column} """ async with self.conn.cursor() as cur: await cur.execute(sql, values) row = await cur.fetchone() return row[self.id_column] if row else None async def update(self, entity_id: str, data: Dict[str, Any]) -> bool: """更新实体""" if not data: return False set_clause = ", ".join([f"{k} = ${i+1}" for i, k in enumerate(data.keys())]) values = list(data.values()) + [entity_id] sql = f""" UPDATE {self.table_name} SET {set_clause} WHERE {self.id_column} = ${len(values)} """ async with self.conn.cursor() as cur: await cur.execute(sql, values) return cur.rowcount > 0 async def delete(self, entity_id: str) -> bool: """删除实体""" sql = f"DELETE FROM {self.table_name} WHERE {self.id_column} = $1" async with self.conn.cursor() as cur: await cur.execute(sql, (entity_id,)) return cur.rowcount > 0 async def get_by_id(self, entity_id: str) -> Optional[BaseEntity]: """根据 ID 查询单个实体""" sql = f"SELECT * FROM {self.table_name} WHERE {self.id_column} = $1" async with self.conn.cursor() as cur: await cur.execute(sql, (entity_id,)) row = await cur.fetchone() if row: return self.entity_class.from_dict(dict(row)) return None async def list_by_user(self, user_id: str, limit: int = 100) -> List[BaseEntity]: """查询某个用户的所有实体""" sql = f""" SELECT * FROM {self.table_name} WHERE {self.user_id_column} = $1 ORDER BY created_at DESC LIMIT $2 """ async with self.conn.cursor() as cur: await cur.execute(sql, (user_id, limit)) rows = await cur.fetchall() return [self.entity_class.from_dict(dict(row)) for row in rows] # ========== 自定义查询方法(留钩子)========== async def custom_query(self, sql: str, params: tuple = ()) -> List[Dict]: """执行自定义 SQL""" async with self.conn.cursor() as cur: await cur.execute(sql, params) rows = await cur.fetchall() return [dict(row) for row in rows]