- 创建 download_sparse_model.py 脚本用于下载稀疏模型到本地 - 添加 SPARSE_MODEL_PATH 和 SPARSE_MODEL_NAME 配置 - 修改 retriever.py 和 index_builder.py 使用 cache_dir - 更新 .gitignore 排除 models/ 目录 - 更新 Dockerfile 在构建时下载稀疏模型
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -21,6 +21,7 @@
|
|||||||
!test/**
|
!test/**
|
||||||
!.gitea/
|
!.gitea/
|
||||||
!.gitea/**
|
!.gitea/**
|
||||||
|
!download_sparse_model.py
|
||||||
|
|
||||||
# 3. 放行必要的根目录文件
|
# 3. 放行必要的根目录文件
|
||||||
!.gitignore
|
!.gitignore
|
||||||
@@ -40,6 +41,9 @@ __pycache__/
|
|||||||
*.so
|
*.so
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# 模型目录(不提交到 Git,在 Docker 构建时下载)
|
||||||
|
models/
|
||||||
|
|
||||||
# 包含敏感信息的环境变量配置(绝对不能传)
|
# 包含敏感信息的环境变量配置(绝对不能传)
|
||||||
.env
|
.env
|
||||||
.env.local
|
.env.local
|
||||||
|
|||||||
@@ -51,6 +51,10 @@ ZHIPU_RERANK_MODEL = _get_str("ZHIPU_RERANK_MODEL") or "rerank-2"
|
|||||||
ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4"
|
ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4"
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 稀疏模型配置 ==========
|
||||||
|
SPARSE_MODEL_PATH = _get_str("SPARSE_MODEL_PATH") or "./models/sparse"
|
||||||
|
SPARSE_MODEL_NAME = _get_str("SPARSE_MODEL_NAME") or "Qdrant/bm25"
|
||||||
|
|
||||||
# ========== llama.cpp 服务配置(URL + API密钥 配对) ==========
|
# ========== llama.cpp 服务配置(URL + API密钥 配对) ==========
|
||||||
# 主 LLM 服务
|
# 主 LLM 服务
|
||||||
VLLM_BASE_URL = _get_str("VLLM_BASE_URL")
|
VLLM_BASE_URL = _get_str("VLLM_BASE_URL")
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from langchain_core.retrievers import BaseRetriever
|
|||||||
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
||||||
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
||||||
from app.model_services import get_embedding_service
|
from app.model_services import get_embedding_service
|
||||||
|
from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME
|
||||||
from app.logger import info, warning
|
from app.logger import info, warning
|
||||||
|
|
||||||
# 模块级常量
|
# 模块级常量
|
||||||
@@ -134,9 +135,12 @@ def create_hybrid_retriever(
|
|||||||
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
|
raise ValueError(f"Qdrant 集合 '{collection_name}' 不存在")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 初始化稀疏嵌入
|
# 初始化稀疏嵌入(使用本地缓存目录)
|
||||||
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")
|
sparse_embeddings = FastEmbedSparse(
|
||||||
info("✅ FastEmbedSparse 初始化成功")
|
model_name=SPARSE_MODEL_NAME,
|
||||||
|
cache_dir=SPARSE_MODEL_PATH
|
||||||
|
)
|
||||||
|
info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})")
|
||||||
|
|
||||||
# 创建混合模式的 QdrantVectorStore
|
# 创建混合模式的 QdrantVectorStore
|
||||||
vector_store = QdrantVectorStore(
|
vector_store = QdrantVectorStore(
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ ENV BACKEND_PORT=8079
|
|||||||
ENV MEMORY_SUMMARIZE_INTERVAL=10
|
ENV MEMORY_SUMMARIZE_INTERVAL=10
|
||||||
ENV ENABLE_GRAPH_TRACE=false
|
ENV ENABLE_GRAPH_TRACE=false
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 稀疏模型配置
|
||||||
|
# =============================================================================
|
||||||
|
ENV SPARSE_MODEL_PATH=/app/models/sparse
|
||||||
|
ENV SPARSE_MODEL_NAME=Qdrant/bm25
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 日志配置(生产环境默认值)
|
# 日志配置(生产环境默认值)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -74,6 +80,14 @@ RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
|||||||
COPY backend/requirements.txt .
|
COPY backend/requirements.txt .
|
||||||
RUN pip install --no-cache-dir --default-timeout=300 -r requirements.txt
|
RUN pip install --no-cache-dir --default-timeout=300 -r requirements.txt
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 下载稀疏模型(关键步骤:在构建阶段下载到固定目录)
|
||||||
|
# =============================================================================
|
||||||
|
RUN mkdir -p /app/models/sparse
|
||||||
|
COPY download_sparse_model.py .
|
||||||
|
RUN python download_sparse_model.py --cache-dir /app/models/sparse --model-name Qdrant/bm25 && \
|
||||||
|
rm -f download_sparse_model.py
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 复制项目代码
|
# 复制项目代码
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
73
download_sparse_model.py
Normal file
73
download_sparse_model.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
下载稀疏嵌入模型到本地目录。
|
||||||
|
仅需在开发机或构建镜像时执行一次。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 添加 backend 目录到路径
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent / "backend"))
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(cache_dir: str = "./models/sparse", model_name: str = "Qdrant/bm25"):
|
||||||
|
"""
|
||||||
|
下载稀疏嵌入模型到指定目录。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_dir: 模型缓存目录
|
||||||
|
model_name: 模型名称
|
||||||
|
"""
|
||||||
|
cache_path = Path(cache_dir)
|
||||||
|
cache_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"准备下载模型 {model_name} 到 {cache_path.absolute()}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastembed import SparseTextEmbedding
|
||||||
|
|
||||||
|
# 下载并缓存模型
|
||||||
|
model = SparseTextEmbedding(model_name=model_name, cache_dir=str(cache_path))
|
||||||
|
logger.info(f"✅ 模型 {model_name} 下载/加载成功")
|
||||||
|
|
||||||
|
# 测试一下
|
||||||
|
test_result = model.embed(["测试文本"])
|
||||||
|
logger.info(f"✅ 模型测试成功,稀疏向量维度: {len(list(test_result)[0])}")
|
||||||
|
|
||||||
|
logger.info("✅ 所有步骤完成!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 模型下载失败: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="下载稀疏嵌入模型")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-dir",
|
||||||
|
default="./models/sparse",
|
||||||
|
help="模型缓存目录 (默认: ./models/sparse)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
default="Qdrant/bm25",
|
||||||
|
help="模型名称 (默认: Qdrant/bm25)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
success = download_model(args.cache_dir, args.model_name)
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
@@ -41,6 +41,15 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_MODEL_SERVICES = False
|
HAS_MODEL_SERVICES = False
|
||||||
|
|
||||||
|
# 尝试导入稀疏模型配置(如果可用)
|
||||||
|
try:
|
||||||
|
from app.config import SPARSE_MODEL_PATH, SPARSE_MODEL_NAME
|
||||||
|
HAS_SPARSE_CONFIG = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_SPARSE_CONFIG = False
|
||||||
|
SPARSE_MODEL_PATH = "./models/sparse"
|
||||||
|
SPARSE_MODEL_NAME = "Qdrant/bm25"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------- 配置数据类 ----------
|
# ---------- 配置数据类 ----------
|
||||||
@@ -118,10 +127,13 @@ class IndexBuilder:
|
|||||||
self.embedder = LlamaCppEmbedder()
|
self.embedder = LlamaCppEmbedder()
|
||||||
self.embeddings = self.embedder.as_langchain_embeddings()
|
self.embeddings = self.embedder.as_langchain_embeddings()
|
||||||
|
|
||||||
# 初始化稀疏嵌入
|
# 初始化稀疏嵌入(使用本地缓存目录)
|
||||||
from langchain_qdrant import FastEmbedSparse, RetrievalMode
|
from langchain_qdrant import FastEmbedSparse, RetrievalMode
|
||||||
self.sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")
|
self.sparse_embeddings = FastEmbedSparse(
|
||||||
logger.info("✅ FastEmbedSparse 初始化成功")
|
model_name=SPARSE_MODEL_NAME,
|
||||||
|
cache_dir=SPARSE_MODEL_PATH
|
||||||
|
)
|
||||||
|
logger.info(f"✅ FastEmbedSparse 初始化成功 (cache_dir={SPARSE_MODEL_PATH})")
|
||||||
|
|
||||||
# 初始化向量存储(混合检索模式)
|
# 初始化向量存储(混合检索模式)
|
||||||
self.vector_store = QdrantVectorStore(
|
self.vector_store = QdrantVectorStore(
|
||||||
|
|||||||
Reference in New Issue
Block a user