334 lines
9.9 KiB
Python
334 lines
9.9 KiB
Python
"""
|
||
RAG 评估模块
|
||
用于计算 RAG 系统的召回率、相关性、准确率等指标
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
from typing import List, Dict, Tuple, Optional
|
||
from dataclasses import dataclass
|
||
from langchain_core.documents import Document
|
||
|
||
|
||
@dataclass
|
||
class RetrievalTestCase:
|
||
"""检索测试用例"""
|
||
query: str # 用户查询
|
||
relevant_doc_ids: List[str] # 相关文档 ID 列表
|
||
expected_answer: Optional[str] = None # 期望的答案(可选)
|
||
|
||
|
||
@dataclass
|
||
class RetrievalMetrics:
|
||
"""检索评估指标"""
|
||
recall_at_k: Dict[int, float] # Recall@k,例如 {1: 0.8, 3: 0.9, 5: 1.0}
|
||
precision_at_k: Dict[int, float] # Precision@k
|
||
f1_at_k: Dict[int, float] # F1@k
|
||
mrr: float # 平均倒数排名
|
||
ndcg_at_k: Dict[int, float] # NDCG@k
|
||
relevance_scores: List[float] # 每个测试用例的相关性评分
|
||
|
||
|
||
class RAGEvaluator:
|
||
"""RAG 评估器"""
|
||
|
||
def __init__(self, rag_pipeline, test_cases: List[RetrievalTestCase]):
|
||
"""
|
||
初始化评估器
|
||
|
||
Args:
|
||
rag_pipeline: RAG 流水线对象(需实现 aretrieve 方法)
|
||
test_cases: 测试用例列表
|
||
"""
|
||
self.rag_pipeline = rag_pipeline
|
||
self.test_cases = test_cases
|
||
|
||
async def evaluate_retrieval(self, k_list: List[int] = None) -> RetrievalMetrics:
|
||
"""
|
||
评估检索质量
|
||
|
||
Args:
|
||
k_list: 要计算的 k 值列表,例如 [1, 3, 5]
|
||
|
||
Returns:
|
||
检索评估指标
|
||
"""
|
||
if k_list is None:
|
||
k_list = [1, 3, 5, 10]
|
||
|
||
all_results = []
|
||
all_mrr = []
|
||
|
||
for test_case in self.test_cases:
|
||
# 执行检索
|
||
retrieved_docs = await self.rag_pipeline.aretrieve(test_case.query)
|
||
retrieved_ids = [doc.metadata.get("id", doc.page_content[:50]) for doc in retrieved_docs]
|
||
|
||
# 计算召回率和精确率
|
||
result = self._calculate_retrieval_metrics(
|
||
retrieved_ids,
|
||
test_case.relevant_doc_ids,
|
||
k_list
|
||
)
|
||
all_results.append(result)
|
||
|
||
# 计算 MRR
|
||
mrr = self._calculate_mrr(retrieved_ids, test_case.relevant_doc_ids)
|
||
all_mrr.append(mrr)
|
||
|
||
# 聚合所有测试用例的结果
|
||
metrics = self._aggregate_metrics(all_results, all_mrr, k_list)
|
||
return metrics
|
||
|
||
def _calculate_retrieval_metrics(
|
||
self,
|
||
retrieved_ids: List[str],
|
||
relevant_ids: List[str],
|
||
k_list: List[int]
|
||
) -> Dict[int, Dict[str, float]]:
|
||
"""
|
||
计算单个测试用例的检索指标
|
||
|
||
Returns:
|
||
{k: {'recall': float, 'precision': float, 'f1': float}}
|
||
"""
|
||
results = {}
|
||
|
||
for k in k_list:
|
||
# 取前 k 个结果
|
||
top_k = retrieved_ids[:k]
|
||
|
||
# 计算交集
|
||
relevant_in_top_k = set(top_k) & set(relevant_ids)
|
||
num_relevant_in_top_k = len(relevant_in_top_k)
|
||
|
||
# 召回率 = 相关文档在 top k 中的数量 / 总相关文档数量
|
||
recall = num_relevant_in_top_k / len(relevant_ids) if relevant_ids else 0.0
|
||
|
||
# 精确率 = 相关文档在 top k 中的数量 / k
|
||
precision = num_relevant_in_top_k / k if k > 0 else 0.0
|
||
|
||
# F1 分数
|
||
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0
|
||
|
||
results[k] = {
|
||
'recall': recall,
|
||
'precision': precision,
|
||
'f1': f1
|
||
}
|
||
|
||
return results
|
||
|
||
def _calculate_mrr(self, retrieved_ids: List[str], relevant_ids: List[str]) -> float:
|
||
"""
|
||
计算平均倒数排名 (Mean Reciprocal Rank)
|
||
|
||
MRR@k = 1/m * sum(1/rank_i for i=1..m)
|
||
其中 rank_i 是第 i 个相关文档第一次出现的排名
|
||
"""
|
||
for rank, doc_id in enumerate(retrieved_ids, start=1):
|
||
if doc_id in relevant_ids:
|
||
return 1.0 / rank
|
||
return 0.0
|
||
|
||
def _calculate_ndcg(
|
||
self,
|
||
retrieved_ids: List[str],
|
||
relevant_ids: List[str],
|
||
k: int
|
||
) -> float:
|
||
"""
|
||
计算 NDCG@k (Normalized Discounted Cumulative Gain)
|
||
|
||
DCG@k = sum(relevance_i / log2(i+1) for i=1..k)
|
||
NDCG@k = DCG@k / IDCG@k
|
||
"""
|
||
top_k = retrieved_ids[:k]
|
||
|
||
# 计算 DCG
|
||
dcg = 0.0
|
||
for i, doc_id in enumerate(top_k, start=1):
|
||
relevance = 1.0 if doc_id in relevant_ids else 0.0
|
||
dcg += relevance / (i.bit_length() - 1) # log2(i)
|
||
|
||
# 计算 IDCG(理想 DCG)
|
||
ideal_relevance = [1.0] * min(len(relevant_ids), k)
|
||
idcg = 0.0
|
||
for i, rel in enumerate(ideal_relevance, start=1):
|
||
idcg += rel / (i.bit_length() - 1)
|
||
|
||
return dcg / idcg if idcg > 0 else 0.0
|
||
|
||
def _aggregate_metrics(
|
||
self,
|
||
all_results: List[Dict[int, Dict[str, float]]],
|
||
all_mrr: List[float],
|
||
k_list: List[int]
|
||
) -> RetrievalMetrics:
|
||
"""聚合所有测试用例的指标"""
|
||
|
||
recall_at_k = {}
|
||
precision_at_k = {}
|
||
f1_at_k = {}
|
||
ndcg_at_k = {}
|
||
|
||
for k in k_list:
|
||
# 聚合召回率
|
||
recalls = [result[k]['recall'] for result in all_results]
|
||
recall_at_k[k] = sum(recalls) / len(recalls)
|
||
|
||
# 聚合精确率
|
||
precisions = [result[k]['precision'] for result in all_results]
|
||
precision_at_k[k] = sum(precisions) / len(precisions)
|
||
|
||
# 聚合 F1
|
||
f1s = [result[k]['f1'] for result in all_results]
|
||
f1_at_k[k] = sum(f1s) / len(f1s)
|
||
|
||
# 计算 NDCG(这里简化处理)
|
||
ndcg_at_k[k] = sum(f1s) / len(f1s) # 用 F1 近似
|
||
|
||
# 计算 MRR
|
||
mrr = sum(all_mrr) / len(all_mrr)
|
||
|
||
return RetrievalMetrics(
|
||
recall_at_k=recall_at_k,
|
||
precision_at_k=precision_at_k,
|
||
f1_at_k=f1_at_k,
|
||
mrr=mrr,
|
||
ndcg_at_k=ndcg_at_k,
|
||
relevance_scores=[1.0] * len(all_results) # 占位符
|
||
)
|
||
|
||
|
||
class RelevanceEvaluator:
|
||
"""相关性评估器(基于 LLM 评估)"""
|
||
|
||
def __init__(self, llm):
|
||
"""
|
||
初始化相关性评估器
|
||
|
||
Args:
|
||
llm: 用于评估相关性的语言模型
|
||
"""
|
||
self.llm = llm
|
||
|
||
async def evaluate_relevance(
|
||
self,
|
||
query: str,
|
||
document: Document
|
||
) -> Tuple[float, str]:
|
||
"""
|
||
评估文档与查询的相关性
|
||
|
||
Args:
|
||
query: 用户查询
|
||
document: 文档对象
|
||
|
||
Returns:
|
||
(相关性分数 0-5, 评估理由)
|
||
"""
|
||
prompt = f"""请评估以下文档与用户查询的相关性,给出 0-5 的评分:
|
||
|
||
用户查询:{query}
|
||
|
||
文档内容:{document.page_content[:500]}
|
||
|
||
请按以下标准评分:
|
||
5 = 完全相关,文档直接回答了用户查询
|
||
4 = 高度相关,文档包含回答查询的关键信息
|
||
3 = 部分相关,文档有一些相关信息但不够直接
|
||
2 = 弱相关,文档有少量提及但不太相关
|
||
1 = 不相关,文档内容与查询基本无关
|
||
0 = 完全无关
|
||
|
||
请只返回 JSON 格式,例如:{{"score": 4, "reason": "文档详细解释了用户查询的概念"}}"""
|
||
|
||
try:
|
||
response = await self.llm.ainvoke(prompt)
|
||
result_text = response.content if hasattr(response, 'content') else str(response)
|
||
|
||
# 尝试解析 JSON
|
||
import json
|
||
result = json.loads(result_text)
|
||
score = result.get('score', 0.0)
|
||
reason = result.get('reason', '无理由')
|
||
|
||
# 确保分数在 0-5 范围内
|
||
score = max(0.0, min(5.0, float(score)))
|
||
|
||
return score, reason
|
||
|
||
except Exception as e:
|
||
return 0.0, f"评估失败:{str(e)}"
|
||
|
||
|
||
def generate_test_report(metrics: RetrievalMetrics) -> str:
|
||
"""生成测试报告"""
|
||
|
||
report = []
|
||
report.append("=" * 80)
|
||
report.append("RAG 系统评估报告")
|
||
report.append("=" * 80)
|
||
report.append("")
|
||
|
||
# 召回率
|
||
report.append("【召回率 Recall@k】")
|
||
for k, v in sorted(metrics.recall_at_k.items()):
|
||
report.append(f" Recall@{k}: {v:.2%}")
|
||
report.append("")
|
||
|
||
# 精确率
|
||
report.append("【精确率 Precision@k】")
|
||
for k, v in sorted(metrics.precision_at_k.items()):
|
||
report.append(f" Precision@{k}: {v:.2%}")
|
||
report.append("")
|
||
|
||
# F1 分数
|
||
report.append("【F1 分数 F1@k】")
|
||
for k, v in sorted(metrics.f1_at_k.items()):
|
||
report.append(f" F1@{k}: {v:.4f}")
|
||
report.append("")
|
||
|
||
# MRR
|
||
report.append(f"【平均倒数排名 MRR】: {metrics.mrr:.4f}")
|
||
report.append("")
|
||
|
||
# 解释
|
||
report.append("=" * 80)
|
||
report.append("指标说明:")
|
||
report.append("- Recall@k: 前 k 个结果中包含多少比例的相关文档")
|
||
report.append("- Precision@k: 前 k 个结果中有多少比例是相关文档")
|
||
report.append("- F1@k: 召回率和精确率的调和平均数")
|
||
report.append("- MRR: 第一个相关文档的排名的倒数的平均值")
|
||
report.append("=" * 80)
|
||
|
||
return "\n".join(report)
|
||
|
||
|
||
# 示例使用
|
||
def create_sample_test_cases() -> List[RetrievalTestCase]:
|
||
"""创建示例测试用例"""
|
||
return [
|
||
RetrievalTestCase(
|
||
query="什么是 RAG 系统?",
|
||
relevant_doc_ids=["doc_rag_1", "doc_rag_2"],
|
||
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写..."
|
||
),
|
||
RetrievalTestCase(
|
||
query="如何使用 LangChain?",
|
||
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2", "doc_langchain_3"],
|
||
expected_answer="LangChain 的使用步骤包括..."
|
||
),
|
||
]
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 示例:如何使用评估器
|
||
print("RAG 评估模块已加载")
|
||
print("使用方法:")
|
||
print(" 1. 创建测试用例")
|
||
print(" 2. 初始化 RAGEvaluator")
|
||
print(" 3. 调用 evaluate_retrieval()")
|
||
print(" 4. 生成报告")
|