diff --git a/README.md b/README.md index 431f50f..606bd23 100644 --- a/README.md +++ b/README.md @@ -664,6 +664,31 @@ def reciprocal_rank_fusion(doc_lists: List[List[Document]], k: int = 60) -> List - 兼容 OpenAI Rerank API 格式 - 超时保护:60 秒超时,失败时降级为原始排序 +--- + +### 1.5 RAG 评估方法 ⭐ + +如何评估 RAG 系统的召回率和相关性? + +**核心指标:** +- **Recall@k**:前 k 个结果中包含多少比例的相关文档 +- **Precision@k**:前 k 个结果中有多少比例是相关文档 +- **F1@k**:召回率和精确率的调和平均数 +- **MRR**:平均倒数排名 +- **相关性评分**:0-5 分的相关性评估 + +**详细指南:** +参见 [backend/docs/RAG_EVALUATION_GUIDE.md](backend/docs/RAG_EVALUATION_GUIDE.md) + +**快速使用:** +```bash +# 运行评估脚本 +cd backend +python scripts/evaluate_rag.py +``` + +--- + ### 2. LangGraph 工作流算法 #### 2.1 React (Reasoning → Acting → Observing) 模式 ⭐ diff --git a/backend/app/rag/evaluate.py b/backend/app/rag/evaluate.py new file mode 100644 index 0000000..68f9f98 --- /dev/null +++ b/backend/app/rag/evaluate.py @@ -0,0 +1,333 @@ +""" +RAG 评估模块 +用于计算 RAG 系统的召回率、相关性、准确率等指标 +""" + +import asyncio +import json +from typing import List, Dict, Tuple, Optional +from dataclasses import dataclass +from langchain_core.documents import Document + + +@dataclass +class RetrievalTestCase: + """检索测试用例""" + query: str # 用户查询 + relevant_doc_ids: List[str] # 相关文档 ID 列表 + expected_answer: Optional[str] = None # 期望的答案(可选) + + +@dataclass +class RetrievalMetrics: + """检索评估指标""" + recall_at_k: Dict[int, float] # Recall@k,例如 {1: 0.8, 3: 0.9, 5: 1.0} + precision_at_k: Dict[int, float] # Precision@k + f1_at_k: Dict[int, float] # F1@k + mrr: float # 平均倒数排名 + ndcg_at_k: Dict[int, float] # NDCG@k + relevance_scores: List[float] # 每个测试用例的相关性评分 + + +class RAGEvaluator: + """RAG 评估器""" + + def __init__(self, rag_pipeline, test_cases: List[RetrievalTestCase]): + """ + 初始化评估器 + + Args: + rag_pipeline: RAG 流水线对象(需实现 aretrieve 方法) + test_cases: 测试用例列表 + """ + self.rag_pipeline = rag_pipeline + self.test_cases = test_cases + + async def evaluate_retrieval(self, k_list: List[int] = None) -> RetrievalMetrics: + """ + 评估检索质量 + + Args: + k_list: 要计算的 k 值列表,例如 [1, 3, 5] + + Returns: + 检索评估指标 + """ + if k_list is None: + k_list = [1, 3, 5, 10] + + all_results = [] + all_mrr = [] + + for test_case in self.test_cases: + # 执行检索 + retrieved_docs = await self.rag_pipeline.aretrieve(test_case.query) + retrieved_ids = [doc.metadata.get("id", doc.page_content[:50]) for doc in retrieved_docs] + + # 计算召回率和精确率 + result = self._calculate_retrieval_metrics( + retrieved_ids, + test_case.relevant_doc_ids, + k_list + ) + all_results.append(result) + + # 计算 MRR + mrr = self._calculate_mrr(retrieved_ids, test_case.relevant_doc_ids) + all_mrr.append(mrr) + + # 聚合所有测试用例的结果 + metrics = self._aggregate_metrics(all_results, all_mrr, k_list) + return metrics + + def _calculate_retrieval_metrics( + self, + retrieved_ids: List[str], + relevant_ids: List[str], + k_list: List[int] + ) -> Dict[int, Dict[str, float]]: + """ + 计算单个测试用例的检索指标 + + Returns: + {k: {'recall': float, 'precision': float, 'f1': float}} + """ + results = {} + + for k in k_list: + # 取前 k 个结果 + top_k = retrieved_ids[:k] + + # 计算交集 + relevant_in_top_k = set(top_k) & set(relevant_ids) + num_relevant_in_top_k = len(relevant_in_top_k) + + # 召回率 = 相关文档在 top k 中的数量 / 总相关文档数量 + recall = num_relevant_in_top_k / len(relevant_ids) if relevant_ids else 0.0 + + # 精确率 = 相关文档在 top k 中的数量 / k + precision = num_relevant_in_top_k / k if k > 0 else 0.0 + + # F1 分数 + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0 + + results[k] = { + 'recall': recall, + 'precision': precision, + 'f1': f1 + } + + return results + + def _calculate_mrr(self, retrieved_ids: List[str], relevant_ids: List[str]) -> float: + """ + 计算平均倒数排名 (Mean Reciprocal Rank) + + MRR@k = 1/m * sum(1/rank_i for i=1..m) + 其中 rank_i 是第 i 个相关文档第一次出现的排名 + """ + for rank, doc_id in enumerate(retrieved_ids, start=1): + if doc_id in relevant_ids: + return 1.0 / rank + return 0.0 + + def _calculate_ndcg( + self, + retrieved_ids: List[str], + relevant_ids: List[str], + k: int + ) -> float: + """ + 计算 NDCG@k (Normalized Discounted Cumulative Gain) + + DCG@k = sum(relevance_i / log2(i+1) for i=1..k) + NDCG@k = DCG@k / IDCG@k + """ + top_k = retrieved_ids[:k] + + # 计算 DCG + dcg = 0.0 + for i, doc_id in enumerate(top_k, start=1): + relevance = 1.0 if doc_id in relevant_ids else 0.0 + dcg += relevance / (i.bit_length() - 1) # log2(i) + + # 计算 IDCG(理想 DCG) + ideal_relevance = [1.0] * min(len(relevant_ids), k) + idcg = 0.0 + for i, rel in enumerate(ideal_relevance, start=1): + idcg += rel / (i.bit_length() - 1) + + return dcg / idcg if idcg > 0 else 0.0 + + def _aggregate_metrics( + self, + all_results: List[Dict[int, Dict[str, float]]], + all_mrr: List[float], + k_list: List[int] + ) -> RetrievalMetrics: + """聚合所有测试用例的指标""" + + recall_at_k = {} + precision_at_k = {} + f1_at_k = {} + ndcg_at_k = {} + + for k in k_list: + # 聚合召回率 + recalls = [result[k]['recall'] for result in all_results] + recall_at_k[k] = sum(recalls) / len(recalls) + + # 聚合精确率 + precisions = [result[k]['precision'] for result in all_results] + precision_at_k[k] = sum(precisions) / len(precisions) + + # 聚合 F1 + f1s = [result[k]['f1'] for result in all_results] + f1_at_k[k] = sum(f1s) / len(f1s) + + # 计算 NDCG(这里简化处理) + ndcg_at_k[k] = sum(f1s) / len(f1s) # 用 F1 近似 + + # 计算 MRR + mrr = sum(all_mrr) / len(all_mrr) + + return RetrievalMetrics( + recall_at_k=recall_at_k, + precision_at_k=precision_at_k, + f1_at_k=f1_at_k, + mrr=mrr, + ndcg_at_k=ndcg_at_k, + relevance_scores=[1.0] * len(all_results) # 占位符 + ) + + +class RelevanceEvaluator: + """相关性评估器(基于 LLM 评估)""" + + def __init__(self, llm): + """ + 初始化相关性评估器 + + Args: + llm: 用于评估相关性的语言模型 + """ + self.llm = llm + + async def evaluate_relevance( + self, + query: str, + document: Document + ) -> Tuple[float, str]: + """ + 评估文档与查询的相关性 + + Args: + query: 用户查询 + document: 文档对象 + + Returns: + (相关性分数 0-5, 评估理由) + """ + prompt = f"""请评估以下文档与用户查询的相关性,给出 0-5 的评分: + +用户查询:{query} + +文档内容:{document.page_content[:500]} + +请按以下标准评分: +5 = 完全相关,文档直接回答了用户查询 +4 = 高度相关,文档包含回答查询的关键信息 +3 = 部分相关,文档有一些相关信息但不够直接 +2 = 弱相关,文档有少量提及但不太相关 +1 = 不相关,文档内容与查询基本无关 +0 = 完全无关 + +请只返回 JSON 格式,例如:{{"score": 4, "reason": "文档详细解释了用户查询的概念"}}""" + + try: + response = await self.llm.ainvoke(prompt) + result_text = response.content if hasattr(response, 'content') else str(response) + + # 尝试解析 JSON + import json + result = json.loads(result_text) + score = result.get('score', 0.0) + reason = result.get('reason', '无理由') + + # 确保分数在 0-5 范围内 + score = max(0.0, min(5.0, float(score))) + + return score, reason + + except Exception as e: + return 0.0, f"评估失败:{str(e)}" + + +def generate_test_report(metrics: RetrievalMetrics) -> str: + """生成测试报告""" + + report = [] + report.append("=" * 80) + report.append("RAG 系统评估报告") + report.append("=" * 80) + report.append("") + + # 召回率 + report.append("【召回率 Recall@k】") + for k, v in sorted(metrics.recall_at_k.items()): + report.append(f" Recall@{k}: {v:.2%}") + report.append("") + + # 精确率 + report.append("【精确率 Precision@k】") + for k, v in sorted(metrics.precision_at_k.items()): + report.append(f" Precision@{k}: {v:.2%}") + report.append("") + + # F1 分数 + report.append("【F1 分数 F1@k】") + for k, v in sorted(metrics.f1_at_k.items()): + report.append(f" F1@{k}: {v:.4f}") + report.append("") + + # MRR + report.append(f"【平均倒数排名 MRR】: {metrics.mrr:.4f}") + report.append("") + + # 解释 + report.append("=" * 80) + report.append("指标说明:") + report.append("- Recall@k: 前 k 个结果中包含多少比例的相关文档") + report.append("- Precision@k: 前 k 个结果中有多少比例是相关文档") + report.append("- F1@k: 召回率和精确率的调和平均数") + report.append("- MRR: 第一个相关文档的排名的倒数的平均值") + report.append("=" * 80) + + return "\n".join(report) + + +# 示例使用 +def create_sample_test_cases() -> List[RetrievalTestCase]: + """创建示例测试用例""" + return [ + RetrievalTestCase( + query="什么是 RAG 系统?", + relevant_doc_ids=["doc_rag_1", "doc_rag_2"], + expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写..." + ), + RetrievalTestCase( + query="如何使用 LangChain?", + relevant_doc_ids=["doc_langchain_1", "doc_langchain_2", "doc_langchain_3"], + expected_answer="LangChain 的使用步骤包括..." + ), + ] + + +if __name__ == "__main__": + # 示例:如何使用评估器 + print("RAG 评估模块已加载") + print("使用方法:") + print(" 1. 创建测试用例") + print(" 2. 初始化 RAGEvaluator") + print(" 3. 调用 evaluate_retrieval()") + print(" 4. 生成报告") diff --git a/backend/docs/RAG_EVALUATION_GUIDE.md b/backend/docs/RAG_EVALUATION_GUIDE.md new file mode 100644 index 0000000..3a2510a --- /dev/null +++ b/backend/docs/RAG_EVALUATION_GUIDE.md @@ -0,0 +1,363 @@ +# RAG 召回率与相关性评估指南 + +本指南介绍如何评估 RAG 系统的召回率(Recall)和相关性(Relevance)。 + +--- + +## 📊 核心概念 + +### 1. 召回率 (Recall) + +召回率衡量的是:**在所有相关文档中,有多少被检索出来了?** + +``` +Recall@k = (前 k 个结果中的相关文档数量) / (总相关文档数量) +``` + +例如: +- 总共有 5 篇相关文档 +- 检索返回 10 篇,其中 3 篇是相关的 +- Recall@10 = 3/5 = 60% + +### 2. 精确率 (Precision) + +精确率衡量的是:**在检索出来的文档中,有多少是相关的?** + +``` +Precision@k = (前 k 个结果中的相关文档数量) / k +``` + +例如: +- 检索返回 10 篇,其中 3 篇是相关的 +- Precision@10 = 3/10 = 30% + +### 3. F1 分数 (F1 Score) + +F1 分数是召回率和精确率的调和平均数: + +``` +F1@k = 2 * Recall@k * Precision@k / (Recall@k + Precision@k) +``` + +### 4. 平均倒数排名 (MRR) + +MRR 衡量第一个相关文档的排名: + +``` +MRR = 1/m * sum(1/rank_i for i=1..m) +``` + +其中 rank_i 是第 i 个相关文档第一次出现的排名。 + +例如: +- 测试用例 1:第一个相关文档在第 2 位 → 1/2 = 0.5 +- 测试用例 2:第一个相关文档在第 1 位 → 1/1 = 1.0 +- 测试用例 3:第一个相关文档在第 3 位 → 1/3 ≈ 0.333 +- MRR = (0.5 + 1.0 + 0.333) / 3 ≈ 0.611 + +### 5. 相关性评分 + +相关性评分评估检索到的文档与查询的相关程度,通常使用: +- 人工标注(Human Evaluation) +- LLM 评估(LLM-as-a-Judge) +- 相关性模型(Cross-Encoder) + +--- + +## 🛠️ 如何评估 + +### 方法一:使用内置评估模块 + +我们的项目已经内置了评估模块 `app.rag.evaluate`。 + +#### 1. 准备测试用例 + +首先,需要准备带有标注的测试用例: + +```python +from app.rag.evaluate import RetrievalTestCase + +test_cases = [ + RetrievalTestCase( + query="什么是 RAG 系统?", + relevant_doc_ids=["doc_rag_1", "doc_rag_2", "doc_rag_3"], + expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写..." + ), + RetrievalTestCase( + query="如何使用 LangChain?", + relevant_doc_ids=["doc_langchain_1", "doc_langchain_2"], + expected_answer="LangChain 的使用步骤包括..." + ), + # 更多测试用例... +] +``` + +**重要提示:** +- 每个查询需要知道哪些文档是相关的 +- 相关文档需要有唯一的 ID +- expected_answer 是可选的,用于评估答案质量 + +#### 2. 运行评估 + +```python +import asyncio +from app.rag.evaluate import RAGEvaluator, generate_test_report + +# 初始化评估器 +evaluator = RAGEvaluator(rag_pipeline, test_cases) + +# 运行评估 +metrics = asyncio.run(evaluator.evaluate_retrieval(k_list=[1, 3, 5, 10])) + +# 生成报告 +report = generate_test_report(metrics) +print(report) +``` + +#### 3. 运行示例脚本 + +```bash +cd backend +python scripts/evaluate_rag.py +``` + +--- + +### 方法二:手动计算召回率 + +如果你想手动计算,步骤如下: + +#### 步骤 1:准备测试数据 + +准备一个测试查询列表,每个查询对应相关文档的 ID: + +```python +test_queries = [ + { + "query": "什么是 RAG?", + "relevant_ids": ["doc1", "doc3", "doc5"] + }, + { + "query": "如何优化 RAG?", + "relevant_ids": ["doc2", "doc4"] + } +] +``` + +#### 步骤 2:运行检索 + +对于每个查询,运行 RAG 检索,记录返回的文档 ID: + +```python +def run_retrieval(query): + """运行检索,返回文档 ID 列表""" + docs = rag_pipeline.retrieve(query) + return [doc.metadata["id"] for doc in docs] +``` + +#### 步骤 3:计算召回率 + +```python +def calculate_recall(retrieved_ids, relevant_ids, k): + """计算 Recall@k""" + top_k = retrieved_ids[:k] + relevant_in_top_k = set(top_k) & set(relevant_ids) + recall = len(relevant_in_top_k) / len(relevant_ids) + return recall + +# 示例 +retrieved = ["doc1", "doc2", "doc3", "doc4", "doc5"] +relevant = ["doc1", "doc3", "doc5"] +print(f"Recall@3: {calculate_recall(retrieved, relevant, k=3):.2%}") # 2/3 = 66.67% +print(f"Recall@5: {calculate_recall(retrieved, relevant, k=5):.2%}") # 3/3 = 100% +``` + +#### 步骤 4:聚合结果 + +```python +import numpy as np + +all_recalls_at_1 = [] +all_recalls_at_3 = [] +all_recalls_at_5 = [] + +for test_case in test_queries: + retrieved = run_retrieval(test_case["query"]) + recall_1 = calculate_recall(retrieved, test_case["relevant_ids"], k=1) + recall_3 = calculate_recall(retrieved, test_case["relevant_ids"], k=3) + recall_5 = calculate_recall(retrieved, test_case["relevant_ids"], k=5) + + all_recalls_at_1.append(recall_1) + all_recalls_at_3.append(recall_3) + all_recalls_at_5.append(recall_5) + +print(f"Average Recall@1: {np.mean(all_recalls_at_1):.2%}") +print(f"Average Recall@3: {np.mean(all_recalls_at_3):.2%}") +print(f"Average Recall@5: {np.mean(all_recalls_at_5):.2%}") +``` + +--- + +### 方法三:评估相关性 + +评估相关性有几种方法: + +#### 方案 A:使用 LLM 评估(LLM-as-a-Judge) + +```python +from app.rag.evaluate import RelevanceEvaluator + +# 初始化评估器 +evaluator = RelevanceEvaluator(llm) + +# 评估相关性 +score, reason = asyncio.run(evaluator.evaluate_relevance(query, document)) + +print(f"相关性评分: {score}/5") +print(f"理由: {reason}") +``` + +#### 方案 B:使用重排模型评分 + +重排模型本身可以给出相关性分数: + +```python +from app.model_services import get_rerank_service + +rerank_service = get_rerank_service() + +# 获取相关性分数 +scores = rerank_service.compute_scores( + query="什么是 RAG?", + documents=["doc1", "doc2", "doc3"] +) +``` + +#### 方案 C:人工标注 + +最准确但也最耗时的方法是让人工标注相关性: + +```python +# 相关性评分标准 +relevance_levels = { + 5: "完全相关,直接回答了问题", + 4: "高度相关,包含关键信息", + 3: "部分相关,有一些相关信息", + 2: "弱相关,提及但不太相关", + 1: "不相关,基本无关", + 0: "完全无关" +} +``` + +--- + +## 📈 如何解释结果 + +### 召回率低怎么办? + +如果 Recall@k 低,可能的原因: + +1. **检索器召回能力不足** + - 嵌入模型不合适 + - 检索算法太简单 + - 解决方案:改用更好的嵌入模型、使用混合检索 + +2. **查询理解不够** + - 查询改写效果不好 + - 解决方案:增加查询改写的多样性 + +3. **文档分块策略不好** + - 分块太小/太大 + - 解决方案:调整 chunk_size,使用父子分块 + +### 精确率低怎么办? + +如果 Precision@k 低,可能的原因: + +1. **检索结果噪声多** + - 解决方案:加强重排序 + +2. **文档切分有问题** + - 不相关的片段也被检索到 + - 解决方案:改进切分策略 + +--- + +## 🎯 评估最佳实践 + +### 1. 测试用例构建 + +- ✅ **覆盖多样的查询类型**:事实型、概念型、操作型 +- ✅ **每个查询有多个相关文档**:避免单点依赖 +- ✅ **包含难例**:测试边界情况 +- ✅ **定期更新**:随着知识库变化更新测试用例 + +### 2. 评估指标选择 + +- **快速迭代**:关注 Recall@3, Recall@5 +- **正式发布**:完整评估所有指标 +- **用户体验**:同时评估答案质量 + +### 3. A/B 测试 + +当你改进 RAG 系统时,使用 A/B 测试: + +```python +# A 版本(旧版本) +metrics_a = evaluator.evaluate_retrieval() + +# B 版本(新版本) +metrics_b = evaluator_new.evaluate_retrieval() + +# 对比 +print(f"Recall@5 改进: {metrics_b.recall_at_k[5] - metrics_a.recall_at_k[5]:.2%}") +``` + +--- + +## 📝 完整评估报告示例 + +运行评估后,会生成这样的报告: + +``` +================================================================================ +RAG 系统评估报告 +================================================================================ + +【召回率 Recall@k】 + Recall@1: 60.00% + Recall@3: 85.00% + Recall@5: 95.00% + Recall@10: 100.00% + +【精确率 Precision@k】 + Precision@1: 100.00% + Precision@3: 90.00% + Precision@5: 80.00% + Precision@10: 55.00% + +【F1 分数 F1@k】 + F1@1: 0.7500 + F1@3: 0.8718 + F1@5: 0.8636 + F1@10: 0.7097 + +【平均倒数排名 MRR】: 0.8500 + +================================================================================ +指标说明: +- Recall@k: 前 k 个结果中包含多少比例的相关文档 +- Precision@k: 前 k 个结果中有多少比例是相关文档 +- F1@k: 召回率和精确率的调和平均数 +- MRR: 第一个相关文档的排名的倒数的平均值 +================================================================================ +``` + +--- + +## 🔗 相关文件 + +- `backend/app/rag/evaluate.py` - 评估模块 +- `backend/scripts/evaluate_rag.py` - 评估示例脚本 +- `backend/app/rag/pipeline.py` - RAG 流水线 +- `backend/app/model_services/` - 模型服务 diff --git a/backend/scripts/evaluate_rag.py b/backend/scripts/evaluate_rag.py new file mode 100644 index 0000000..1dea3ed --- /dev/null +++ b/backend/scripts/evaluate_rag.py @@ -0,0 +1,143 @@ +""" +RAG 评估示例脚本 +演示如何使用 RAGEvaluator 评估召回率和相关性 +""" + +import asyncio +import sys +import os + +# 添加项目路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from app.rag.evaluate import ( + RAGEvaluator, + RelevanceEvaluator, + RetrievalTestCase, + generate_test_report +) +from app.rag.pipeline import RAGPipeline +from app.model_services import get_chat_service, get_embedding_service + + +async def main(): + print("=" * 80) + print("RAG 系统评估示例") + print("=" * 80) + print() + + # 1. 准备测试用例 + print("【1/4】准备测试用例...") + test_cases = [ + RetrievalTestCase( + query="什么是 RAG 系统?", + relevant_doc_ids=["doc_rag_1", "doc_rag_2", "doc_rag_3"], + expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写,是一种结合检索和生成的技术..." + ), + RetrievalTestCase( + query="如何使用 LangChain 构建 RAG?", + relevant_doc_ids=["doc_langchain_1", "doc_langchain_2"], + expected_answer="使用 LangChain 构建 RAG 的步骤包括:1) 准备文档 2) 向量化 3) 构建检索器 4) 组合生成..." + ), + RetrievalTestCase( + query="什么是向量数据库?", + relevant_doc_ids=["doc_vector_db_1", "doc_qdrant_1"], + expected_answer="向量数据库是专门用于存储和检索向量嵌入的数据库,如 Qdrant、Pinecone 等..." + ), + RetrievalTestCase( + query="如何优化 RAG 的检索质量?", + relevant_doc_ids=["doc_optimize_1", "doc_rerank_1", "doc_fusion_1"], + expected_answer="优化 RAG 检索质量的方法包括:重排序、查询改写、结果融合、混合检索等..." + ), + RetrievalTestCase( + query="LangGraph 是什么?", + relevant_doc_ids=["doc_langgraph_1"], + expected_answer="LangGraph 是 LangChain 的扩展,用于构建状态感知的多步工作流..." + ), + ] + print(f" 已加载 {len(test_cases)} 个测试用例") + print() + + # 2. 初始化 RAG 系统(这里使用模拟) + print("【2/4】初始化 RAG 系统...") + + # 注意:实际使用时,这里应该初始化真实的 RAGPipeline + # 这里为了演示,我们创建一个模拟的 RAG 类 + class MockRAGPipeline: + def __init__(self): + # 模拟的文档库 + self.mock_docs = { + "doc_rag_1": "RAG 是 Retrieval-Augmented Generation 的缩写...", + "doc_rag_2": "RAG 系统由检索器和生成器两部分组成...", + "doc_rag_3": "RAG 的工作流程是:查询 -> 检索 -> 生成...", + "doc_langchain_1": "LangChain 是用于构建 LLM 应用的框架...", + "doc_langchain_2": "LangChain 提供了多种工具和集成...", + "doc_vector_db_1": "向量数据库用于存储向量嵌入...", + "doc_qdrant_1": "Qdrant 是一个开源的向量数据库...", + "doc_optimize_1": "RAG 优化方法包括重排序和查询改写...", + "doc_rerank_1": "重排序使用 Cross-Encoder 重新排序检索结果...", + "doc_fusion_1": "结果融合使用 RRF 算法合并多个检索结果...", + "doc_langgraph_1": "LangGraph 用于构建状态机工作流...", + } + + async def aretrieve(self, query: str): + """模拟检索,返回相关文档""" + from langchain_core.documents import Document + + # 简单的关键词匹配模拟 + results = [] + for doc_id, content in self.mock_docs.items(): + if any(keyword in query.lower() for keyword in ["rag", "检索"]): + if "rag" in doc_id.lower(): + results.append(Document(page_content=content, metadata={"id": doc_id})) + elif any(keyword in query.lower() for keyword in ["langchain", "构建"]): + if "langchain" in doc_id.lower(): + results.append(Document(page_content=content, metadata={"id": doc_id})) + elif any(keyword in query.lower() for keyword in ["向量", "数据库", "qdrant"]): + if "vector" in doc_id.lower() or "qdrant" in doc_id.lower(): + results.append(Document(page_content=content, metadata={"id": doc_id})) + elif any(keyword in query.lower() for keyword in ["优化", "重排", "融合"]): + if "optimize" in doc_id.lower() or "rerank" in doc_id.lower() or "fusion" in doc_id.lower(): + results.append(Document(page_content=content, metadata={"id": doc_id})) + elif any(keyword in query.lower() for keyword in ["langgraph"]): + if "langgraph" in doc_id.lower(): + results.append(Document(page_content=content, metadata={"id": doc_id})) + + # 如果没有匹配到,返回一些通用结果 + if not results: + for doc_id, content in list(self.mock_docs.items())[:3]: + results.append(Document(page_content=content, metadata={"id": doc_id})) + + return results + + rag_pipeline = MockRAGPipeline() + print(" RAG 系统已初始化(模拟)") + print() + + # 3. 评估检索质量 + print("【3/4】评估检索质量...") + evaluator = RAGEvaluator(rag_pipeline, test_cases) + metrics = await evaluator.evaluate_retrieval(k_list=[1, 3, 5, 10]) + print(" 评估完成") + print() + + # 4. 生成报告 + print("【4/4】生成评估报告...") + report = generate_test_report(metrics) + print(report) + print() + + # 5. 保存报告 + report_file = os.path.join(os.path.dirname(__file__), 'rag_evaluation_report.txt') + with open(report_file, 'w', encoding='utf-8') as f: + f.write(report) + print(f" 报告已保存到:{report_file}") + print() + + print("=" * 80) + print("评估完成!") + print("=" * 80) + + +if __name__ == "__main__": + asyncio.run(main())