Files
ailine/backend/scripts/evaluate_rag.py

144 lines
6.0 KiB
Python
Raw Normal View History

"""
RAG 评估示例脚本
演示如何使用 RAGEvaluator 评估召回率和相关性
"""
import asyncio
import sys
import os
# 添加项目路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from app.rag.evaluate import (
RAGEvaluator,
RelevanceEvaluator,
RetrievalTestCase,
generate_test_report
)
from app.rag.pipeline import RAGPipeline
from app.model_services import get_chat_service, get_embedding_service
async def main():
print("=" * 80)
print("RAG 系统评估示例")
print("=" * 80)
print()
# 1. 准备测试用例
print("【1/4】准备测试用例...")
test_cases = [
RetrievalTestCase(
query="什么是 RAG 系统?",
relevant_doc_ids=["doc_rag_1", "doc_rag_2", "doc_rag_3"],
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写,是一种结合检索和生成的技术..."
),
RetrievalTestCase(
query="如何使用 LangChain 构建 RAG",
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2"],
expected_answer="使用 LangChain 构建 RAG 的步骤包括1) 准备文档 2) 向量化 3) 构建检索器 4) 组合生成..."
),
RetrievalTestCase(
query="什么是向量数据库?",
relevant_doc_ids=["doc_vector_db_1", "doc_qdrant_1"],
expected_answer="向量数据库是专门用于存储和检索向量嵌入的数据库,如 Qdrant、Pinecone 等..."
),
RetrievalTestCase(
query="如何优化 RAG 的检索质量?",
relevant_doc_ids=["doc_optimize_1", "doc_rerank_1", "doc_fusion_1"],
expected_answer="优化 RAG 检索质量的方法包括:重排序、查询改写、结果融合、混合检索等..."
),
RetrievalTestCase(
query="LangGraph 是什么?",
relevant_doc_ids=["doc_langgraph_1"],
expected_answer="LangGraph 是 LangChain 的扩展,用于构建状态感知的多步工作流..."
),
]
print(f" 已加载 {len(test_cases)} 个测试用例")
print()
# 2. 初始化 RAG 系统(这里使用模拟)
print("【2/4】初始化 RAG 系统...")
# 注意:实际使用时,这里应该初始化真实的 RAGPipeline
# 这里为了演示,我们创建一个模拟的 RAG 类
class MockRAGPipeline:
def __init__(self):
# 模拟的文档库
self.mock_docs = {
"doc_rag_1": "RAG 是 Retrieval-Augmented Generation 的缩写...",
"doc_rag_2": "RAG 系统由检索器和生成器两部分组成...",
"doc_rag_3": "RAG 的工作流程是:查询 -> 检索 -> 生成...",
"doc_langchain_1": "LangChain 是用于构建 LLM 应用的框架...",
"doc_langchain_2": "LangChain 提供了多种工具和集成...",
"doc_vector_db_1": "向量数据库用于存储向量嵌入...",
"doc_qdrant_1": "Qdrant 是一个开源的向量数据库...",
"doc_optimize_1": "RAG 优化方法包括重排序和查询改写...",
"doc_rerank_1": "重排序使用 Cross-Encoder 重新排序检索结果...",
"doc_fusion_1": "结果融合使用 RRF 算法合并多个检索结果...",
"doc_langgraph_1": "LangGraph 用于构建状态机工作流...",
}
async def aretrieve(self, query: str):
"""模拟检索,返回相关文档"""
from langchain_core.documents import Document
# 简单的关键词匹配模拟
results = []
for doc_id, content in self.mock_docs.items():
if any(keyword in query.lower() for keyword in ["rag", "检索"]):
if "rag" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
elif any(keyword in query.lower() for keyword in ["langchain", "构建"]):
if "langchain" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
elif any(keyword in query.lower() for keyword in ["向量", "数据库", "qdrant"]):
if "vector" in doc_id.lower() or "qdrant" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
elif any(keyword in query.lower() for keyword in ["优化", "重排", "融合"]):
if "optimize" in doc_id.lower() or "rerank" in doc_id.lower() or "fusion" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
elif any(keyword in query.lower() for keyword in ["langgraph"]):
if "langgraph" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
# 如果没有匹配到,返回一些通用结果
if not results:
for doc_id, content in list(self.mock_docs.items())[:3]:
results.append(Document(page_content=content, metadata={"id": doc_id}))
return results
rag_pipeline = MockRAGPipeline()
print(" RAG 系统已初始化(模拟)")
print()
# 3. 评估检索质量
print("【3/4】评估检索质量...")
evaluator = RAGEvaluator(rag_pipeline, test_cases)
metrics = await evaluator.evaluate_retrieval(k_list=[1, 3, 5, 10])
print(" 评估完成")
print()
# 4. 生成报告
print("【4/4】生成评估报告...")
report = generate_test_report(metrics)
print(report)
print()
# 5. 保存报告
report_file = os.path.join(os.path.dirname(__file__), 'rag_evaluation_report.txt')
with open(report_file, 'w', encoding='utf-8') as f:
f.write(report)
print(f" 报告已保存到:{report_file}")
print()
print("=" * 80)
print("评估完成!")
print("=" * 80)
if __name__ == "__main__":
asyncio.run(main())