refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
This commit is contained in:
@@ -6,8 +6,13 @@ RAG Core - 公共 RAG 组件包
|
||||
from .embedders import get_embeddings, get_embedding_dimension
|
||||
from .vector_store import QdrantHybridStore
|
||||
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .client import create_qdrant_client, create_async_qdrant_client
|
||||
from .doc_store import PostgresDocStore
|
||||
from .client import (
|
||||
create_qdrant_client,
|
||||
create_async_qdrant_client,
|
||||
create_docstore,
|
||||
get_docstore_uri
|
||||
)
|
||||
from .config import (
|
||||
QDRANT_URL,
|
||||
QDRANT_API_KEY,
|
||||
@@ -24,14 +29,15 @@ __all__ = [
|
||||
"QdrantHybridStore",
|
||||
"BM25SparseEmbedder",
|
||||
"get_sparse_embedder",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"get_docstore_uri",
|
||||
"create_qdrant_client",
|
||||
"create_async_qdrant_client",
|
||||
"QDRANT_URL",
|
||||
"QDRANT_API_KEY",
|
||||
"LLAMACPP_EMBEDDING_URL",
|
||||
"LLAMACPP_API_KEY",
|
||||
"DB_URI",
|
||||
"DOCSTORE_URI",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"create_qdrant_client",
|
||||
"create_async_qdrant_client",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
# rag_core/client.py
|
||||
import os
|
||||
from .config import QDRANT_URL, QDRANT_API_KEY
|
||||
from .config import QDRANT_URL, QDRANT_API_KEY, DOCSTORE_URI
|
||||
from qdrant_client import QdrantClient, AsyncQdrantClient
|
||||
from typing import Tuple
|
||||
from langchain_core.stores import BaseStore
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||||
@@ -54,3 +59,47 @@ def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient:
|
||||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||||
|
||||
return AsyncQdrantClient(**client_kwargs)
|
||||
|
||||
|
||||
def get_docstore_uri() -> str:
|
||||
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
||||
return DOCSTORE_URI
|
||||
|
||||
|
||||
def create_docstore(
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: dict | None = None,
|
||||
max_concurrency: int | None = None
|
||||
) -> Tuple[BaseStore, str]:
|
||||
"""
|
||||
工厂函数,创建 PostgreSQL 文档存储。
|
||||
|
||||
Args:
|
||||
table_name: PostgreSQL 表名(默认:parent_documents)
|
||||
pool_config: 连接池配置
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Returns:
|
||||
元组 (存储实例, 连接字符串)
|
||||
|
||||
Raises:
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
Example:
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... table_name="parent_docs",
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
from .doc_store import PostgresDocStore
|
||||
|
||||
conn_str = get_docstore_uri()
|
||||
store = PostgresDocStore(
|
||||
connection_string=conn_str,
|
||||
table_name=table_name,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
logger.info(f"PostgreSQL docstore 已创建: {table_name}")
|
||||
return store, conn_str
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
异步 PostgreSQL 存储实现 - 用于生产环境。
|
||||
异步 PostgreSQL 文档存储
|
||||
|
||||
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
|
||||
用于 ParentDocumentRetriever 的父文档存储,支持高并发访问。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -16,6 +16,7 @@ import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresDocStore(BaseStore[str, Any]):
|
||||
"""
|
||||
异步 PostgreSQL 文档存储实现。
|
||||
@@ -49,7 +50,7 @@ class PostgresDocStore(BaseStore[str, Any]):
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL 连接 URL,格式:
|
||||
"postgresql://user:password@host:port/database?sslmode=disable"
|
||||
"postgresql://user:***@host:port/database?sslmode=disable"
|
||||
table_name: 存储表名,默认为 "parent_documents"
|
||||
pool_config: 连接池配置字典,包含:
|
||||
- min_size: 最小连接数(默认 2)
|
||||
@@ -57,17 +58,16 @@ class PostgresDocStore(BaseStore[str, Any]):
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 asyncpg 时抛出
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
Example:
|
||||
>>> store = PostgresDocStore(
|
||||
... "postgresql://user:pass@localhost:5432/mydb",
|
||||
... "postgresql://user:***@localhost:5432/mydb",
|
||||
... table_name="parent_docs",
|
||||
... pool_config={"min_size": 5, "max_size": 20},
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
self.dsn = connection_string
|
||||
self.table_name = table_name
|
||||
@@ -244,3 +244,4 @@ class PostgresDocStore(BaseStore[str, Any]):
|
||||
注意:在异步环境中,请使用 aclose 方法。
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""
|
||||
文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。
|
||||
|
||||
提供 PostgreSQL 存储后端:
|
||||
- PostgresDocStore: PostgreSQL 数据库存储(生产环境)
|
||||
|
||||
示例用法:
|
||||
>>> from rag_core.store import create_docstore
|
||||
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... table_name="parent_docs"
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
from .postgres import PostgresDocStore
|
||||
from .factory import create_docstore, get_docstore_uri
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# 具体实现
|
||||
"PostgresDocStore",
|
||||
|
||||
# 工厂函数
|
||||
"create_docstore",
|
||||
"get_docstore_uri",
|
||||
]
|
||||
@@ -1,56 +0,0 @@
|
||||
"""
|
||||
文档存储工厂 - 创建不同类型的存储实例。
|
||||
|
||||
提供统一的接口来创建本地文件存储或 PostgreSQL 存储。
|
||||
"""
|
||||
|
||||
import os
|
||||
from ..config import DOCSTORE_URI
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
from .postgres import PostgresDocStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_docstore_uri() -> str:
|
||||
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
||||
return DOCSTORE_URI
|
||||
|
||||
|
||||
def create_docstore(
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: dict | None = None,
|
||||
max_concurrency: int | None = None
|
||||
) -> Tuple[BaseStore, str]:
|
||||
"""
|
||||
工厂函数,创建 PostgreSQL 文档存储。
|
||||
|
||||
Args:
|
||||
table_name: PostgreSQL 表名(默认:parent_documents)
|
||||
pool_config: 连接池配置
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Returns:
|
||||
元组 (存储实例, 连接字符串)
|
||||
|
||||
Raises:
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
Example:
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... table_name="parent_docs",
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
conn_str = get_docstore_uri()
|
||||
store = PostgresDocStore(
|
||||
connection_string=conn_str,
|
||||
table_name=table_name,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
return store, conn_str
|
||||
@@ -33,8 +33,6 @@ class QdrantHybridStore:
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
sparse_embedder: Optional[BM25SparseEmbedder] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._client: Optional[QdrantClient] = None
|
||||
@@ -43,13 +41,10 @@ class QdrantHybridStore:
|
||||
self._last_connection_time: Optional[float] = None
|
||||
|
||||
# 稠密嵌入模型
|
||||
if embeddings is None:
|
||||
self.embeddings = get_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
self.embeddings = get_embeddings()
|
||||
|
||||
# 稀疏嵌入模型
|
||||
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
|
||||
self.sparse_embedder = get_sparse_embedder()
|
||||
|
||||
# 集合初始化
|
||||
self.create_collection()
|
||||
@@ -176,7 +171,7 @@ class QdrantHybridStore:
|
||||
texts = [doc.page_content for doc in documents]
|
||||
|
||||
# 生成稠密向量
|
||||
dense_vectors = await self._aembed_texts(texts)
|
||||
dense_vectors = await self.aembed_documents(texts)
|
||||
|
||||
# 生成稀疏向量
|
||||
sparse_vectors = self.sparse_embedder.embed_documents(texts)
|
||||
@@ -210,14 +205,18 @@ class QdrantHybridStore:
|
||||
|
||||
return [p.id for p in points]
|
||||
|
||||
async def _aembed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""异步生成稠密向量(适配同步 Embeddings 接口)"""
|
||||
# 注意:LangChain 的 Embeddings 接口目前主要是同步的
|
||||
# 使用线程池或直接调用(如果 embedding 内部有异步支持)
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""异步生成文本列表的稠密向量"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_documents, texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""异步生成查询的稠密向量"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_query, text)
|
||||
|
||||
# ---------- 异步检索方法 ----------
|
||||
async def asimilarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""
|
||||
@@ -227,7 +226,7 @@ class QdrantHybridStore:
|
||||
client = self.get_async_client()
|
||||
|
||||
# 生成查询向量
|
||||
dense_query = await self._aembed_query(query)
|
||||
dense_query = await self.aembed_query(query)
|
||||
sparse_query = self.sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
@@ -264,12 +263,6 @@ class QdrantHybridStore:
|
||||
logger.debug("混合检索返回 %d 个文档", len(results))
|
||||
return results
|
||||
|
||||
async def _aembed_query(self, text: str) -> List[float]:
|
||||
"""异步生成查询稠密向量"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_query, text)
|
||||
|
||||
# ---------- 同步管理方法(保留,用于初始化和管理) ----------
|
||||
def delete_collection(self):
|
||||
self.get_client().delete_collection(self.collection_name)
|
||||
|
||||
Reference in New Issue
Block a user