247 lines
9.1 KiB
Python
247 lines
9.1 KiB
Python
"""
|
||
异步 PostgreSQL 存储实现 - 用于生产环境。
|
||
|
||
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence
|
||
|
||
from langchain_core.documents import Document
|
||
from langchain_core.stores import BaseStore
|
||
|
||
import asyncpg
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class PostgresDocStore(BaseStore[str, Any]):
|
||
"""
|
||
异步 PostgreSQL 文档存储实现。
|
||
|
||
使用 asyncpg 作为异步 PostgreSQL 客户端,支持:
|
||
- 真正的异步操作
|
||
- 连接池管理
|
||
- 自动表创建
|
||
- 批量操作(amget/amset/amdelete)
|
||
- JSONB 数据存储
|
||
- 并发控制
|
||
|
||
适用于生产环境,提供高性能的异步数据持久化。
|
||
|
||
Attributes:
|
||
dsn: PostgreSQL 连接字符串
|
||
table_name: 存储表名,默认为 "parent_documents"
|
||
_pool: asyncpg 连接池实例
|
||
_semaphore: 控制并发数的信号量(可选)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
connection_string: str,
|
||
table_name: str = "parent_documents",
|
||
pool_config: Optional[Dict[str, Any]] = None,
|
||
max_concurrency: Optional[int] = None
|
||
):
|
||
"""
|
||
初始化异步 PostgreSQL 文档存储。
|
||
|
||
Args:
|
||
connection_string: PostgreSQL 连接 URL,格式:
|
||
"postgresql://user:password@host:port/database?sslmode=disable"
|
||
table_name: 存储表名,默认为 "parent_documents"
|
||
pool_config: 连接池配置字典,包含:
|
||
- min_size: 最小连接数(默认 2)
|
||
- max_size: 最大连接数(默认 10)
|
||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||
|
||
Raises:
|
||
ImportError: 未安装 asyncpg 时抛出
|
||
|
||
Example:
|
||
>>> store = PostgresDocStore(
|
||
... "postgresql://user:pass@localhost:5432/mydb",
|
||
... table_name="parent_docs",
|
||
... pool_config={"min_size": 5, "max_size": 20},
|
||
... max_concurrency=10
|
||
... )
|
||
"""
|
||
|
||
|
||
self.dsn = connection_string
|
||
self.table_name = table_name
|
||
self._pool: Optional["asyncpg.Pool"] = None
|
||
self._pool_config = pool_config or {}
|
||
|
||
# 并发控制信号量
|
||
self._semaphore = None
|
||
if max_concurrency is not None and max_concurrency > 0:
|
||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||
|
||
# 注意:连接池的异步初始化延迟到第一次使用时
|
||
# 表结构创建也延迟到第一次操作时
|
||
|
||
async def _get_pool(self):
|
||
"""获取或创建 asyncpg 连接池。"""
|
||
if self._pool is None:
|
||
import asyncpg
|
||
min_size = self._pool_config.get("min_size", 2)
|
||
max_size = self._pool_config.get("max_size", 10)
|
||
|
||
try:
|
||
self._pool = await asyncpg.create_pool(
|
||
dsn=self.dsn,
|
||
min_size=min_size,
|
||
max_size=max_size
|
||
)
|
||
logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}")
|
||
|
||
# 初始化表结构
|
||
await self._create_table()
|
||
except Exception as e:
|
||
raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}")
|
||
|
||
return self._pool
|
||
|
||
async def _create_table(self):
|
||
"""创建存储表(如果不存在)。"""
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.transaction():
|
||
await conn.execute(f"""
|
||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||
key TEXT PRIMARY KEY,
|
||
value JSONB NOT NULL,
|
||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||
)
|
||
""")
|
||
logger.info(f"表 {self.table_name} 已就绪")
|
||
|
||
async def _with_concurrency_control(self, coro):
|
||
"""使用信号量控制并发执行。"""
|
||
if self._semaphore is None:
|
||
return await coro
|
||
async with self._semaphore:
|
||
return await coro
|
||
|
||
# --- 同步方法(保持兼容性,但功能有限)---
|
||
|
||
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
|
||
"""不支持同步操作,请使用异步 amget 方法。"""
|
||
raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。")
|
||
|
||
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
|
||
"""不支持同步操作,请使用异步 amset 方法。"""
|
||
raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。")
|
||
|
||
def mdelete(self, keys: Sequence[str]) -> None:
|
||
"""不支持同步操作,请使用异步 amdelete 方法。"""
|
||
raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。")
|
||
|
||
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
|
||
"""不支持同步操作,请使用异步 ayield_keys 方法。"""
|
||
raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。")
|
||
|
||
# --- 异步方法(真正的实现)---
|
||
|
||
async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]:
|
||
"""异步批量获取文档。"""
|
||
if not keys:
|
||
return []
|
||
|
||
async def _amget():
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
rows = await conn.fetch(
|
||
f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)",
|
||
keys
|
||
)
|
||
result_map = {}
|
||
for row in rows:
|
||
val = row['value']
|
||
if isinstance(val, str):
|
||
val = json.loads(val)
|
||
if isinstance(val, dict) and 'page_content' in val:
|
||
result_map[row['key']] = Document(**val)
|
||
else:
|
||
result_map[row['key']] = val
|
||
return [result_map.get(key) for key in keys]
|
||
|
||
return await self._with_concurrency_control(_amget())
|
||
|
||
async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
|
||
"""异步批量设置文档。"""
|
||
if not key_value_pairs:
|
||
return
|
||
|
||
async def _amset():
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.transaction():
|
||
await conn.executemany(
|
||
f"""
|
||
INSERT INTO {self.table_name} (key, value)
|
||
VALUES ($1, $2)
|
||
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
|
||
""",
|
||
[
|
||
(k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False))
|
||
for k, v in key_value_pairs
|
||
]
|
||
)
|
||
logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档")
|
||
|
||
await self._with_concurrency_control(_amset())
|
||
|
||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||
"""异步批量删除文档。"""
|
||
if not keys:
|
||
return
|
||
|
||
async def _amdelete():
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.transaction():
|
||
await conn.execute(
|
||
f"DELETE FROM {self.table_name} WHERE key = ANY($1)",
|
||
keys
|
||
)
|
||
logger.debug(f"已异步批量删除 {len(keys)} 个文档")
|
||
|
||
await self._with_concurrency_control(_amdelete())
|
||
|
||
async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
|
||
"""异步迭代所有键。
|
||
|
||
注意:这是一个异步生成器,需要使用 async for 迭代。
|
||
"""
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
if prefix:
|
||
rows = await conn.fetch(
|
||
f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key",
|
||
f"{prefix}%"
|
||
)
|
||
else:
|
||
rows = await conn.fetch(
|
||
f"SELECT key FROM {self.table_name} ORDER BY key"
|
||
)
|
||
|
||
for row in rows:
|
||
yield row['key']
|
||
|
||
async def aclose(self) -> None:
|
||
"""异步关闭连接池,释放资源。"""
|
||
if self._pool:
|
||
await self._pool.close()
|
||
self._pool = None
|
||
logger.info("PostgreSQL 异步连接池已关闭")
|
||
|
||
def close(self) -> None:
|
||
"""同步关闭连接池(功能有限)。
|
||
|
||
注意:在异步环境中,请使用 aclose 方法。
|
||
"""
|
||
pass
|