feat: 添加 RAG 评估模块,支持召回率和相关性评估
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m13s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m13s
This commit is contained in:
143
backend/scripts/evaluate_rag.py
Normal file
143
backend/scripts/evaluate_rag.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
RAG 评估示例脚本
|
||||
演示如何使用 RAGEvaluator 评估召回率和相关性
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from app.rag.evaluate import (
|
||||
RAGEvaluator,
|
||||
RelevanceEvaluator,
|
||||
RetrievalTestCase,
|
||||
generate_test_report
|
||||
)
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
from app.model_services import get_chat_service, get_embedding_service
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 80)
|
||||
print("RAG 系统评估示例")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# 1. 准备测试用例
|
||||
print("【1/4】准备测试用例...")
|
||||
test_cases = [
|
||||
RetrievalTestCase(
|
||||
query="什么是 RAG 系统?",
|
||||
relevant_doc_ids=["doc_rag_1", "doc_rag_2", "doc_rag_3"],
|
||||
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写,是一种结合检索和生成的技术..."
|
||||
),
|
||||
RetrievalTestCase(
|
||||
query="如何使用 LangChain 构建 RAG?",
|
||||
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2"],
|
||||
expected_answer="使用 LangChain 构建 RAG 的步骤包括:1) 准备文档 2) 向量化 3) 构建检索器 4) 组合生成..."
|
||||
),
|
||||
RetrievalTestCase(
|
||||
query="什么是向量数据库?",
|
||||
relevant_doc_ids=["doc_vector_db_1", "doc_qdrant_1"],
|
||||
expected_answer="向量数据库是专门用于存储和检索向量嵌入的数据库,如 Qdrant、Pinecone 等..."
|
||||
),
|
||||
RetrievalTestCase(
|
||||
query="如何优化 RAG 的检索质量?",
|
||||
relevant_doc_ids=["doc_optimize_1", "doc_rerank_1", "doc_fusion_1"],
|
||||
expected_answer="优化 RAG 检索质量的方法包括:重排序、查询改写、结果融合、混合检索等..."
|
||||
),
|
||||
RetrievalTestCase(
|
||||
query="LangGraph 是什么?",
|
||||
relevant_doc_ids=["doc_langgraph_1"],
|
||||
expected_answer="LangGraph 是 LangChain 的扩展,用于构建状态感知的多步工作流..."
|
||||
),
|
||||
]
|
||||
print(f" 已加载 {len(test_cases)} 个测试用例")
|
||||
print()
|
||||
|
||||
# 2. 初始化 RAG 系统(这里使用模拟)
|
||||
print("【2/4】初始化 RAG 系统...")
|
||||
|
||||
# 注意:实际使用时,这里应该初始化真实的 RAGPipeline
|
||||
# 这里为了演示,我们创建一个模拟的 RAG 类
|
||||
class MockRAGPipeline:
|
||||
def __init__(self):
|
||||
# 模拟的文档库
|
||||
self.mock_docs = {
|
||||
"doc_rag_1": "RAG 是 Retrieval-Augmented Generation 的缩写...",
|
||||
"doc_rag_2": "RAG 系统由检索器和生成器两部分组成...",
|
||||
"doc_rag_3": "RAG 的工作流程是:查询 -> 检索 -> 生成...",
|
||||
"doc_langchain_1": "LangChain 是用于构建 LLM 应用的框架...",
|
||||
"doc_langchain_2": "LangChain 提供了多种工具和集成...",
|
||||
"doc_vector_db_1": "向量数据库用于存储向量嵌入...",
|
||||
"doc_qdrant_1": "Qdrant 是一个开源的向量数据库...",
|
||||
"doc_optimize_1": "RAG 优化方法包括重排序和查询改写...",
|
||||
"doc_rerank_1": "重排序使用 Cross-Encoder 重新排序检索结果...",
|
||||
"doc_fusion_1": "结果融合使用 RRF 算法合并多个检索结果...",
|
||||
"doc_langgraph_1": "LangGraph 用于构建状态机工作流...",
|
||||
}
|
||||
|
||||
async def aretrieve(self, query: str):
|
||||
"""模拟检索,返回相关文档"""
|
||||
from langchain_core.documents import Document
|
||||
|
||||
# 简单的关键词匹配模拟
|
||||
results = []
|
||||
for doc_id, content in self.mock_docs.items():
|
||||
if any(keyword in query.lower() for keyword in ["rag", "检索"]):
|
||||
if "rag" in doc_id.lower():
|
||||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||||
elif any(keyword in query.lower() for keyword in ["langchain", "构建"]):
|
||||
if "langchain" in doc_id.lower():
|
||||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||||
elif any(keyword in query.lower() for keyword in ["向量", "数据库", "qdrant"]):
|
||||
if "vector" in doc_id.lower() or "qdrant" in doc_id.lower():
|
||||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||||
elif any(keyword in query.lower() for keyword in ["优化", "重排", "融合"]):
|
||||
if "optimize" in doc_id.lower() or "rerank" in doc_id.lower() or "fusion" in doc_id.lower():
|
||||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||||
elif any(keyword in query.lower() for keyword in ["langgraph"]):
|
||||
if "langgraph" in doc_id.lower():
|
||||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||||
|
||||
# 如果没有匹配到,返回一些通用结果
|
||||
if not results:
|
||||
for doc_id, content in list(self.mock_docs.items())[:3]:
|
||||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||||
|
||||
return results
|
||||
|
||||
rag_pipeline = MockRAGPipeline()
|
||||
print(" RAG 系统已初始化(模拟)")
|
||||
print()
|
||||
|
||||
# 3. 评估检索质量
|
||||
print("【3/4】评估检索质量...")
|
||||
evaluator = RAGEvaluator(rag_pipeline, test_cases)
|
||||
metrics = await evaluator.evaluate_retrieval(k_list=[1, 3, 5, 10])
|
||||
print(" 评估完成")
|
||||
print()
|
||||
|
||||
# 4. 生成报告
|
||||
print("【4/4】生成评估报告...")
|
||||
report = generate_test_report(metrics)
|
||||
print(report)
|
||||
print()
|
||||
|
||||
# 5. 保存报告
|
||||
report_file = os.path.join(os.path.dirname(__file__), 'rag_evaluation_report.txt')
|
||||
with open(report_file, 'w', encoding='utf-8') as f:
|
||||
f.write(report)
|
||||
print(f" 报告已保存到:{report_file}")
|
||||
print()
|
||||
|
||||
print("=" * 80)
|
||||
print("评估完成!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user