Files
ailine/app/rag/reranker.py

75 lines
2.6 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
"""
2026-04-20 01:10:18 +08:00
重排序器模块 (适配版)
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
2026-04-18 16:31:48 +08:00
"""
2026-04-20 01:10:18 +08:00
import requests
from typing import List
2026-04-18 16:31:48 +08:00
from langchain_core.documents import Document
2026-04-20 01:10:18 +08:00
class LLaMaCPPReranker:
"""使用远程 llama.cpp 服务对检索结果重排序。"""
2026-04-18 16:31:48 +08:00
2026-04-20 01:10:18 +08:00
def __init__(self,
2026-04-20 14:05:57 +08:00
base_url: str,
api_key: str,
2026-04-20 01:10:18 +08:00
top_n: int = 5,
timeout: int = 60):
2026-04-18 16:31:48 +08:00
"""
2026-04-20 01:10:18 +08:00
初始化远程重排序器
2026-04-18 16:31:48 +08:00
Args:
2026-04-20 14:05:57 +08:00
base_url: llama.cpp 服务的地址和端口默认为环境变量 LLAMACPP_RERANKER_URL "http://127.0.0.1:8083"
2026-04-20 01:10:18 +08:00
top_n: 返回前 N 个结果
2026-04-20 14:05:57 +08:00
api_key: API 密钥默认为环境变量 LLAMACPP_API_KEY "huang1998"
2026-04-20 01:10:18 +08:00
timeout: 请求超时时间
2026-04-18 16:31:48 +08:00
"""
2026-04-20 14:05:57 +08:00
self.base_url = base_url
self.api_key = api_key
2026-04-18 16:31:48 +08:00
self.top_n = top_n
2026-04-20 01:10:18 +08:00
self.timeout = timeout
2026-04-20 14:05:57 +08:00
self.endpoint = f"{self.base_url}/rerank"
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
def compress_documents(
self, documents: List[Document], query: str
2026-04-18 16:31:48 +08:00
) -> List[Document]:
"""
对文档进行重排序
Args:
2026-04-19 22:01:55 +08:00
documents: 待排序的文档列表
query: 查询字符串
2026-04-18 16:31:48 +08:00
Returns:
2026-04-19 22:01:55 +08:00
排序后的文档列表
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
if not documents:
return []
2026-04-20 01:10:18 +08:00
# 准备请求体
# 根据 llama.cpp 的 OpenAI 兼容性,文档是一个字符串列表
payload = {
"model": "bge-reranker-v2-m3",
"query": query,
"documents": [doc.page_content for doc in documents],
"top_n": self.top_n
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
2026-04-19 22:01:55 +08:00
try:
2026-04-20 01:10:18 +08:00
response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout)
response.raise_for_status() # 检查请求是否成功
results = response.json()
# 解析返回结果
# 返回格式: {"results": [{"index": 0, "document": "...", "relevance_score": 0.8}, ...]}
# 按相关性得分降序排列
sorted_indices = [item["index"] for item in results["results"]]
sorted_docs = [documents[idx] for idx in sorted_indices]
return sorted_docs
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
except Exception as e:
2026-04-20 01:10:18 +08:00
print(f"警告: 远程重排序过程出错,将使用原始排序。错误: {e}")
return documents[:self.top_n]