Files
ailine/backend/app/model_services/rerank_services.py
root 8db63e7a8d 重构:添加模型服务模块,支持嵌入和重排服务的自动降级
新增功能:
- 创建 app/model_services 模块,提供统一的模型服务获取接口
- 实现 BaseServiceProvider 基类和 FallbackServiceChain 降级链
- 实现 get_embedding_service():优先本地 llama.cpp,降级到智谱 API
- 实现 get_rerank_service():优先本地 llama.cpp,降级到智谱 API
- 支持单例管理,确保全局只有一个服务实例

修改内容:
- 更新 app/config.py,添加智谱 API 相关配置
- 修改 rag_core/vector_store.py:支持接受外部传入的 embeddings
- 修改 rag_core/retriever_factory.py:支持接受外部传入的 embeddings
- 修改 app/agent/rag_initializer.py:使用 get_embedding_service()
- 修改 app/rag/pipeline.py:使用 get_rerank_service()
- 修改 app/memory/mem0_client.py:智能判断可用服务配置 mem0
- 修改 rag_indexer/index_builder.py:支持使用新服务,保持向后兼容
- 修改 rag_indexer/config.py:添加智谱配置

环境变量:
- ZHIPUAI_API_KEY:智谱 API 密钥(必选)
- ZHIPU_EMBEDDING_MODEL:可选,默认 embedding-3
- ZHIPU_RERANK_MODEL:可选,默认 rerank-2
- ZHIPU_API_BASE:可选,默认 https://open.bigmodel.cn/api/paas/v4
2026-04-24 22:52:36 +08:00

234 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
重排模型服务模块
本模块提供统一的重排模型服务获取接口,支持自动降级:
1. 优先使用本地 llama.cpp 重排服务
2. 本地服务不可用时,自动降级到智谱 API 重排服务
主要功能:
- LocalLlamaCppRerankProvider本地 llama.cpp 重排服务提供者
- ZhipuRerankProvider智谱 API 重排服务提供者
- get_rerank_service():获取重排服务的统一接口
"""
import logging
from typing import List
import requests
from langchain_core.documents import Document
from .base import (
BaseServiceProvider,
FallbackServiceChain,
SingletonServiceManager
)
from ..config import (
LLAMACPP_RERANKER_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,
ZHIPU_RERANK_MODEL,
ZHIPU_API_BASE
)
logger = logging.getLogger(__name__)
class BaseReranker:
"""
重排器基类,定义统一的接口
"""
def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]:
"""
对文档进行重排序
Args:
documents: 待排序的文档列表
query: 查询字符串
top_n: 返回前 N 个结果
Returns:
排序后的文档列表
"""
raise NotImplementedError
class LocalLlamaCppReranker(BaseReranker):
"""
使用远程 llama.cpp 服务对检索结果重排序
"""
def __init__(self, base_url: str, api_key: str, model: str = "bge-reranker-v2-m3", timeout: int = 60):
self.base_url = base_url
self.api_key = api_key
self.model = model
self.timeout = timeout
self.endpoint = f"{self.base_url}/rerank"
def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]:
"""
对文档进行重排序
"""
if not documents:
return []
# 准备请求体
payload = {
"model": self.model,
"query": query,
"documents": [doc.page_content for doc in documents],
"top_n": top_n
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
try:
response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout)
response.raise_for_status()
results = response.json()
# 解析返回结果
sorted_indices = [item["index"] for item in results["results"]]
sorted_docs = [documents[idx] for idx in sorted_indices]
return sorted_docs
except Exception as e:
logger.warning(f"远程重排序过程出错,返回原始前 {top_n} 个结果: {e}")
return documents[:top_n]
class ZhipuReranker(BaseReranker):
"""
使用智谱 API 对检索结果重排序
"""
def __init__(self, model: str | None = None):
self.model = model or ZHIPU_RERANK_MODEL
self.api_key = ZHIPUAI_API_KEY
def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]:
"""
对文档进行重排序
"""
if not documents:
return []
try:
from zhipuai import ZhipuAI
client = ZhipuAI(api_key=self.api_key)
response = client.rerank.create(
model=self.model,
query=query,
documents=[doc.page_content for doc in documents],
top_n=top_n
)
sorted_indices = [item.index for item in response.results]
sorted_docs = [documents[idx] for idx in sorted_indices]
return sorted_docs
except Exception as e:
logger.warning(f"智谱重排序过程出错,返回原始前 {top_n} 个结果: {e}")
return documents[:top_n]
class LocalLlamaCppRerankProvider(BaseServiceProvider[BaseReranker]):
"""
本地 llama.cpp 重排服务提供者
"""
def __init__(self, model: str = "bge-reranker-v2-m3"):
super().__init__("local_llamacpp_rerank")
self._model = model
def is_available(self) -> bool:
"""
检查本地 llama.cpp 重排服务是否可用
"""
if not LLAMACPP_RERANKER_URL:
logger.warning("LLAMACPP_RERANKER_URL 未配置")
return False
try:
# 测试重排服务
test_docs = [Document(page_content="test document 1"), Document(page_content="test document 2")]
reranker = LocalLlamaCppReranker(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
model=self._model
)
result = reranker.compress_documents(test_docs, "test query", top_n=1)
logger.info(f"本地 llama.cpp 重排服务可用")
return True
except Exception as e:
logger.warning(f"本地 llama.cpp 重排服务不可用: {e}")
return False
def get_service(self) -> BaseReranker:
"""
获取本地 llama.cpp 重排服务
"""
if self._service_instance is None:
self._service_instance = LocalLlamaCppReranker(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
model=self._model
)
return self._service_instance
class ZhipuRerankProvider(BaseServiceProvider[BaseReranker]):
"""
智谱 API 重排服务提供者
"""
def __init__(self, model: str | None = None):
super().__init__("zhipu_rerank")
self._model = model or ZHIPU_RERANK_MODEL
def is_available(self) -> bool:
"""
检查智谱 API 重排服务是否可用
"""
if not ZHIPUAI_API_KEY:
logger.warning("ZHIPUAI_API_KEY 未配置")
return False
try:
# 测试重排服务
test_docs = [Document(page_content="test document 1"), Document(page_content="test document 2")]
reranker = ZhipuReranker(model=self._model)
result = reranker.compress_documents(test_docs, "test query", top_n=1)
logger.info(f"智谱重排服务可用")
return True
except ImportError:
logger.warning("zhipuai 库未安装")
return False
except Exception as e:
logger.warning(f"智谱重排服务不可用: {e}")
return False
def get_service(self) -> BaseReranker:
"""
获取智谱 API 重排服务
"""
if self._service_instance is None:
self._service_instance = ZhipuReranker(model=self._model)
return self._service_instance
def get_rerank_service() -> BaseReranker:
"""
获取重排服务(带自动降级)
Returns:
BaseReranker: 重排服务实例
"""
def _create_chain():
primary = LocalLlamaCppRerankProvider()
fallback = ZhipuRerankProvider()
return FallbackServiceChain(primary, [fallback])
chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain)
return chain.get_available_service()