统一 Repository 方案:添加 db 基类和子图模型 + 通讯录 API 支持真实数据库
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-04-27 16:37:45 +08:00
parent 0cb9571db7
commit 29016f8792
6 changed files with 460 additions and 57 deletions

View File

@@ -0,0 +1,7 @@
"""
数据库模块
提供统一的 Repository 和实体类
"""
from .base import BaseRepository, BaseEntity
__all__ = ['BaseRepository', 'BaseEntity']

134
backend/app/db/base.py Normal file
View File

@@ -0,0 +1,134 @@
"""
统一的 Repository 基类
所有子图的数据操作都用这个,避免重复代码
"""
import json
from typing import Any, List, Dict, Optional
from dataclasses import dataclass, asdict
@dataclass
class BaseEntity:
"""实体基类,用于类型提示和自动序列化"""
id: Optional[str] = None
user_id: Optional[str] = None
created_at: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于存储)"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'BaseEntity':
"""从字典创建实体"""
return cls(**data)
class BaseRepository:
"""
统一的 Repository 基类
核心思想:
1. 传入连接,复用现有 checkpointer.conn
2. 提供通用 CRUD 方法
3. 表名和字段映射通过子类定义
"""
# 子类需要定义的配置
table_name: str
entity_class: type # 继承自 BaseEntity 的实体类
id_column: str = "id"
user_id_column: str = "user_id"
def __init__(self, conn):
"""
构造函数:传入数据库连接
这样可以复用 checkpointer.conn不需要额外的连接池
"""
self.conn = conn
# ========== 通用 CRUD 方法 ==========
async def insert(self, entity: BaseEntity) -> str:
"""插入单个实体,返回 ID"""
data = entity.to_dict()
data.pop('id', None) # 如果 id 是自增的
if not data:
raise ValueError("实体数据为空")
columns = list(data.keys())
placeholders = [f"${i+1}" for i in range(len(columns))]
values = list(data.values())
sql = f"""
INSERT INTO {self.table_name} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
RETURNING {self.id_column}
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, values)
row = await cur.fetchone()
return row[self.id_column] if row else None
async def update(self, entity_id: str, data: Dict[str, Any]) -> bool:
"""更新实体"""
if not data:
return False
set_clause = ", ".join([f"{k} = ${i+1}" for i, k in enumerate(data.keys())])
values = list(data.values()) + [entity_id]
sql = f"""
UPDATE {self.table_name}
SET {set_clause}
WHERE {self.id_column} = ${len(values)}
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, values)
return cur.rowcount > 0
async def delete(self, entity_id: str) -> bool:
"""删除实体"""
sql = f"DELETE FROM {self.table_name} WHERE {self.id_column} = $1"
async with self.conn.cursor() as cur:
await cur.execute(sql, (entity_id,))
return cur.rowcount > 0
async def get_by_id(self, entity_id: str) -> Optional[BaseEntity]:
"""根据 ID 查询单个实体"""
sql = f"SELECT * FROM {self.table_name} WHERE {self.id_column} = $1"
async with self.conn.cursor() as cur:
await cur.execute(sql, (entity_id,))
row = await cur.fetchone()
if row:
return self.entity_class.from_dict(dict(row))
return None
async def list_by_user(self, user_id: str, limit: int = 100) -> List[BaseEntity]:
"""查询某个用户的所有实体"""
sql = f"""
SELECT * FROM {self.table_name}
WHERE {self.user_id_column} = $1
ORDER BY created_at DESC
LIMIT $2
"""
async with self.conn.cursor() as cur:
await cur.execute(sql, (user_id, limit))
rows = await cur.fetchall()
return [self.entity_class.from_dict(dict(row)) for row in rows]
# ========== 自定义查询方法(留钩子)==========
async def custom_query(self, sql: str, params: tuple = ()) -> List[Dict]:
"""执行自定义 SQL"""
async with self.conn.cursor() as cur:
await cur.execute(sql, params)
rows = await cur.fetchall()
return [dict(row) for row in rows]

97
backend/app/db/init_db.py Normal file
View File

