Files
ailine/backend/rag_core/client.py

106 lines
2.7 KiB
Python
Raw Normal View History

2026-04-21 11:02:16 +08:00
# rag_core/client.py
import os
from .config import QDRANT_URL, QDRANT_API_KEY, DOCSTORE_URI
from qdrant_client import QdrantClient, AsyncQdrantClient
from typing import Tuple
from langchain_core.stores import BaseStore
import logging
logger = logging.getLogger(__name__)
2026-04-21 11:02:16 +08:00
2026-04-21 19:06:34 +08:00
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
"""
创建并返回一个配置好的 Qdrant 客户端
2026-04-21 11:02:16 +08:00
2026-04-21 19:06:34 +08:00
Args:
timeout: 请求超时时间默认 300 索引构建需要较长超时
2026-04-21 11:02:16 +08:00
2026-04-21 19:06:34 +08:00
Returns:
配置好的 QdrantClient 实例
Raises:
ValueError: 如果 QDRANT_URL 未配置
"""
if not QDRANT_URL:
2026-04-21 11:02:16 +08:00
raise ValueError("Qdrant URL 未配置")
client_kwargs = {
2026-04-21 19:06:34 +08:00
"url": QDRANT_URL,
2026-04-21 11:02:16 +08:00
"timeout": timeout,
}
2026-04-21 19:06:34 +08:00
if QDRANT_API_KEY:
client_kwargs["api_key"] = QDRANT_API_KEY
2026-04-21 11:02:16 +08:00
2026-04-21 19:06:34 +08:00
return QdrantClient(**client_kwargs)
def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient:
"""
创建并返回一个配置好的 Qdrant 异步客户端
Args:
timeout: 请求超时时间默认 300
Returns:
配置好的 AsyncQdrantClient 实例
Raises:
ValueError: 如果 QDRANT_URL 未配置
"""
if not QDRANT_URL:
raise ValueError("Qdrant URL 未配置")
client_kwargs = {
"url": QDRANT_URL,
"timeout": timeout,
}
if QDRANT_API_KEY:
client_kwargs["api_key"] = QDRANT_API_KEY
return AsyncQdrantClient(**client_kwargs)
def get_docstore_uri() -> str:
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
return DOCSTORE_URI
def create_docstore(
table_name: str = "parent_documents",
pool_config: dict | None = None,
max_concurrency: int | None = None
) -> Tuple[BaseStore, str]:
"""
工厂函数创建 PostgreSQL 文档存储
Args:
table_name: PostgreSQL 表名默认parent_documents
pool_config: 连接池配置
max_concurrency: 最大并发操作数如果为 None 则不限制
Returns:
元组 (存储实例, 连接字符串)
Raises:
ImportError: 缺少必要的依赖
Example:
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... table_name="parent_docs",
... max_concurrency=10
... )
"""
from .doc_store import PostgresDocStore
conn_str = get_docstore_uri()
store = PostgresDocStore(
connection_string=conn_str,
table_name=table_name,
pool_config=pool_config,
max_concurrency=max_concurrency
)
logger.info(f"PostgreSQL docstore 已创建: {table_name}")
return store, conn_str