Files
ailine/backend/scripts/evaluate_rag.py
root 92863e86dc
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m13s
feat: 添加 RAG 评估模块,支持召回率和相关性评估
2026-04-26 15:39:05 +08:00

144 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 评估示例脚本
演示如何使用 RAGEvaluator 评估召回率和相关性
"""
import asyncio
import sys
import os
# 添加项目路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from app.rag.evaluate import (
RAGEvaluator,
RelevanceEvaluator,
RetrievalTestCase,
generate_test_report
)
from app.rag.pipeline import RAGPipeline
from app.model_services import get_chat_service, get_embedding_service
async def main():
print("=" * 80)
print("RAG 系统评估示例")
print("=" * 80)
print()
# 1. 准备测试用例
print("【1/4】准备测试用例...")
test_cases = [
RetrievalTestCase(
query="什么是 RAG 系统?",
relevant_doc_ids=["doc_rag_1", "doc_rag_2", "doc_rag_3"],
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写,是一种结合检索和生成的技术..."
),
RetrievalTestCase(
query="如何使用 LangChain 构建 RAG",
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2"],
expected_answer="使用 LangChain 构建 RAG 的步骤包括1) 准备文档 2) 向量化 3) 构建检索器 4) 组合生成..."
),
RetrievalTestCase(
query="什么是向量数据库?",
relevant_doc_ids=["doc_vector_db_1", "doc_qdrant_1"],
expected_answer="向量数据库是专门用于存储和检索向量嵌入的数据库,如 Qdrant、Pinecone 等..."
),
RetrievalTestCase(
query="如何优化 RAG 的检索质量?",
relevant_doc_ids=["doc_optimize_1", "doc_rerank_1", "doc_fusion_1"],
expected_answer="优化 RAG 检索质量的方法包括:重排序、查询改写、结果融合、混合检索等..."
),
RetrievalTestCase(
query="LangGraph 是什么?",
relevant_doc_ids=["doc_langgraph_1"],
expected_answer="LangGraph 是 LangChain 的扩展,用于构建状态感知的多步工作流..."
),
]
print(f" 已加载 {len(test_cases)} 个测试用例")
print()
# 2. 初始化 RAG 系统(这里使用模拟)
print("【2/4】初始化 RAG 系统...")
# 注意:实际使用时,这里应该初始化真实的 RAGPipeline
# 这里为了演示,我们创建一个模拟的 RAG 类
class MockRAGPipeline:
def __init__(self):
# 模拟的文档库
self.mock_docs = {
"doc_rag_1": "RAG 是 Retrieval-Augmented Generation 的缩写...",
"doc_rag_2": "RAG 系统由检索器和生成器两部分组成...",
"doc_rag_3": "RAG 的工作流程是:查询 -> 检索 -> 生成...",
"doc_langchain_1": "LangChain 是用于构建 LLM 应用的框架...",
"doc_langchain_2": "LangChain 提供了多种工具和集成...",
"doc_vector_db_1": "向量数据库用于存储向量嵌入...",
"doc_qdrant_1": "Qdrant 是一个开源的向量数据库...",
"doc_optimize_1": "RAG 优化方法包括重排序和查询改写...",
"doc_rerank_1": "重排序使用 Cross-Encoder 重新排序检索结果...",
"doc_fusion_1": "结果融合使用 RRF 算法合并多个检索结果...",
"doc_langgraph_1": "LangGraph 用于构建状态机工作流...",
}
async def aretrieve(self, query: str):
"""模拟检索,返回相关文档"""
from langchain_core.documents import Document
# 简单的关键词匹配模拟
results = []
for doc_id, content in self.mock_docs.items():
if any(keyword in query.lower() for keyword in ["rag", "检索"]):
if "rag" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
elif any(keyword in query.lower() for keyword in ["langchain", "构建"]):
if "langchain" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
elif any(keyword in query.lower() for keyword in ["向量", "数据库", "qdrant"]):
if "vector" in doc_id.lower() or "qdrant" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
elif any(keyword in query.lower() for keyword in ["优化", "重排", "融合"]):
if "optimize" in doc_id.lower() or "rerank" in doc_id.lower() or "fusion" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
elif any(keyword in query.lower() for keyword in ["langgraph"]):
if "langgraph" in doc_id.lower():
results.append(Document(page_content=content, metadata={"id": doc_id}))
# 如果没有匹配到,返回一些通用结果
if not results:
for doc_id, content in list(self.mock_docs.items())[:3]:
results.append(Document(page_content=content, metadata={"id": doc_id}))
return results
rag_pipeline = MockRAGPipeline()
print(" RAG 系统已初始化(模拟)")
print()
# 3. 评估检索质量
print("【3/4】评估检索质量...")
evaluator = RAGEvaluator(rag_pipeline, test_cases)
metrics = await evaluator.evaluate_retrieval(k_list=[1, 3, 5, 10])
print(" 评估完成")
print()
# 4. 生成报告
print("【4/4】生成评估报告...")
report = generate_test_report(metrics)
print(report)
print()
# 5. 保存报告
report_file = os.path.join(os.path.dirname(__file__), 'rag_evaluation_report.txt')
with open(report_file, 'w', encoding='utf-8') as f:
f.write(report)
print(f" 报告已保存到:{report_file}")
print()
print("=" * 80)
print("评估完成!")
print("=" * 80)
if __name__ == "__main__":
asyncio.run(main())