@@ -0,0 +1,97 @@
"""
子图数据库表初始化
在 FastAPI 启动时创建表
"""
from typing import Optional
async def init_subgraph_tables(conn):
"""
初始化子图所需的表
Args:
conn: 数据库连接(来自 AsyncPostgresSaver
"""
# 1. contacts 表(通讯录)
await _create_contacts_table(conn)
# 2. words 表(词典)
await _create_words_table(conn)
# 3. news 表(资讯)
await _create_news_table(conn)
async def _create_contacts_table(conn):
"""创建 contacts 表"""
sql = """
CREATE TABLE IF NOT EXISTS contacts (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
name VARCHAR(100) NOT NULL,
phone VARCHAR(32),
email VARCHAR(100),
company VARCHAR(100),
position VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_contacts_user (user_id)
);
"""
try:
async with conn.cursor() as cur:
await cur.execute(sql)
# 注释掉 INFO 打印,避免噪声
except Exception as e:
print(f"contacts 表可能已存在: {e}")
async def _create_words_table(conn):
"""创建 words 表"""
sql = """
CREATE TABLE IF NOT EXISTS words (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
word VARCHAR(100) NOT NULL,
phonetic VARCHAR(100),
part_of_speech VARCHAR(50),
definition TEXT,
examples TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_words_user (user_id),
INDEX idx_words_word (word)
);
"""
try:
async with conn.cursor() as cur:
await cur.execute(sql)
except Exception as e:
print(f"words 表可能已存在: {e}")
async def _create_news_table(conn):
"""创建 news 表"""
sql = """
CREATE TABLE IF NOT EXISTS news (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
title VARCHAR(200) NOT NULL,
content TEXT,
url VARCHAR(500),
source VARCHAR(100),
keywords TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_news_user (user_id)
);
"""
try:
async with conn.cursor() as cur:
await cur.execute(sql)
except Exception as e:
print(f"news 表可能已存在: {e}")

106
backend/app/db/models.py Normal file
View File

@@ -0,0 +1,106 @@
"""
子图数据模型和 Repository 定义
只需要定义表名、实体类,基类搞定一切
"""
from typing import List, Optional
from dataclasses import dataclass
from .base import BaseRepository, BaseEntity
# ========== 通讯录 ==========
@dataclass
class ContactEntity(BaseEntity):
"""通讯录实体"""
name: str = ""
phone: str = ""
email: str = ""
company: str = ""
position: str = ""
class ContactRepository(BaseRepository):
"""通讯录 Repository"""
table_name = "contacts"
entity_class = ContactEntity
# 如果需要特殊查询,在这里加方法
async def search_by_name(self, user_id: str, name_keyword: str) -> List[ContactEntity]:
"""自定义查询:按姓名搜索"""
sql = """
SELECT * FROM contacts
WHERE user_id = $1 AND name ILIKE $2
ORDER BY created_at DESC
LIMIT 50
"""
result = await self.custom_query(sql, (user_id, f"%{name_keyword}%"))
return [ContactEntity.from_dict(row) for row in result]
# ========== 词典 ==========
@dataclass
class WordEntity(BaseEntity):
"""单词实体"""
word: str = ""
phonetic: str = ""
part_of_speech: str = ""
definition: str = ""
examples: str = ""
class DictionaryRepository(BaseRepository):
"""词典 Repository"""
table_name = "words"
entity_class = WordEntity
async def search_by_word(self, user_id: str, word: str) -> Optional[WordEntity]:
"""按单词查询"""
sql = """
SELECT * FROM words
WHERE user_id = $1 AND word = $2
LIMIT 1
"""
result = await self.custom_query(sql, (user_id, word))
if result:
return WordEntity.from_dict(result[0])
return None
# ========== 资讯 ==========
@dataclass
class NewsEntity(BaseEntity):
"""资讯实体"""
title: str = ""
content: str = ""
url: str = ""
source: str = ""
keywords: str = ""
class NewsRepository(BaseRepository):
"""资讯 Repository"""
table_name = "news"
entity_class = NewsEntity
async def search_by_keywords(self, user_id: str, keyword: str) -> List[NewsEntity]:
"""按关键词搜索"""
sql = """
SELECT * FROM news
WHERE user_id = $1 AND (title ILIKE $2 OR keywords ILIKE $2)
ORDER BY created_at DESC
LIMIT 50
"""
result = await self.custom_query(sql, (user_id, f"%{keyword}%"))
return [NewsEntity.from_dict(row) for row in result]
# ========== 导出 ==========
__all__ = [
'ContactEntity', 'ContactRepository',
'WordEntity', 'DictionaryRepository',
'NewsEntity', 'NewsRepository',
]