Files
ailine/backend/rag_core/client.py
root 9841f47432
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
refactor: 重构RAG核心组件,简化代码结构和测试文件
2026-05-04 17:58:10 +08:00

106 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rag_core/client.py
import os
from .config import QDRANT_URL, QDRANT_API_KEY, DOCSTORE_URI
from qdrant_client import QdrantClient, AsyncQdrantClient
from typing import Tuple
from langchain_core.stores import BaseStore
import logging
logger = logging.getLogger(__name__)
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
"""
创建并返回一个配置好的 Qdrant 客户端。
Args:
timeout: 请求超时时间(秒),默认 300 秒(索引构建需要较长超时)。
Returns:
配置好的 QdrantClient 实例。
Raises:
ValueError: 如果 QDRANT_URL 未配置。
"""
if not QDRANT_URL:
raise ValueError("Qdrant URL 未配置")
client_kwargs = {
"url": QDRANT_URL,
"timeout": timeout,
}
if QDRANT_API_KEY:
client_kwargs["api_key"] = QDRANT_API_KEY
return QdrantClient(**client_kwargs)
def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient:
"""
创建并返回一个配置好的 Qdrant 异步客户端。
Args:
timeout: 请求超时时间(秒),默认 300 秒。
Returns:
配置好的 AsyncQdrantClient 实例。
Raises:
ValueError: 如果 QDRANT_URL 未配置。
"""
if not QDRANT_URL:
raise ValueError("Qdrant URL 未配置")
client_kwargs = {
"url": QDRANT_URL,
"timeout": timeout,
}
if QDRANT_API_KEY:
client_kwargs["api_key"] = QDRANT_API_KEY
return AsyncQdrantClient(**client_kwargs)
def get_docstore_uri() -> str:
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
return DOCSTORE_URI
def create_docstore(
table_name: str = "parent_documents",
pool_config: dict | None = None,
max_concurrency: int | None = None
) -> Tuple[BaseStore, str]:
"""
工厂函数,创建 PostgreSQL 文档存储。
Args:
table_name: PostgreSQL 表名默认parent_documents
pool_config: 连接池配置
max_concurrency: 最大并发操作数,如果为 None 则不限制
Returns:
元组 (存储实例, 连接字符串)
Raises:
ImportError: 缺少必要的依赖
Example:
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... table_name="parent_docs",
... max_concurrency=10
... )
"""
from .doc_store import PostgresDocStore
conn_str = get_docstore_uri()
store = PostgresDocStore(
connection_string=conn_str,
table_name=table_name,
pool_config=pool_config,
max_concurrency=max_concurrency
)
logger.info(f"PostgreSQL docstore 已创建: {table_name}")
return store, conn_str