本地RAG尝试
This commit is contained in:
141
app/rag/reranker.py
Normal file
141
app/rag/reranker.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Cross-Encoder 重排序器
|
||||
|
||||
使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||
from langchain_core.documents import Document
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
|
||||
class CrossEncoderReranker:
|
||||
"""
|
||||
Cross-Encoder 重排序器包装类
|
||||
|
||||
支持 BAAI/bge-reranker-base 等模型。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "BAAI/bge-reranker-base",
|
||||
top_n: int = 5,
|
||||
device: Optional[str] = None,
|
||||
cache_folder: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化重排序器
|
||||
|
||||
Args:
|
||||
model_name: 模型名称或路径
|
||||
top_n: 返回的顶部文档数量
|
||||
device: 设备(cpu/cuda),如果为 None 则自动选择
|
||||
cache_folder: 模型缓存目录
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.top_n = top_n
|
||||
self.device = device
|
||||
self.cache_folder = cache_folder or os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "sentence_transformers"
|
||||
)
|
||||
|
||||
# 延迟加载模型
|
||||
self._model = None
|
||||
self._langchain_reranker = None
|
||||
|
||||
def _load_model(self):
|
||||
"""加载交叉编码器模型"""
|
||||
if self._model is None:
|
||||
try:
|
||||
self._model = CrossEncoder(
|
||||
self.model_name,
|
||||
device=self.device,
|
||||
cache_folder=self.cache_folder,
|
||||
)
|
||||
except Exception as e:
|
||||
# 如果指定模型加载失败,尝试备用模型
|
||||
print(f"加载模型 {self.model_name} 失败: {e}")
|
||||
print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...")
|
||||
self._model = CrossEncoder(
|
||||
"BAAI/bge-reranker-v2-m3",
|
||||
device=self.device,
|
||||
cache_folder=self.cache_folder,
|
||||
)
|
||||
|
||||
def _create_langchain_reranker(self):
|
||||
"""创建 LangChain 重排序器"""
|
||||
if self._langchain_reranker is None:
|
||||
self._load_model()
|
||||
self._langchain_reranker = CrossEncoderReranker(
|
||||
model=self._model,
|
||||
top_n=self.top_n,
|
||||
)
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Document],
|
||||
) -> List[Document]:
|
||||
"""
|
||||
对文档进行重排序
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
documents: 待排序文档列表
|
||||
|
||||
Returns:
|
||||
重排序后的文档列表
|
||||
"""
|
||||
self._create_langchain_reranker()
|
||||
return self._langchain_reranker.compress_documents(
|
||||
documents=documents,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def create_contextual_compression_retriever(
|
||||
self,
|
||||
base_retriever: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
创建上下文压缩检索器
|
||||
|
||||
Args:
|
||||
base_retriever: 基础检索器
|
||||
|
||||
Returns:
|
||||
上下文压缩检索器
|
||||
"""
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
|
||||
self._create_langchain_reranker()
|
||||
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=self._langchain_reranker,
|
||||
base_retriever=base_retriever,
|
||||
)
|
||||
|
||||
return compression_retriever
|
||||
|
||||
@classmethod
|
||||
def create_from_config(
|
||||
cls,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> "CrossEncoderReranker":
|
||||
"""
|
||||
从配置创建重排序器
|
||||
|
||||
Args:
|
||||
config: 配置字典,包含 model_name, top_n, device 等
|
||||
|
||||
Returns:
|
||||
CrossEncoderReranker 实例
|
||||
"""
|
||||
config = config or {}
|
||||
return cls(
|
||||
model_name=config.get("model_name", "BAAI/bge-reranker-base"),
|
||||
top_n=config.get("top_n", 5),
|
||||
device=config.get("device", None),
|
||||
cache_folder=config.get("cache_folder", None),
|
||||
)
|
||||
Reference in New Issue
Block a user