Files
ailine/backend/app/db/base.py
root 5b12188d45
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m11s
修复:PostgreSQL 语法适配 + UUID 自动生成
2026-04-27 16:38:39 +08:00

139 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一的 Repository 基类
所有子图的数据操作都用这个,避免重复代码
"""
import json
import uuid
from typing import Any, List, Dict, Optional
from dataclasses import dataclass, asdict
@dataclass
class BaseEntity:
"""实体基类,用于类型提示和自动序列化"""
id: Optional[str] = None
user_id: Optional[str] = None
created_at: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于存储)"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'BaseEntity':
"""从字典创建实体"""
return cls(**data)
class BaseRepository:
"""
统一的 Repository 基类
核心思想:
1. 传入连接,复用现有 checkpointer.conn
2. 提供通用 CRUD 方法
3. 表名和字段映射通过子类定义
"""
# 子类需要定义的配置
table_name: str
entity_class: type # 继承自 BaseEntity 的实体类
id_column: str = "id"
user_id_column: str = "user_id"
def __init__(self, conn):
"""
构造函数:传入数据库连接
这样可以复用 checkpointer.conn不需要额外的连接池
"""
self.conn = conn
# ========== 通用 CRUD 方法 ==========
async def insert(self, entity: BaseEntity) -> str:
"""插入单个实体,返回 ID"""
data = entity.to_dict()
# 确保有 ID
if not data.get('id'):
data['id'] = str(uuid.uuid4())
if not data:
raise ValueError("实体数据为空")
columns = list(data.keys())
placeholders = [f"${i+1}" for i in range(len(columns))]
values = list(data.values())
sql = f"""
INSERT INTO {self.table_name} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
RETURNING {self.id_column}
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, values)
row = await cur.fetchone()
return row[self.id_column] if row else data['id']
async def update(self, entity_id: str, data: Dict[str, Any]) -> bool:
"""更新实体"""
if not data:
return False
set_clause = ", ".join([f"{k} = ${i+1}" for i, k in enumerate(data.keys())])
values = list(data.values()) + [entity_id]
sql = f"""
UPDATE {self.table_name}
SET {set_clause}
WHERE {self.id_column} = ${len(values)}
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, values)
return cur.rowcount > 0
async def delete(self, entity_id: str) -> bool:
"""删除实体"""
sql = f"DELETE FROM {self.table_name} WHERE {self.id_column} = $1"
async with self.conn.cursor() as cur:
await cur.execute(sql, (entity_id,))
return cur.rowcount > 0
async def get_by_id(self, entity_id: str) -> Optional[BaseEntity]:
"""根据 ID 查询单个实体"""
sql = f"SELECT * FROM {self.table_name} WHERE {self.id_column} = $1"
async with self.conn.cursor() as cur:
await cur.execute(sql, (entity_id,))
row = await cur.fetchone()
if row:
return self.entity_class.from_dict(dict(row))
return None
async def list_by_user(self, user_id: str, limit: int = 100) -> List[BaseEntity]:
"""查询某个用户的所有实体"""
sql = f"""
SELECT * FROM {self.table_name}
WHERE {self.user_id_column} = $1
ORDER BY created_at DESC
LIMIT $2
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, (user_id, limit))
rows = await cur.fetchall()
return [self.entity_class.from_dict(dict(row)) for row in rows]
# ========== 自定义查询方法(留钩子)==========
async def custom_query(self, sql: str, params: tuple = ()) -> List[Dict]:
"""执行自定义 SQL"""
async with self.conn.cursor() as cur:
await cur.execute(sql, params)
rows = await cur.fetchall()
return [dict(row) for row in rows]