优化: LLM降级重排一次调用给所有文档打分
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 9m56s
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 9m56s
This commit is contained in:
@@ -215,49 +215,56 @@ class LLMFallbackRerankService(BaseRerankService):
|
||||
|
||||
def compute_scores(self, query: str, documents: List[str]) -> List[float]:
|
||||
"""
|
||||
使用 LLM 评估文档相关性并打分
|
||||
使用 LLM 评估文档相关性并打分 - 一次调用给所有文档打分!
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
logger.info(f"[LLMFallbackRerank] 开始为 {len(documents)} 个文档打分")
|
||||
scores = []
|
||||
for i, doc in enumerate(documents):
|
||||
score = self._score_single_document(query, doc)
|
||||
scores.append(score)
|
||||
logger.info(f"[LLMFallbackRerank] doc[{i}] score={score:.4f}")
|
||||
|
||||
return scores
|
||||
|
||||
def _score_single_document(self, query: str, document: str) -> float:
|
||||
"""
|
||||
让 LLM 为单个文档的相关性打分 (0.0-1.0)
|
||||
"""
|
||||
prompt = f"""你是一个文档相关性评分专家。请评估以下文档与查询的相关性,返回一个0到1之间的分数:
|
||||
|
||||
# 构建提示词,一次性给所有文档打分
|
||||
docs_str = "\n".join([f"文档{i}: {doc[:500]}..." for i, doc in enumerate(documents)])
|
||||
prompt = f"""你是一个文档相关性评分专家。请评估以下文档与查询的相关性,为每个文档返回一个0到1之间的分数:
|
||||
- 1.0表示完全相关
|
||||
- 0.0表示完全不相关
|
||||
|
||||
查询: {query}
|
||||
|
||||
文档: {document}
|
||||
{docs_str}
|
||||
|
||||
请只返回一个数字,不要解释。"""
|
||||
请按以下JSON格式返回,不要解释:
|
||||
{{
|
||||
"scores": [
|
||||
0.95,
|
||||
0.12,
|
||||
...
|
||||
]
|
||||
}}"""
|
||||
|
||||
try:
|
||||
result = self.llm.invoke(prompt)
|
||||
content = result.content if hasattr(result, 'content') else str(result)
|
||||
# 尝试提取数字
|
||||
|
||||
# 尝试提取 JSON
|
||||
import json
|
||||
import re
|
||||
match = re.search(r'(\d+\.?\d*)', content)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
# 确保在 0-1 之间
|
||||
return max(0.0, min(1.0, score))
|
||||
# 如果没有找到数字,返回0.5作为默认值
|
||||
return 0.5
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
scores = data.get("scores", [])
|
||||
# 确保分数数量匹配,且在0-1之间
|
||||
if len(scores) == len(documents):
|
||||
scores = [max(0.0, min(1.0, s)) for s in scores]
|
||||
for i, score in enumerate(scores):
|
||||
logger.info(f"[LLMFallbackRerank] doc[{i}] score={score:.4f}")
|
||||
return scores
|
||||
|
||||
logger.warning(f"[LLMFallbackRerank] 无法从LLM响应提取分数,返回默认值")
|
||||
return [0.5 for _ in documents]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM 打分失败,返回默认分数 0.5: {e}")
|
||||
return 0.5
|
||||
logger.warning(f"[LLMFallbackRerank] LLM 打分失败,返回默认分数 0.5: {e}")
|
||||
return [0.5 for _ in documents]
|
||||
|
||||
|
||||
class LLMFallbackRerankProvider(BaseServiceProvider[BaseRerankService]):
|
||||
|
||||
Reference in New Issue
Block a user