Files
ailine/backend/rag_core/store/postgres.py

247 lines
9.1 KiB
Python
Raw Normal View History

2026-04-21 11:02:16 +08:00
"""
异步 PostgreSQL 存储实现 - 用于生产环境
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储支持高并发访问
"""
import asyncio
import json
import logging
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
import asyncpg
logger = logging.getLogger(__name__)
class PostgresDocStore(BaseStore[str, Any]):
"""
异步 PostgreSQL 文档存储实现
使用 asyncpg 作为异步 PostgreSQL 客户端支持
- 真正的异步操作
- 连接池管理
- 自动表创建
- 批量操作amget/amset/amdelete
- JSONB 数据存储
- 并发控制
适用于生产环境提供高性能的异步数据持久化
Attributes:
dsn: PostgreSQL 连接字符串
table_name: 存储表名默认为 "parent_documents"
_pool: asyncpg 连接池实例
_semaphore: 控制并发数的信号量可选
"""
def __init__(
self,
connection_string: str,
table_name: str = "parent_documents",
pool_config: Optional[Dict[str, Any]] = None,
max_concurrency: Optional[int] = None
):
"""
初始化异步 PostgreSQL 文档存储
Args:
connection_string: PostgreSQL 连接 URL格式
"postgresql://user:password@host:port/database?sslmode=disable"
table_name: 存储表名默认为 "parent_documents"
pool_config: 连接池配置字典包含
- min_size: 最小连接数默认 2
- max_size: 最大连接数默认 10
max_concurrency: 最大并发操作数如果为 None 则不限制
Raises:
ImportError: 未安装 asyncpg 时抛出
Example:
>>> store = PostgresDocStore(
... "postgresql://user:pass@localhost:5432/mydb",
... table_name="parent_docs",
... pool_config={"min_size": 5, "max_size": 20},
... max_concurrency=10
... )
"""
self.dsn = connection_string
self.table_name = table_name
self._pool: Optional["asyncpg.Pool"] = None
self._pool_config = pool_config or {}
# 并发控制信号量
self._semaphore = None
if max_concurrency is not None and max_concurrency > 0:
self._semaphore = asyncio.Semaphore(max_concurrency)
# 注意:连接池的异步初始化延迟到第一次使用时
# 表结构创建也延迟到第一次操作时
async def _get_pool(self):
"""获取或创建 asyncpg 连接池。"""
if self._pool is None:
import asyncpg
min_size = self._pool_config.get("min_size", 2)
max_size = self._pool_config.get("max_size", 10)
try:
self._pool = await asyncpg.create_pool(
dsn=self.dsn,
min_size=min_size,
max_size=max_size
)
logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}")
# 初始化表结构
await self._create_table()
except Exception as e:
raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}")
return self._pool
async def _create_table(self):
"""创建存储表(如果不存在)。"""
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value JSONB NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
)
""")
logger.info(f"{self.table_name} 已就绪")
async def _with_concurrency_control(self, coro):
"""使用信号量控制并发执行。"""
if self._semaphore is None:
return await coro
async with self._semaphore:
return await coro
# --- 同步方法(保持兼容性,但功能有限)---
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""不支持同步操作,请使用异步 amget 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。")
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""不支持同步操作,请使用异步 amset 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。")
def mdelete(self, keys: Sequence[str]) -> None:
"""不支持同步操作,请使用异步 amdelete 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。")
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""不支持同步操作,请使用异步 ayield_keys 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。")
# --- 异步方法(真正的实现)---
async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""异步批量获取文档。"""
if not keys:
return []
async def _amget():
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)",
keys
)
result_map = {}
for row in rows:
val = row['value']
if isinstance(val, str):
val = json.loads(val)
if isinstance(val, dict) and 'page_content' in val:
result_map[row['key']] = Document(**val)
else:
result_map[row['key']] = val
return [result_map.get(key) for key in keys]
return await self._with_concurrency_control(_amget())
async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""异步批量设置文档。"""
if not key_value_pairs:
return
async def _amset():
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.executemany(
f"""
INSERT INTO {self.table_name} (key, value)
VALUES ($1, $2)
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
""",
[
(k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False))
for k, v in key_value_pairs
]
)
logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档")
await self._with_concurrency_control(_amset())
async def amdelete(self, keys: Sequence[str]) -> None:
"""异步批量删除文档。"""
if not keys:
return
async def _amdelete():
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
f"DELETE FROM {self.table_name} WHERE key = ANY($1)",
keys
)
logger.debug(f"已异步批量删除 {len(keys)} 个文档")
await self._with_concurrency_control(_amdelete())
async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""异步迭代所有键。
注意这是一个异步生成器需要使用 async for 迭代
"""
pool = await self._get_pool()
async with pool.acquire() as conn:
if prefix:
rows = await conn.fetch(
f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key",
f"{prefix}%"
)
else:
rows = await conn.fetch(
f"SELECT key FROM {self.table_name} ORDER BY key"
)
for row in rows:
yield row['key']
async def aclose(self) -> None:
"""异步关闭连接池,释放资源。"""
if self._pool:
await self._pool.close()
self._pool = None
logger.info("PostgreSQL 异步连接池已关闭")
def close(self) -> None:
"""同步关闭连接池(功能有限)。
注意在异步环境中请使用 aclose 方法
"""
pass