106 lines
2.7 KiB
Python
106 lines
2.7 KiB
Python
# rag_core/client.py
|
||
import os
|
||
from .config import QDRANT_URL, QDRANT_API_KEY, DOCSTORE_URI
|
||
from qdrant_client import QdrantClient, AsyncQdrantClient
|
||
from typing import Tuple
|
||
from langchain_core.stores import BaseStore
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||
"""
|
||
创建并返回一个配置好的 Qdrant 客户端。
|
||
|
||
Args:
|
||
timeout: 请求超时时间(秒),默认 300 秒(索引构建需要较长超时)。
|
||
|
||
Returns:
|
||
配置好的 QdrantClient 实例。
|
||
|
||
Raises:
|
||
ValueError: 如果 QDRANT_URL 未配置。
|
||
"""
|
||
if not QDRANT_URL:
|
||
raise ValueError("Qdrant URL 未配置")
|
||
|
||
client_kwargs = {
|
||
"url": QDRANT_URL,
|
||
"timeout": timeout,
|
||
}
|
||
if QDRANT_API_KEY:
|
||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||
|
||
return QdrantClient(**client_kwargs)
|
||
|
||
|
||
def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient:
|
||
"""
|
||
创建并返回一个配置好的 Qdrant 异步客户端。
|
||
|
||
Args:
|
||
timeout: 请求超时时间(秒),默认 300 秒。
|
||
|
||
Returns:
|
||
配置好的 AsyncQdrantClient 实例。
|
||
|
||
Raises:
|
||
ValueError: 如果 QDRANT_URL 未配置。
|
||
"""
|
||
if not QDRANT_URL:
|
||
raise ValueError("Qdrant URL 未配置")
|
||
|
||
client_kwargs = {
|
||
"url": QDRANT_URL,
|
||
"timeout": timeout,
|
||
}
|
||
if QDRANT_API_KEY:
|
||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||
|
||
return AsyncQdrantClient(**client_kwargs)
|
||
|
||
|
||
def get_docstore_uri() -> str:
|
||
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
||
return DOCSTORE_URI
|
||
|
||
|
||
def create_docstore(
|
||
table_name: str = "parent_documents",
|
||
pool_config: dict | None = None,
|
||
max_concurrency: int | None = None
|
||
) -> Tuple[BaseStore, str]:
|
||
"""
|
||
工厂函数,创建 PostgreSQL 文档存储。
|
||
|
||
Args:
|
||
table_name: PostgreSQL 表名(默认:parent_documents)
|
||
pool_config: 连接池配置
|
||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||
|
||
Returns:
|
||
元组 (存储实例, 连接字符串)
|
||
|
||
Raises:
|
||
ImportError: 缺少必要的依赖
|
||
|
||
Example:
|
||
>>> # 创建 PostgreSQL 存储
|
||
>>> store, conn = create_docstore(
|
||
... table_name="parent_docs",
|
||
... max_concurrency=10
|
||
... )
|
||
"""
|
||
from .doc_store import PostgresDocStore
|
||
|
||
conn_str = get_docstore_uri()
|
||
store = PostgresDocStore(
|
||
connection_string=conn_str,
|
||
table_name=table_name,
|
||
pool_config=pool_config,
|
||
max_concurrency=max_concurrency
|
||
)
|
||
logger.info(f"PostgreSQL docstore 已创建: {table_name}")
|
||
return store, conn_str
|