Files
ailine/backend/rag_core/store/postgres.py
root 8b354b7ccc
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 47m14s
重构代码,统一config配置
2026-04-21 11:02:16 +08:00

247 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
异步 PostgreSQL 存储实现 - 用于生产环境。
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
"""
import asyncio
import json
import logging
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
import asyncpg
logger = logging.getLogger(__name__)
class PostgresDocStore(BaseStore[str, Any]):
"""
异步 PostgreSQL 文档存储实现。
使用 asyncpg 作为异步 PostgreSQL 客户端,支持:
- 真正的异步操作
- 连接池管理
- 自动表创建
- 批量操作amget/amset/amdelete
- JSONB 数据存储
- 并发控制
适用于生产环境,提供高性能的异步数据持久化。
Attributes:
dsn: PostgreSQL 连接字符串
table_name: 存储表名,默认为 "parent_documents"
_pool: asyncpg 连接池实例
_semaphore: 控制并发数的信号量(可选)
"""
def __init__(
self,
connection_string: str,
table_name: str = "parent_documents",
pool_config: Optional[Dict[str, Any]] = None,
max_concurrency: Optional[int] = None
):
"""
初始化异步 PostgreSQL 文档存储。
Args:
connection_string: PostgreSQL 连接 URL格式
"postgresql://user:password@host:port/database?sslmode=disable"
table_name: 存储表名,默认为 "parent_documents"
pool_config: 连接池配置字典,包含:
- min_size: 最小连接数(默认 2
- max_size: 最大连接数(默认 10
max_concurrency: 最大并发操作数,如果为 None 则不限制
Raises:
ImportError: 未安装 asyncpg 时抛出
Example:
>>> store = PostgresDocStore(
... "postgresql://user:pass@localhost:5432/mydb",
... table_name="parent_docs",
... pool_config={"min_size": 5, "max_size": 20},
... max_concurrency=10
... )
"""
self.dsn = connection_string
self.table_name = table_name
self._pool: Optional["asyncpg.Pool"] = None
self._pool_config = pool_config or {}
# 并发控制信号量
self._semaphore = None
if max_concurrency is not None and max_concurrency > 0:
self._semaphore = asyncio.Semaphore(max_concurrency)
# 注意:连接池的异步初始化延迟到第一次使用时
# 表结构创建也延迟到第一次操作时
async def _get_pool(self):
"""获取或创建 asyncpg 连接池。"""
if self._pool is None:
import asyncpg
min_size = self._pool_config.get("min_size", 2)
max_size = self._pool_config.get("max_size", 10)
try:
self._pool = await asyncpg.create_pool(
dsn=self.dsn,
min_size=min_size,
max_size=max_size
)
logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}")
# 初始化表结构
await self._create_table()
except Exception as e:
raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}")
return self._pool
async def _create_table(self):
"""创建存储表(如果不存在)。"""
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value JSONB NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
)
""")
logger.info(f"{self.table_name} 已就绪")
async def _with_concurrency_control(self, coro):
"""使用信号量控制并发执行。"""
if self._semaphore is None:
return await coro
async with self._semaphore:
return await coro
# --- 同步方法(保持兼容性,但功能有限)---
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""不支持同步操作,请使用异步 amget 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。")
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""不支持同步操作,请使用异步 amset 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。")
def mdelete(self, keys: Sequence[str]) -> None:
"""不支持同步操作,请使用异步 amdelete 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。")
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""不支持同步操作,请使用异步 ayield_keys 方法。"""
raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。")
# --- 异步方法(真正的实现)---
async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]:
"""异步批量获取文档。"""
if not keys:
return []
async def _amget():
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)",
keys
)
result_map = {}
for row in rows:
val = row['value']
if isinstance(val, str):
val = json.loads(val)
if isinstance(val, dict) and 'page_content' in val:
result_map[row['key']] = Document(**val)
else:
result_map[row['key']] = val
return [result_map.get(key) for key in keys]
return await self._with_concurrency_control(_amget())
async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
"""异步批量设置文档。"""
if not key_value_pairs:
return
async def _amset():
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.executemany(
f"""
INSERT INTO {self.table_name} (key, value)
VALUES ($1, $2)
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
""",
[
(k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False))
for k, v in key_value_pairs
]
)
logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档")
await self._with_concurrency_control(_amset())
async def amdelete(self, keys: Sequence[str]) -> None:
"""异步批量删除文档。"""
if not keys:
return
async def _amdelete():
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
f"DELETE FROM {self.table_name} WHERE key = ANY($1)",
keys
)
logger.debug(f"已异步批量删除 {len(keys)} 个文档")
await self._with_concurrency_control(_amdelete())
async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
"""异步迭代所有键。
注意:这是一个异步生成器,需要使用 async for 迭代。
"""
pool = await self._get_pool()
async with pool.acquire() as conn:
if prefix:
rows = await conn.fetch(
f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key",
f"{prefix}%"
)
else:
rows = await conn.fetch(
f"SELECT key FROM {self.table_name} ORDER BY key"
)
for row in rows:
yield row['key']
async def aclose(self) -> None:
"""异步关闭连接池,释放资源。"""
if self._pool:
await self._pool.close()
self._pool = None
logger.info("PostgreSQL 异步连接池已关闭")
def close(self) -> None:
"""同步关闭连接池(功能有限)。
注意:在异步环境中,请使用 aclose 方法。
"""
pass