Files
ailine/backend/app/rag/evaluate.py
root 92863e86dc
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m13s
feat: 添加 RAG 评估模块,支持召回率和相关性评估
2026-04-26 15:39:05 +08:00

334 lines
9.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 评估模块
用于计算 RAG 系统的召回率、相关性、准确率等指标
"""
import asyncio
import json
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from langchain_core.documents import Document
@dataclass
class RetrievalTestCase:
"""检索测试用例"""
query: str # 用户查询
relevant_doc_ids: List[str] # 相关文档 ID 列表
expected_answer: Optional[str] = None # 期望的答案(可选)
@dataclass
class RetrievalMetrics:
"""检索评估指标"""
recall_at_k: Dict[int, float] # Recall@k例如 {1: 0.8, 3: 0.9, 5: 1.0}
precision_at_k: Dict[int, float] # Precision@k
f1_at_k: Dict[int, float] # F1@k
mrr: float # 平均倒数排名
ndcg_at_k: Dict[int, float] # NDCG@k
relevance_scores: List[float] # 每个测试用例的相关性评分
class RAGEvaluator:
"""RAG 评估器"""
def __init__(self, rag_pipeline, test_cases: List[RetrievalTestCase]):
"""
初始化评估器
Args:
rag_pipeline: RAG 流水线对象(需实现 aretrieve 方法)
test_cases: 测试用例列表
"""
self.rag_pipeline = rag_pipeline
self.test_cases = test_cases
async def evaluate_retrieval(self, k_list: List[int] = None) -> RetrievalMetrics:
"""
评估检索质量
Args:
k_list: 要计算的 k 值列表,例如 [1, 3, 5]
Returns:
检索评估指标
"""
if k_list is None:
k_list = [1, 3, 5, 10]
all_results = []
all_mrr = []
for test_case in self.test_cases:
# 执行检索
retrieved_docs = await self.rag_pipeline.aretrieve(test_case.query)
retrieved_ids = [doc.metadata.get("id", doc.page_content[:50]) for doc in retrieved_docs]
# 计算召回率和精确率
result = self._calculate_retrieval_metrics(
retrieved_ids,
test_case.relevant_doc_ids,
k_list
)
all_results.append(result)
# 计算 MRR
mrr = self._calculate_mrr(retrieved_ids, test_case.relevant_doc_ids)
all_mrr.append(mrr)
# 聚合所有测试用例的结果
metrics = self._aggregate_metrics(all_results, all_mrr, k_list)
return metrics
def _calculate_retrieval_metrics(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k_list: List[int]
) -> Dict[int, Dict[str, float]]:
"""
计算单个测试用例的检索指标
Returns:
{k: {'recall': float, 'precision': float, 'f1': float}}
"""
results = {}
for k in k_list:
# 取前 k 个结果
top_k = retrieved_ids[:k]
# 计算交集
relevant_in_top_k = set(top_k) & set(relevant_ids)
num_relevant_in_top_k = len(relevant_in_top_k)
# 召回率 = 相关文档在 top k 中的数量 / 总相关文档数量
recall = num_relevant_in_top_k / len(relevant_ids) if relevant_ids else 0.0
# 精确率 = 相关文档在 top k 中的数量 / k
precision = num_relevant_in_top_k / k if k > 0 else 0.0
# F1 分数
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0
results[k] = {
'recall': recall,
'precision': precision,
'f1': f1
}
return results
def _calculate_mrr(self, retrieved_ids: List[str], relevant_ids: List[str]) -> float:
"""
计算平均倒数排名 (Mean Reciprocal Rank)
MRR@k = 1/m * sum(1/rank_i for i=1..m)
其中 rank_i 是第 i 个相关文档第一次出现的排名
"""
for rank, doc_id in enumerate(retrieved_ids, start=1):
if doc_id in relevant_ids:
return 1.0 / rank
return 0.0
def _calculate_ndcg(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k: int
) -> float:
"""
计算 NDCG@k (Normalized Discounted Cumulative Gain)
DCG@k = sum(relevance_i / log2(i+1) for i=1..k)
NDCG@k = DCG@k / IDCG@k
"""
top_k = retrieved_ids[:k]
# 计算 DCG
dcg = 0.0
for i, doc_id in enumerate(top_k, start=1):
relevance = 1.0 if doc_id in relevant_ids else 0.0
dcg += relevance / (i.bit_length() - 1) # log2(i)
# 计算 IDCG理想 DCG
ideal_relevance = [1.0] * min(len(relevant_ids), k)
idcg = 0.0
for i, rel in enumerate(ideal_relevance, start=1):
idcg += rel / (i.bit_length() - 1)
return dcg / idcg if idcg > 0 else 0.0
def _aggregate_metrics(
self,
all_results: List[Dict[int, Dict[str, float]]],
all_mrr: List[float],
k_list: List[int]
) -> RetrievalMetrics:
"""聚合所有测试用例的指标"""
recall_at_k = {}
precision_at_k = {}
f1_at_k = {}
ndcg_at_k = {}
for k in k_list:
# 聚合召回率
recalls = [result[k]['recall'] for result in all_results]
recall_at_k[k] = sum(recalls) / len(recalls)
# 聚合精确率
precisions = [result[k]['precision'] for result in all_results]
precision_at_k[k] = sum(precisions) / len(precisions)
# 聚合 F1
f1s = [result[k]['f1'] for result in all_results]
f1_at_k[k] = sum(f1s) / len(f1s)
# 计算 NDCG这里简化处理
ndcg_at_k[k] = sum(f1s) / len(f1s) # 用 F1 近似
# 计算 MRR
mrr = sum(all_mrr) / len(all_mrr)
return RetrievalMetrics(
recall_at_k=recall_at_k,
precision_at_k=precision_at_k,
f1_at_k=f1_at_k,
mrr=mrr,
ndcg_at_k=ndcg_at_k,
relevance_scores=[1.0] * len(all_results) # 占位符
)
class RelevanceEvaluator:
"""相关性评估器(基于 LLM 评估)"""
def __init__(self, llm):
"""
初始化相关性评估器
Args:
llm: 用于评估相关性的语言模型
"""
self.llm = llm
async def evaluate_relevance(
self,
query: str,
document: Document
) -> Tuple[float, str]:
"""
评估文档与查询的相关性
Args:
query: 用户查询
document: 文档对象
Returns:
(相关性分数 0-5, 评估理由)
"""
prompt = f"""请评估以下文档与用户查询的相关性,给出 0-5 的评分:
用户查询:{query}
文档内容:{document.page_content[:500]}
请按以下标准评分:
5 = 完全相关,文档直接回答了用户查询
4 = 高度相关,文档包含回答查询的关键信息
3 = 部分相关,文档有一些相关信息但不够直接
2 = 弱相关,文档有少量提及但不太相关
1 = 不相关,文档内容与查询基本无关
0 = 完全无关
请只返回 JSON 格式,例如:{{"score": 4, "reason": "文档详细解释了用户查询的概念"}}"""
try:
response = await self.llm.ainvoke(prompt)
result_text = response.content if hasattr(response, 'content') else str(response)
# 尝试解析 JSON
import json
result = json.loads(result_text)
score = result.get('score', 0.0)
reason = result.get('reason', '无理由')
# 确保分数在 0-5 范围内
score = max(0.0, min(5.0, float(score)))
return score, reason
except Exception as e:
return 0.0, f"评估失败:{str(e)}"
def generate_test_report(metrics: RetrievalMetrics) -> str:
"""生成测试报告"""
report = []
report.append("=" * 80)
report.append("RAG 系统评估报告")
report.append("=" * 80)
report.append("")
# 召回率
report.append("【召回率 Recall@k】")
for k, v in sorted(metrics.recall_at_k.items()):
report.append(f" Recall@{k}: {v:.2%}")
report.append("")
# 精确率
report.append("【精确率 Precision@k】")
for k, v in sorted(metrics.precision_at_k.items()):
report.append(f" Precision@{k}: {v:.2%}")
report.append("")
# F1 分数
report.append("【F1 分数 F1@k】")
for k, v in sorted(metrics.f1_at_k.items()):
report.append(f" F1@{k}: {v:.4f}")
report.append("")
# MRR
report.append(f"【平均倒数排名 MRR】: {metrics.mrr:.4f}")
report.append("")
# 解释
report.append("=" * 80)
report.append("指标说明:")
report.append("- Recall@k: 前 k 个结果中包含多少比例的相关文档")
report.append("- Precision@k: 前 k 个结果中有多少比例是相关文档")
report.append("- F1@k: 召回率和精确率的调和平均数")
report.append("- MRR: 第一个相关文档的排名的倒数的平均值")
report.append("=" * 80)
return "\n".join(report)
# 示例使用
def create_sample_test_cases() -> List[RetrievalTestCase]:
"""创建示例测试用例"""
return [
RetrievalTestCase(
query="什么是 RAG 系统?",
relevant_doc_ids=["doc_rag_1", "doc_rag_2"],
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写..."
),
RetrievalTestCase(
query="如何使用 LangChain",
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2", "doc_langchain_3"],
expected_answer="LangChain 的使用步骤包括..."
),
]
if __name__ == "__main__":
# 示例:如何使用评估器
print("RAG 评估模块已加载")
print("使用方法:")
print(" 1. 创建测试用例")
print(" 2. 初始化 RAGEvaluator")
print(" 3. 调用 evaluate_retrieval()")
print(" 4. 生成报告")