""" 异步 PostgreSQL 存储实现 - 用于生产环境。 使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。 """ import asyncio import json import logging from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence from langchain_core.documents import Document from langchain_core.stores import BaseStore import asyncpg logger = logging.getLogger(__name__) class PostgresDocStore(BaseStore[str, Any]): """ 异步 PostgreSQL 文档存储实现。 使用 asyncpg 作为异步 PostgreSQL 客户端,支持: - 真正的异步操作 - 连接池管理 - 自动表创建 - 批量操作(amget/amset/amdelete) - JSONB 数据存储 - 并发控制 适用于生产环境,提供高性能的异步数据持久化。 Attributes: dsn: PostgreSQL 连接字符串 table_name: 存储表名,默认为 "parent_documents" _pool: asyncpg 连接池实例 _semaphore: 控制并发数的信号量(可选) """ def __init__( self, connection_string: str, table_name: str = "parent_documents", pool_config: Optional[Dict[str, Any]] = None, max_concurrency: Optional[int] = None ): """ 初始化异步 PostgreSQL 文档存储。 Args: connection_string: PostgreSQL 连接 URL,格式: "postgresql://user:password@host:port/database?sslmode=disable" table_name: 存储表名,默认为 "parent_documents" pool_config: 连接池配置字典,包含: - min_size: 最小连接数(默认 2) - max_size: 最大连接数(默认 10) max_concurrency: 最大并发操作数,如果为 None 则不限制 Raises: ImportError: 未安装 asyncpg 时抛出 Example: >>> store = PostgresDocStore( ... "postgresql://user:pass@localhost:5432/mydb", ... table_name="parent_docs", ... pool_config={"min_size": 5, "max_size": 20}, ... max_concurrency=10 ... ) """ self.dsn = connection_string self.table_name = table_name self._pool: Optional["asyncpg.Pool"] = None self._pool_config = pool_config or {} # 并发控制信号量 self._semaphore = None if max_concurrency is not None and max_concurrency > 0: self._semaphore = asyncio.Semaphore(max_concurrency) # 注意:连接池的异步初始化延迟到第一次使用时 # 表结构创建也延迟到第一次操作时 async def _get_pool(self): """获取或创建 asyncpg 连接池。""" if self._pool is None: import asyncpg min_size = self._pool_config.get("min_size", 2) max_size = self._pool_config.get("max_size", 10) try: self._pool = await asyncpg.create_pool( dsn=self.dsn, min_size=min_size, max_size=max_size ) logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}") # 初始化表结构 await self._create_table() except Exception as e: raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}") return self._pool async def _create_table(self): """创建存储表(如果不存在)。""" pool = await self._get_pool() async with pool.acquire() as conn: async with conn.transaction(): await conn.execute(f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( key TEXT PRIMARY KEY, value JSONB NOT NULL, created_at TIMESTAMPTZ DEFAULT NOW() ) """) logger.info(f"表 {self.table_name} 已就绪") async def _with_concurrency_control(self, coro): """使用信号量控制并发执行。""" if self._semaphore is None: return await coro async with self._semaphore: return await coro # --- 同步方法(保持兼容性,但功能有限)--- def mget(self, keys: Sequence[str]) -> List[Optional[Any]]: """不支持同步操作,请使用异步 amget 方法。""" raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。") def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None: """不支持同步操作,请使用异步 amset 方法。""" raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。") def mdelete(self, keys: Sequence[str]) -> None: """不支持同步操作,请使用异步 amdelete 方法。""" raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。") def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: """不支持同步操作,请使用异步 ayield_keys 方法。""" raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。") # --- 异步方法(真正的实现)--- async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]: """异步批量获取文档。""" if not keys: return [] async def _amget(): pool = await self._get_pool() async with pool.acquire() as conn: rows = await conn.fetch( f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)", keys ) result_map = {} for row in rows: val = row['value'] if isinstance(val, str): val = json.loads(val) if isinstance(val, dict) and 'page_content' in val: result_map[row['key']] = Document(**val) else: result_map[row['key']] = val return [result_map.get(key) for key in keys] return await self._with_concurrency_control(_amget()) async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None: """异步批量设置文档。""" if not key_value_pairs: return async def _amset(): pool = await self._get_pool() async with pool.acquire() as conn: async with conn.transaction(): await conn.executemany( f""" INSERT INTO {self.table_name} (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value """, [ (k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False)) for k, v in key_value_pairs ] ) logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档") await self._with_concurrency_control(_amset()) async def amdelete(self, keys: Sequence[str]) -> None: """异步批量删除文档。""" if not keys: return async def _amdelete(): pool = await self._get_pool() async with pool.acquire() as conn: async with conn.transaction(): await conn.execute( f"DELETE FROM {self.table_name} WHERE key = ANY($1)", keys ) logger.debug(f"已异步批量删除 {len(keys)} 个文档") await self._with_concurrency_control(_amdelete()) async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]: """异步迭代所有键。 注意:这是一个异步生成器,需要使用 async for 迭代。 """ pool = await self._get_pool() async with pool.acquire() as conn: if prefix: rows = await conn.fetch( f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key", f"{prefix}%" ) else: rows = await conn.fetch( f"SELECT key FROM {self.table_name} ORDER BY key" ) for row in rows: yield row['key'] async def aclose(self) -> None: """异步关闭连接池,释放资源。""" if self._pool: await self._pool.close() self._pool = None logger.info("PostgreSQL 异步连接池已关闭") def close(self) -> None: """同步关闭连接池(功能有限)。 注意:在异步环境中,请使用 aclose 方法。 """ pass