重排,多路查询
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 35m37s

This commit is contained in:
2026-04-20 01:10:18 +08:00
parent 933d418d77
commit 3c906e91d9
21 changed files with 728 additions and 635 deletions

View File

@@ -1,35 +1,34 @@
"""
重排序器模块
使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度。
重排序器模块 (适配版)
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
"""
import requests
from typing import List
from langchain_core.documents import Document
class LLaMaCPPReranker:
"""使用远程 llama.cpp 服务对检索结果重排序。"""
class CrossEncoderReranker:
"""使用 Cross-Encoder 对检索结果重排序。"""
def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5):
def __init__(self,
base_url: str = "http://127.0.0.1:8083",
top_n: int = 5,
api_key: str = "huang1998", # 你设置的 LLAMA_ARG_API_KEY
timeout: int = 60):
"""
初始化重排序器
初始化远程重排序器
Args:
model_name: 预训练模型名称
top_n: 返回前 N 个结果
base_url: llama.cpp 服务的地址和端口。
top_n: 返回前 N 个结果
api_key: 在容器中设置的 API 密钥。
timeout: 请求超时时间(秒)。
"""
self.model_name = model_name
self.base_url = base_url.rstrip('/')
self.top_n = top_n
self.model = None
self.api_key = api_key
self.timeout = timeout
self.endpoint = f"{self.base_url}/v1/rerank"
# 尝试加载 Cross-Encoder 模型
try:
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(model_name)
except Exception as e:
print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}")
def compress_documents(
self, documents: List[Document], query: str
) -> List[Document]:
@@ -45,21 +44,32 @@ class CrossEncoderReranker:
"""
if not documents:
return []
# 如果模型加载失败,返回前 top_n 个文档
if self.model is None:
return documents[:self.top_n]
# 使用 Cross-Encoder 进行重排序
# 准备请求体
# 根据 llama.cpp 的 OpenAI 兼容性,文档是一个字符串列表
payload = {
"model": "bge-reranker-v2-m3",
"query": query,
"documents": [doc.page_content for doc in documents],
"top_n": self.top_n
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
try:
pairs = [[query, doc.page_content] for doc in documents]
scores = self.model.predict(pairs)
response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout)
response.raise_for_status() # 检查请求是否成功
results = response.json()
# 解析返回结果
# 返回格式: {"results": [{"index": 0, "document": "...", "relevance_score": 0.8}, ...]}
# 按相关性得分降序排列
sorted_indices = [item["index"] for item in results["results"]]
sorted_docs = [documents[idx] for idx in sorted_indices]
return sorted_docs
# 按分数降序排序
scored_docs = sorted(
zip(documents, scores), key=lambda x: x[1], reverse=True
)
return [doc for doc, _ in scored_docs[:self.top_n]]
except Exception as e:
print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}")
return documents[:self.top_n]
print(f"警告: 远程重排序过程出错,将使用原始排序。错误: {e}")
return documents[:self.top_n]