144 lines
6.0 KiB
Python
144 lines
6.0 KiB
Python
"""
|
||
RAG 评估示例脚本
|
||
演示如何使用 RAGEvaluator 评估召回率和相关性
|
||
"""
|
||
|
||
import asyncio
|
||
import sys
|
||
import os
|
||
|
||
# 添加项目路径
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||
|
||
from app.rag.evaluate import (
|
||
RAGEvaluator,
|
||
RelevanceEvaluator,
|
||
RetrievalTestCase,
|
||
generate_test_report
|
||
)
|
||
from app.rag.pipeline import RAGPipeline
|
||
from app.model_services import get_chat_service, get_embedding_service
|
||
|
||
|
||
async def main():
|
||
print("=" * 80)
|
||
print("RAG 系统评估示例")
|
||
print("=" * 80)
|
||
print()
|
||
|
||
# 1. 准备测试用例
|
||
print("【1/4】准备测试用例...")
|
||
test_cases = [
|
||
RetrievalTestCase(
|
||
query="什么是 RAG 系统?",
|
||
relevant_doc_ids=["doc_rag_1", "doc_rag_2", "doc_rag_3"],
|
||
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写,是一种结合检索和生成的技术..."
|
||
),
|
||
RetrievalTestCase(
|
||
query="如何使用 LangChain 构建 RAG?",
|
||
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2"],
|
||
expected_answer="使用 LangChain 构建 RAG 的步骤包括:1) 准备文档 2) 向量化 3) 构建检索器 4) 组合生成..."
|
||
),
|
||
RetrievalTestCase(
|
||
query="什么是向量数据库?",
|
||
relevant_doc_ids=["doc_vector_db_1", "doc_qdrant_1"],
|
||
expected_answer="向量数据库是专门用于存储和检索向量嵌入的数据库,如 Qdrant、Pinecone 等..."
|
||
),
|
||
RetrievalTestCase(
|
||
query="如何优化 RAG 的检索质量?",
|
||
relevant_doc_ids=["doc_optimize_1", "doc_rerank_1", "doc_fusion_1"],
|
||
expected_answer="优化 RAG 检索质量的方法包括:重排序、查询改写、结果融合、混合检索等..."
|
||
),
|
||
RetrievalTestCase(
|
||
query="LangGraph 是什么?",
|
||
relevant_doc_ids=["doc_langgraph_1"],
|
||
expected_answer="LangGraph 是 LangChain 的扩展,用于构建状态感知的多步工作流..."
|
||
),
|
||
]
|
||
print(f" 已加载 {len(test_cases)} 个测试用例")
|
||
print()
|
||
|
||
# 2. 初始化 RAG 系统(这里使用模拟)
|
||
print("【2/4】初始化 RAG 系统...")
|
||
|
||
# 注意:实际使用时,这里应该初始化真实的 RAGPipeline
|
||
# 这里为了演示,我们创建一个模拟的 RAG 类
|
||
class MockRAGPipeline:
|
||
def __init__(self):
|
||
# 模拟的文档库
|
||
self.mock_docs = {
|
||
"doc_rag_1": "RAG 是 Retrieval-Augmented Generation 的缩写...",
|
||
"doc_rag_2": "RAG 系统由检索器和生成器两部分组成...",
|
||
"doc_rag_3": "RAG 的工作流程是:查询 -> 检索 -> 生成...",
|
||
"doc_langchain_1": "LangChain 是用于构建 LLM 应用的框架...",
|
||
"doc_langchain_2": "LangChain 提供了多种工具和集成...",
|
||
"doc_vector_db_1": "向量数据库用于存储向量嵌入...",
|
||
"doc_qdrant_1": "Qdrant 是一个开源的向量数据库...",
|
||
"doc_optimize_1": "RAG 优化方法包括重排序和查询改写...",
|
||
"doc_rerank_1": "重排序使用 Cross-Encoder 重新排序检索结果...",
|
||
"doc_fusion_1": "结果融合使用 RRF 算法合并多个检索结果...",
|
||
"doc_langgraph_1": "LangGraph 用于构建状态机工作流...",
|
||
}
|
||
|
||
async def aretrieve(self, query: str):
|
||
"""模拟检索,返回相关文档"""
|
||
from langchain_core.documents import Document
|
||
|
||
# 简单的关键词匹配模拟
|
||
results = []
|
||
for doc_id, content in self.mock_docs.items():
|
||
if any(keyword in query.lower() for keyword in ["rag", "检索"]):
|
||
if "rag" in doc_id.lower():
|
||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||
elif any(keyword in query.lower() for keyword in ["langchain", "构建"]):
|
||
if "langchain" in doc_id.lower():
|
||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||
elif any(keyword in query.lower() for keyword in ["向量", "数据库", "qdrant"]):
|
||
if "vector" in doc_id.lower() or "qdrant" in doc_id.lower():
|
||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||
elif any(keyword in query.lower() for keyword in ["优化", "重排", "融合"]):
|
||
if "optimize" in doc_id.lower() or "rerank" in doc_id.lower() or "fusion" in doc_id.lower():
|
||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||
elif any(keyword in query.lower() for keyword in ["langgraph"]):
|
||
if "langgraph" in doc_id.lower():
|
||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||
|
||
# 如果没有匹配到,返回一些通用结果
|
||
if not results:
|
||
for doc_id, content in list(self.mock_docs.items())[:3]:
|
||
results.append(Document(page_content=content, metadata={"id": doc_id}))
|
||
|
||
return results
|
||
|
||
rag_pipeline = MockRAGPipeline()
|
||
print(" RAG 系统已初始化(模拟)")
|
||
print()
|
||
|
||
# 3. 评估检索质量
|
||
print("【3/4】评估检索质量...")
|
||
evaluator = RAGEvaluator(rag_pipeline, test_cases)
|
||
metrics = await evaluator.evaluate_retrieval(k_list=[1, 3, 5, 10])
|
||
print(" 评估完成")
|
||
print()
|
||
|
||
# 4. 生成报告
|
||
print("【4/4】生成评估报告...")
|
||
report = generate_test_report(metrics)
|
||
print(report)
|
||
print()
|
||
|
||
# 5. 保存报告
|
||
report_file = os.path.join(os.path.dirname(__file__), 'rag_evaluation_report.txt')
|
||
with open(report_file, 'w', encoding='utf-8') as f:
|
||
f.write(report)
|
||
print(f" 报告已保存到:{report_file}")
|
||
print()
|
||
|
||
print("=" * 80)
|
||
print("评估完成!")
|
||
print("=" * 80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|