diff --git a/.gitignore b/.gitignore index 5c86892..ff42873 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ !test/** !.gitea/ !.gitea/** +!download_sparse_model.py # 3. 放行必要的根目录文件 !.gitignore @@ -40,6 +41,9 @@ __pycache__/ *.so .DS_Store +# 模型目录(不提交到 Git,在 Docker 构建时下载) +models/ + # 包含敏感信息的环境变量配置(绝对不能传) .env .env.local diff --git a/backend/app/config.py b/backend/app/config.py index 986a05d..3a43009 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -51,6 +51,10 @@ ZHIPU_RERANK_MODEL = _get_str("ZHIPU_RERANK_MODEL") or "rerank-2" ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4" +# ========== 稀疏模型配置 ========== +SPARSE_MODEL_PATH = _get_str("SPARSE_MODEL_PATH") or "./models/sparse" +SPARSE_MODEL_NAME = _get_str("SPARSE_MODEL_NAME") or "Qdrant/bm25" + # ========== llama.cpp 服务配置(URL + API密钥 配对) ========== # 主 LLM 服务 VLLM_BASE_URL = _get_str("VLLM_BASE_URL") diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index 65694d1..472b09e 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -28,6 +28,7 @@ from langchain_core.retrievers import BaseRetriever from rag_core import QDRANT_URL, QDRANT_API_KEY from rag_core.client import create_qdrant_client as create_core_qdrant_client from app.model_services import get_embedding_service +from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME from app.logger import info, warning # 模块级常量 @@ -134,9 +135,12 @@ def create_hybrid_retriever( raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在") raise - # 初始化稀疏嵌入 - sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25") - info("✅ FastEmbedSparse 初始化成功") + # 初始化稀疏嵌入(使用本地缓存目录) + sparse_embeddings = FastEmbedSparse( + model_name=SPARSE_MODEL_NAME, + cache_dir=SPARSE_MODEL_PATH + ) + info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})") # 创建混合模式的 QdrantVectorStore vector_store = QdrantVectorStore( diff --git a/docker/backend/Dockerfile b/docker/backend/Dockerfile index f98225a..ea56819 100644 --- a/docker/backend/Dockerfile +++ b/docker/backend/Dockerfile @@ -50,6 +50,12 @@ ENV BACKEND_PORT=8079 ENV MEMORY_SUMMARIZE_INTERVAL=10 ENV ENABLE_GRAPH_TRACE=false +# ============================================================================= +# 稀疏模型配置 +# ============================================================================= +ENV SPARSE_MODEL_PATH=/app/models/sparse +ENV SPARSE_MODEL_NAME=Qdrant/bm25 + # ============================================================================= # 日志配置(生产环境默认值) # ============================================================================= @@ -74,6 +80,14 @@ RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple COPY backend/requirements.txt . RUN pip install --no-cache-dir --default-timeout=300 -r requirements.txt +# ============================================================================= +# 下载稀疏模型(关键步骤:在构建阶段下载到固定目录) +# ============================================================================= +RUN mkdir -p /app/models/sparse +COPY download_sparse_model.py . +RUN python download_sparse_model.py --cache-dir /app/models/sparse --model-name Qdrant/bm25 && \ + rm -f download_sparse_model.py + # ============================================================================= # 复制项目代码 # ============================================================================= diff --git a/download_sparse_model.py b/download_sparse_model.py new file mode 100644 index 0000000..22ff2fe --- /dev/null +++ b/download_sparse_model.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +""" +下载稀疏嵌入模型到本地目录。 +仅需在开发机或构建镜像时执行一次。 +""" + +import logging +import sys +from pathlib import Path + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# 添加 backend 目录到路径 +sys.path.insert(0, str(Path(__file__).parent / "backend")) + + +def download_model(cache_dir: str = "./models/sparse", model_name: str = "Qdrant/bm25"): + """ + 下载稀疏嵌入模型到指定目录。 + + Args: + cache_dir: 模型缓存目录 + model_name: 模型名称 + """ + cache_path = Path(cache_dir) + cache_path.mkdir(parents=True, exist_ok=True) + logger.info(f"准备下载模型 {model_name} 到 {cache_path.absolute()}") + + try: + from fastembed import SparseTextEmbedding + + # 下载并缓存模型 + model = SparseTextEmbedding(model_name=model_name, cache_dir=str(cache_path)) + logger.info(f"✅ 模型 {model_name} 下载/加载成功") + + # 测试一下 + test_result = model.embed(["测试文本"]) + logger.info(f"✅ 模型测试成功,稀疏向量维度: {len(list(test_result)[0])}") + + logger.info("✅ 所有步骤完成!") + return True + + except Exception as e: + logger.error(f"❌ 模型下载失败: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="下载稀疏嵌入模型") + parser.add_argument( + "--cache-dir", + default="./models/sparse", + help="模型缓存目录 (默认: ./models/sparse)" + ) + parser.add_argument( + "--model-name", + default="Qdrant/bm25", + help="模型名称 (默认: Qdrant/bm25)" + ) + + args = parser.parse_args() + + success = download_model(args.cache_dir, args.model_name) + sys.exit(0 if success else 1) diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index a6b8149..e17a56c 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -41,6 +41,15 @@ try: except ImportError: HAS_MODEL_SERVICES = False +# 尝试导入稀疏模型配置(如果可用) +try: + from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME + HAS_SPARSE_CONFIG = True +except ImportError: + HAS_SPARSE_CONFIG = False + SPARSE_MODEL_PATH = "./models/sparse" + SPARSE_MODEL_NAME = "Qdrant/bm25" + logger = logging.getLogger(__name__) # ---------- 配置数据类 ---------- @@ -118,10 +127,13 @@ class IndexBuilder: self.embedder = LlamaCppEmbedder() self.embeddings = self.embedder.as_langchain_embeddings() - # 初始化稀疏嵌入 + # 初始化稀疏嵌入(使用本地缓存目录) from langchain_qdrant import FastEmbedSparse, RetrievalMode - self.sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25") - logger.info("✅ FastEmbedSparse 初始化成功") + self.sparse_embeddings = FastEmbedSparse( + model_name=SPARSE_MODEL_NAME, + cache_dir=SPARSE_MODEL_PATH + ) + logger.info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})") # 初始化向量存储(混合检索模式) self.vector_store = QdrantVectorStore(