74 lines
2.0 KiB
Python
74 lines
2.0 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
下载稀疏嵌入模型到本地目录。
|
||
|
|
仅需在开发机或构建镜像时执行一次。
|
||
|
|
"""
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
# 配置日志
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format="%(asctime)s - %(levelname)s - %(message)s"
|
||
|
|
)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# 添加 backend 目录到路径
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent / "backend"))
|
||
|
|
|
||
|
|
|
||
|
|
def download_model(cache_dir: str = "./models/sparse", model_name: str = "Qdrant/bm25"):
|
||
|
|
"""
|
||
|
|
下载稀疏嵌入模型到指定目录。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
cache_dir: 模型缓存目录
|
||
|
|
model_name: 模型名称
|
||
|
|
"""
|
||
|
|
cache_path = Path(cache_dir)
|
||
|
|
cache_path.mkdir(parents=True, exist_ok=True)
|
||
|
|
logger.info(f"准备下载模型 {model_name} 到 {cache_path.absolute()}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
from fastembed import SparseTextEmbedding
|
||
|
|
|
||
|
|
# 下载并缓存模型
|
||
|
|
model = SparseTextEmbedding(model_name=model_name, cache_dir=str(cache_path))
|
||
|
|
logger.info(f"✅ 模型 {model_name} 下载/加载成功")
|
||
|
|
|
||
|
|
# 测试一下
|
||
|
|
test_result = model.embed(["测试文本"])
|
||
|
|
logger.info(f"✅ 模型测试成功,稀疏向量维度: {len(list(test_result)[0])}")
|
||
|
|
|
||
|
|
logger.info("✅ 所有步骤完成!")
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"❌ 模型下载失败: {e}")
|
||
|
|
import traceback
|
||
|
|
logger.error(traceback.format_exc())
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import argparse
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description="下载稀疏嵌入模型")
|
||
|
|
parser.add_argument(
|
||
|
|
"--cache-dir",
|
||
|
|
default="./models/sparse",
|
||
|
|
help="模型缓存目录 (默认: ./models/sparse)"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--model-name",
|
||
|
|
default="Qdrant/bm25",
|
||
|
|
help="模型名称 (默认: Qdrant/bm25)"
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
success = download_model(args.cache_dir, args.model_name)
|
||
|
|
sys.exit(0 if success else 1)
|