""" RAG 评估模块 用于计算 RAG 系统的召回率、相关性、准确率等指标 """ import asyncio import json from typing import List, Dict, Tuple, Optional from dataclasses import dataclass from langchain_core.documents import Document @dataclass class RetrievalTestCase: """检索测试用例""" query: str # 用户查询 relevant_doc_ids: List[str] # 相关文档 ID 列表 expected_answer: Optional[str] = None # 期望的答案(可选) @dataclass class RetrievalMetrics: """检索评估指标""" recall_at_k: Dict[int, float] # Recall@k,例如 {1: 0.8, 3: 0.9, 5: 1.0} precision_at_k: Dict[int, float] # Precision@k f1_at_k: Dict[int, float] # F1@k mrr: float # 平均倒数排名 ndcg_at_k: Dict[int, float] # NDCG@k relevance_scores: List[float] # 每个测试用例的相关性评分 class RAGEvaluator: """RAG 评估器""" def __init__(self, rag_pipeline, test_cases: List[RetrievalTestCase]): """ 初始化评估器 Args: rag_pipeline: RAG 流水线对象(需实现 aretrieve 方法) test_cases: 测试用例列表 """ self.rag_pipeline = rag_pipeline self.test_cases = test_cases async def evaluate_retrieval(self, k_list: List[int] = None) -> RetrievalMetrics: """ 评估检索质量 Args: k_list: 要计算的 k 值列表,例如 [1, 3, 5] Returns: 检索评估指标 """ if k_list is None: k_list = [1, 3, 5, 10] all_results = [] all_mrr = [] for test_case in self.test_cases: # 执行检索 retrieved_docs = await self.rag_pipeline.aretrieve(test_case.query) retrieved_ids = [doc.metadata.get("id", doc.page_content[:50]) for doc in retrieved_docs] # 计算召回率和精确率 result = self._calculate_retrieval_metrics( retrieved_ids, test_case.relevant_doc_ids, k_list ) all_results.append(result) # 计算 MRR mrr = self._calculate_mrr(retrieved_ids, test_case.relevant_doc_ids) all_mrr.append(mrr) # 聚合所有测试用例的结果 metrics = self._aggregate_metrics(all_results, all_mrr, k_list) return metrics def _calculate_retrieval_metrics( self, retrieved_ids: List[str], relevant_ids: List[str], k_list: List[int] ) -> Dict[int, Dict[str, float]]: """ 计算单个测试用例的检索指标 Returns: {k: {'recall': float, 'precision': float, 'f1': float}} """ results = {} for k in k_list: # 取前 k 个结果 top_k = retrieved_ids[:k] # 计算交集 relevant_in_top_k = set(top_k) & set(relevant_ids) num_relevant_in_top_k = len(relevant_in_top_k) # 召回率 = 相关文档在 top k 中的数量 / 总相关文档数量 recall = num_relevant_in_top_k / len(relevant_ids) if relevant_ids else 0.0 # 精确率 = 相关文档在 top k 中的数量 / k precision = num_relevant_in_top_k / k if k > 0 else 0.0 # F1 分数 f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0 results[k] = { 'recall': recall, 'precision': precision, 'f1': f1 } return results def _calculate_mrr(self, retrieved_ids: List[str], relevant_ids: List[str]) -> float: """ 计算平均倒数排名 (Mean Reciprocal Rank) MRR@k = 1/m * sum(1/rank_i for i=1..m) 其中 rank_i 是第 i 个相关文档第一次出现的排名 """ for rank, doc_id in enumerate(retrieved_ids, start=1): if doc_id in relevant_ids: return 1.0 / rank return 0.0 def _calculate_ndcg( self, retrieved_ids: List[str], relevant_ids: List[str], k: int ) -> float: """ 计算 NDCG@k (Normalized Discounted Cumulative Gain) DCG@k = sum(relevance_i / log2(i+1) for i=1..k) NDCG@k = DCG@k / IDCG@k """ top_k = retrieved_ids[:k] # 计算 DCG dcg = 0.0 for i, doc_id in enumerate(top_k, start=1): relevance = 1.0 if doc_id in relevant_ids else 0.0 dcg += relevance / (i.bit_length() - 1) # log2(i) # 计算 IDCG(理想 DCG) ideal_relevance = [1.0] * min(len(relevant_ids), k) idcg = 0.0 for i, rel in enumerate(ideal_relevance, start=1): idcg += rel / (i.bit_length() - 1) return dcg / idcg if idcg > 0 else 0.0 def _aggregate_metrics( self, all_results: List[Dict[int, Dict[str, float]]], all_mrr: List[float], k_list: List[int] ) -> RetrievalMetrics: """聚合所有测试用例的指标""" recall_at_k = {} precision_at_k = {} f1_at_k = {} ndcg_at_k = {} for k in k_list: # 聚合召回率 recalls = [result[k]['recall'] for result in all_results] recall_at_k[k] = sum(recalls) / len(recalls) # 聚合精确率 precisions = [result[k]['precision'] for result in all_results] precision_at_k[k] = sum(precisions) / len(precisions) # 聚合 F1 f1s = [result[k]['f1'] for result in all_results] f1_at_k[k] = sum(f1s) / len(f1s) # 计算 NDCG(这里简化处理) ndcg_at_k[k] = sum(f1s) / len(f1s) # 用 F1 近似 # 计算 MRR mrr = sum(all_mrr) / len(all_mrr) return RetrievalMetrics( recall_at_k=recall_at_k, precision_at_k=precision_at_k, f1_at_k=f1_at_k, mrr=mrr, ndcg_at_k=ndcg_at_k, relevance_scores=[1.0] * len(all_results) # 占位符 ) class RelevanceEvaluator: """相关性评估器(基于 LLM 评估)""" def __init__(self, llm): """ 初始化相关性评估器 Args: llm: 用于评估相关性的语言模型 """ self.llm = llm async def evaluate_relevance( self, query: str, document: Document ) -> Tuple[float, str]: """ 评估文档与查询的相关性 Args: query: 用户查询 document: 文档对象 Returns: (相关性分数 0-5, 评估理由) """ prompt = f"""请评估以下文档与用户查询的相关性,给出 0-5 的评分: 用户查询:{query} 文档内容:{document.page_content[:500]} 请按以下标准评分: 5 = 完全相关,文档直接回答了用户查询 4 = 高度相关,文档包含回答查询的关键信息 3 = 部分相关,文档有一些相关信息但不够直接 2 = 弱相关,文档有少量提及但不太相关 1 = 不相关,文档内容与查询基本无关 0 = 完全无关 请只返回 JSON 格式,例如:{{"score": 4, "reason": "文档详细解释了用户查询的概念"}}""" try: response = await self.llm.ainvoke(prompt) result_text = response.content if hasattr(response, 'content') else str(response) # 尝试解析 JSON import json result = json.loads(result_text) score = result.get('score', 0.0) reason = result.get('reason', '无理由') # 确保分数在 0-5 范围内 score = max(0.0, min(5.0, float(score))) return score, reason except Exception as e: return 0.0, f"评估失败:{str(e)}" def generate_test_report(metrics: RetrievalMetrics) -> str: """生成测试报告""" report = [] report.append("=" * 80) report.append("RAG 系统评估报告") report.append("=" * 80) report.append("") # 召回率 report.append("【召回率 Recall@k】") for k, v in sorted(metrics.recall_at_k.items()): report.append(f" Recall@{k}: {v:.2%}") report.append("") # 精确率 report.append("【精确率 Precision@k】") for k, v in sorted(metrics.precision_at_k.items()): report.append(f" Precision@{k}: {v:.2%}") report.append("") # F1 分数 report.append("【F1 分数 F1@k】") for k, v in sorted(metrics.f1_at_k.items()): report.append(f" F1@{k}: {v:.4f}") report.append("") # MRR report.append(f"【平均倒数排名 MRR】: {metrics.mrr:.4f}") report.append("") # 解释 report.append("=" * 80) report.append("指标说明:") report.append("- Recall@k: 前 k 个结果中包含多少比例的相关文档") report.append("- Precision@k: 前 k 个结果中有多少比例是相关文档") report.append("- F1@k: 召回率和精确率的调和平均数") report.append("- MRR: 第一个相关文档的排名的倒数的平均值") report.append("=" * 80) return "\n".join(report) # 示例使用 def create_sample_test_cases() -> List[RetrievalTestCase]: """创建示例测试用例""" return [ RetrievalTestCase( query="什么是 RAG 系统?", relevant_doc_ids=["doc_rag_1", "doc_rag_2"], expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写..." ), RetrievalTestCase( query="如何使用 LangChain?", relevant_doc_ids=["doc_langchain_1", "doc_langchain_2", "doc_langchain_3"], expected_answer="LangChain 的使用步骤包括..." ), ] if __name__ == "__main__": # 示例:如何使用评估器 print("RAG 评估模块已加载") print("使用方法:") print(" 1. 创建测试用例") print(" 2. 初始化 RAGEvaluator") print(" 3. 调用 evaluate_retrieval()") print(" 4. 生成报告")