重构:添加模型服务模块,支持嵌入和重排服务的自动降级
新增功能: - 创建 app/model_services 模块,提供统一的模型服务获取接口 - 实现 BaseServiceProvider 基类和 FallbackServiceChain 降级链 - 实现 get_embedding_service():优先本地 llama.cpp,降级到智谱 API - 实现 get_rerank_service():优先本地 llama.cpp,降级到智谱 API - 支持单例管理,确保全局只有一个服务实例 修改内容: - 更新 app/config.py,添加智谱 API 相关配置 - 修改 rag_core/vector_store.py:支持接受外部传入的 embeddings - 修改 rag_core/retriever_factory.py:支持接受外部传入的 embeddings - 修改 app/agent/rag_initializer.py:使用 get_embedding_service() - 修改 app/rag/pipeline.py:使用 get_rerank_service() - 修改 app/memory/mem0_client.py:智能判断可用服务配置 mem0 - 修改 rag_indexer/index_builder.py:支持使用新服务,保持向后兼容 - 修改 rag_indexer/config.py:添加智谱配置 环境变量: - ZHIPUAI_API_KEY:智谱 API 密钥(必选) - ZHIPU_EMBEDDING_MODEL:可选,默认 embedding-3 - ZHIPU_RERANK_MODEL:可选,默认 rerank-2 - ZHIPU_API_BASE:可选,默认 https://open.bigmodel.cn/api/paas/v4
This commit is contained in:
@@ -31,6 +31,13 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
|
||||
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
|
||||
|
||||
# 尝试导入新的 model_services(如果可用)
|
||||
try:
|
||||
from app.model_services import get_embedding_service
|
||||
HAS_MODEL_SERVICES = True
|
||||
except ImportError:
|
||||
HAS_MODEL_SERVICES = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------- 配置数据类 ----------
|
||||
@@ -69,10 +76,11 @@ class IndexBuilderConfig:
|
||||
class IndexBuilder:
|
||||
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""
|
||||
|
||||
def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs):
|
||||
def __init__(self, config: Optional[IndexBuilderConfig] = None, embeddings: Optional[Embeddings] = None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
config: 索引构建器配置对象,优先级高于 kwargs
|
||||
embeddings: 可选的外部嵌入模型实例,如果提供则使用它
|
||||
**kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留)
|
||||
"""
|
||||
if config is None:
|
||||
@@ -88,12 +96,29 @@ class IndexBuilder:
|
||||
|
||||
# 初始化基础组件
|
||||
self.loader = DocumentLoader()
|
||||
self.embedder = LlamaCppEmbedder()
|
||||
self.embeddings: Embeddings = self.embedder.as_langchain_embeddings()
|
||||
|
||||
# 设置嵌入模型 - 优先使用外部提供的,然后尝试使用新服务,最后回退到原来的方式
|
||||
if embeddings is not None:
|
||||
self.embeddings = embeddings
|
||||
self.embedder = None
|
||||
logger.info("使用外部提供的嵌入模型")
|
||||
elif HAS_MODEL_SERVICES:
|
||||
try:
|
||||
self.embeddings = get_embedding_service()
|
||||
self.embedder = None
|
||||
logger.info("使用 model_services 提供的嵌入服务")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取嵌入服务失败,回退到 LlamaCppEmbedder: {e}")
|
||||
self.embedder = LlamaCppEmbedder()
|
||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
||||
else:
|
||||
self.embedder = LlamaCppEmbedder()
|
||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
||||
|
||||
# 初始化向量存储
|
||||
self.vector_store = QdrantVectorStore(
|
||||
collection_name=config.collection_name,
|
||||
embeddings=self.embeddings if self.embedder is None else None,
|
||||
)
|
||||
|
||||
# 根据切分类型初始化相关组件
|
||||
@@ -149,6 +174,7 @@ class IndexBuilder:
|
||||
child_splitter=self.child_splitter,
|
||||
docstore=self.docstore,
|
||||
search_k=cfg.search_k,
|
||||
embeddings=self.embeddings if self.embedder is None else None,
|
||||
)
|
||||
logger.info("ParentDocumentRetriever 初始化完成")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user