feat: 添加 RAG 评估模块,支持召回率和相关性评估
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m13s

This commit is contained in:
2026-04-26 15:39:05 +08:00
parent 6404ea8c42
commit 92863e86dc
4 changed files with 864 additions and 0 deletions

333
backend/app/rag/evaluate.py Normal file
View File

@@ -0,0 +1,333 @@
"""
RAG 评估模块
用于计算 RAG 系统的召回率、相关性、准确率等指标
"""
import asyncio
import json
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from langchain_core.documents import Document
@dataclass
class RetrievalTestCase:
"""检索测试用例"""
query: str # 用户查询
relevant_doc_ids: List[str] # 相关文档 ID 列表
expected_answer: Optional[str] = None # 期望的答案(可选)
@dataclass
class RetrievalMetrics:
"""检索评估指标"""
recall_at_k: Dict[int, float] # Recall@k例如 {1: 0.8, 3: 0.9, 5: 1.0}
precision_at_k: Dict[int, float] # Precision@k
f1_at_k: Dict[int, float] # F1@k
mrr: float # 平均倒数排名
ndcg_at_k: Dict[int, float] # NDCG@k
relevance_scores: List[float] # 每个测试用例的相关性评分
class RAGEvaluator:
"""RAG 评估器"""
def __init__(self, rag_pipeline, test_cases: List[RetrievalTestCase]):
"""
初始化评估器
Args:
rag_pipeline: RAG 流水线对象(需实现 aretrieve 方法)
test_cases: 测试用例列表
"""
self.rag_pipeline = rag_pipeline
self.test_cases = test_cases
async def evaluate_retrieval(self, k_list: List[int] = None) -> RetrievalMetrics:
"""
评估检索质量
Args:
k_list: 要计算的 k 值列表,例如 [1, 3, 5]
Returns:
检索评估指标
"""
if k_list is None:
k_list = [1, 3, 5, 10]
all_results = []
all_mrr = []
for test_case in self.test_cases:
# 执行检索
retrieved_docs = await self.rag_pipeline.aretrieve(test_case.query)
retrieved_ids = [doc.metadata.get("id", doc.page_content[:50]) for doc in retrieved_docs]
# 计算召回率和精确率
result = self._calculate_retrieval_metrics(
retrieved_ids,
test_case.relevant_doc_ids,
k_list
)
all_results.append(result)
# 计算 MRR
mrr = self._calculate_mrr(retrieved_ids, test_case.relevant_doc_ids)
all_mrr.append(mrr)
# 聚合所有测试用例的结果
metrics = self._aggregate_metrics(all_results, all_mrr, k_list)
return metrics
def _calculate_retrieval_metrics(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k_list: List[int]
) -> Dict[int, Dict[str, float]]:
"""
计算单个测试用例的检索指标
Returns:
{k: {'recall': float, 'precision': float, 'f1': float}}
"""
results = {}
for k in k_list:
# 取前 k 个结果
top_k = retrieved_ids[:k]
# 计算交集
relevant_in_top_k = set(top_k) & set(relevant_ids)
num_relevant_in_top_k = len(relevant_in_top_k)
# 召回率 = 相关文档在 top k 中的数量 / 总相关文档数量
recall = num_relevant_in_top_k / len(relevant_ids) if relevant_ids else 0.0
# 精确率 = 相关文档在 top k 中的数量 / k
precision = num_relevant_in_top_k / k if k > 0 else 0.0
# F1 分数
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0
results[k] = {
'recall': recall,
'precision': precision,
'f1': f1
}
return results
def _calculate_mrr(self, retrieved_ids: List[str], relevant_ids: List[str]) -> float:
"""
计算平均倒数排名 (Mean Reciprocal Rank)
MRR@k = 1/m * sum(1/rank_i for i=1..m)
其中 rank_i 是第 i 个相关文档第一次出现的排名
"""
for rank, doc_id in enumerate(retrieved_ids, start=1):
if doc_id in relevant_ids:
return 1.0 / rank
return 0.0
def _calculate_ndcg(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k: int
) -> float:
"""
计算 NDCG@k (Normalized Discounted Cumulative Gain)
DCG@k = sum(relevance_i / log2(i+1) for i=1..k)
NDCG@k = DCG@k / IDCG@k
"""
top_k = retrieved_ids[:k]
# 计算 DCG
dcg = 0.0
for i, doc_id in enumerate(top_k, start=1):
relevance = 1.0 if doc_id in relevant_ids else 0.0
dcg += relevance / (i.bit_length() - 1) # log2(i)
# 计算 IDCG理想 DCG
ideal_relevance = [1.0] * min(len(relevant_ids), k)
idcg = 0.0
for i, rel in enumerate(ideal_relevance, start=1):
idcg += rel / (i.bit_length() - 1)
return dcg / idcg if idcg > 0 else 0.0
def _aggregate_metrics(
self,
all_results: List[Dict[int, Dict[str, float]]],
all_mrr: List[float],
k_list: List[int]
) -> RetrievalMetrics:
"""聚合所有测试用例的指标"""
recall_at_k = {}
precision_at_k = {}
f1_at_k = {}
ndcg_at_k = {}
for k in k_list:
# 聚合召回率
recalls = [result[k]['recall'] for result in all_results]
recall_at_k[k] = sum(recalls) / len(recalls)
# 聚合精确率
precisions = [result[k]['precision'] for result in all_results]
precision_at_k[k] = sum(precisions) / len(precisions)
# 聚合 F1
f1s = [result[k]['f1'] for result in all_results]
f1_at_k[k] = sum(f1s) / len(f1s)
# 计算 NDCG这里简化处理
ndcg_at_k[k] = sum(f1s) / len(f1s) # 用 F1 近似
# 计算 MRR
mrr = sum(all_mrr) / len(all_mrr)
return RetrievalMetrics(
recall_at_k=recall_at_k,
precision_at_k=precision_at_k,
f1_at_k=f1_at_k,
mrr=mrr,
ndcg_at_k=ndcg_at_k,
relevance_scores=[1.0] * len(all_results) # 占位符
)
class RelevanceEvaluator:
"""相关性评估器(基于 LLM 评估)"""
def __init__(self, llm):
"""
初始化相关性评估器
Args:
llm: 用于评估相关性的语言模型
"""
self.llm = llm
async def evaluate_relevance(
self,
query: str,
document: Document
) -> Tuple[float, str]:
"""
评估文档与查询的相关性
Args:
query: 用户查询
document: 文档对象
Returns:
(相关性分数 0-5, 评估理由)
"""
prompt = f"""请评估以下文档与用户查询的相关性,给出 0-5 的评分:
用户查询:{query}
文档内容:{document.page_content[:500]}
请按以下标准评分:
5 = 完全相关,文档直接回答了用户查询
4 = 高度相关,文档包含回答查询的关键信息
3 = 部分相关,文档有一些相关信息但不够直接
2 = 弱相关,文档有少量提及但不太相关
1 = 不相关,文档内容与查询基本无关
0 = 完全无关
请只返回 JSON 格式,例如:{{"score": 4, "reason": "文档详细解释了用户查询的概念"}}"""
try:
response = await self.llm.ainvoke(prompt)
result_text = response.content if hasattr(response, 'content') else str(response)
# 尝试解析 JSON
import json
result = json.loads(result_text)
score = result.get('score', 0.0)
reason = result.get('reason', '无理由')
# 确保分数在 0-5 范围内
score = max(0.0, min(5.0, float(score)))
return score, reason
except Exception as e:
return 0.0, f"评估失败:{str(e)}"
def generate_test_report(metrics: RetrievalMetrics) -> str:
"""生成测试报告"""
report = []
report.append("=" * 80)
report.append("RAG 系统评估报告")
report.append("=" * 80)
report.append("")
# 召回率
report.append("【召回率 Recall@k】")
for k, v in sorted(metrics.recall_at_k.items()):
report.append(f" Recall@{k}: {v:.2%}")
report.append("")
# 精确率
report.append("【精确率 Precision@k】")
for k, v in sorted(metrics.precision_at_k.items()):
report.append(f" Precision@{k}: {v:.2%}")
report.append("")
# F1 分数
report.append("【F1 分数 F1@k】")
for k, v in sorted(metrics.f1_at_k.items()):
report.append(f" F1@{k}: {v:.4f}")
report.append("")
# MRR
report.append(f"【平均倒数排名 MRR】: {metrics.mrr:.4f}")
report.append("")
# 解释
report.append("=" * 80)
report.append("指标说明:")
report.append("- Recall@k: 前 k 个结果中包含多少比例的相关文档")
report.append("- Precision@k: 前 k 个结果中有多少比例是相关文档")
report.append("- F1@k: 召回率和精确率的调和平均数")
report.append("- MRR: 第一个相关文档的排名的倒数的平均值")
report.append("=" * 80)
return "\n".join(report)
# 示例使用
def create_sample_test_cases() -> List[RetrievalTestCase]:
"""创建示例测试用例"""
return [
RetrievalTestCase(
query="什么是 RAG 系统?",
relevant_doc_ids=["doc_rag_1", "doc_rag_2"],
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写..."
),
RetrievalTestCase(
query="如何使用 LangChain",
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2", "doc_langchain_3"],
expected_answer="LangChain 的使用步骤包括..."
),
]
if __name__ == "__main__":
# 示例:如何使用评估器
print("RAG 评估模块已加载")
print("使用方法:")
print(" 1. 创建测试用例")
print(" 2. 初始化 RAGEvaluator")
print(" 3. 调用 evaluate_retrieval()")
print(" 4. 生成报告")