#!/usr/bin/env python3 """ 测试重构后的 IndexBuilder 和 RAG 检索 包括:索引构建、稠密检索、稀疏检索、混合检索、父子文档检索 """ import asyncio import os from rag_indexer.index_builder import IndexBuilder from rag_indexer.splitters import SplitterType from backend.rag_core import QdrantHybridStore, get_sparse_embedder from backend.app.model_services import get_embedding_service from qdrant_client import models async def test_index_builder(): """测试索引构建功能""" print("="*70) print("1. 测试索引构建功能...") print("="*70) # 创建 IndexBuilder 实例 builder = IndexBuilder( collection_name="rag_documents", splitter_type=SplitterType.PARENT_CHILD, parent_chunk_size=1000, child_chunk_size=200 ) # 测试文档路径 project_root = os.path.join(os.path.dirname(__file__), "..", "..") test_file = os.path.join(project_root, "data", "user_docs", "doublestory.txt") if os.path.exists(test_file): # 构建索引 print(f"正在为文件 {test_file} 构建索引...") processed = await builder.build_from_file(test_file) print(f"索引构建完成,处理了 {processed} 个文档") # 获取集合信息 info = builder.get_collection_info() print(f"集合信息: {info}") else: print(f"测试文件不存在: {test_file}") # 关闭资源 builder.close() print("\n索引构建测试完成") return processed def test_dense_retrieval(): """测试稠密检索""" print("\n" + "="*70) print("2. 测试稠密检索...") print("="*70) # 获取嵌入服务 embeddings = get_embedding_service() # 创建向量存储 vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) # 测试查询 query = "The Ant and the Grasshopper" print(f"查询: {query}") results = vs.similarity_search(query, k=3) print(f"\n找到 {len(results)} 个结果:") for i, doc in enumerate(results, 1): print(f"\n{i}. (来源: {doc.metadata.get('source', 'unknown')})") print(f" 元数据: {doc.metadata}") content = doc.page_content.strip() if len(content) > 200: content = content[:200] + "..." print(f" 内容: {content}") def test_sparse_retrieval_simple(): """简单测试稀疏检索""" print("\n" + "="*70) print("3. 测试稀疏检索(BM25)...") print("="*70) # 获取嵌入服务和稀疏嵌入器 embeddings = get_embedding_service() vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) client = vs.get_qdrant_client() sparse_embedder = get_sparse_embedder() # 测试查询 - 用关键词 query = "winter work food" print(f"查询关键词: {query}") # 生成稀疏查询向量 sparse_query = sparse_embedder.embed_query(query) # 包装成 SparseVector 对象 sparse_vec = models.SparseVector( indices=sparse_query["indices"], values=sparse_query["values"] ) # 直接查询稀疏向量 response = client.query_points( collection_name="rag_documents", query=sparse_vec, using="sparse", limit=3, with_payload=True ) print(f"\n找到 {len(response.points)} 个结果:") for i, point in enumerate(response.points, 1): print(f"\n{i}. (分数: {point.score:.4f})") text = point.payload.get("text", "") metadata = {k: v for k, v in point.payload.items() if k != "text"} print(f" 元数据: {metadata}") content = text.strip() if len(content) > 200: content = content[:200] + "..." print(f" 内容: {content}") def test_hybrid_retrieval_simple(): """简单测试混合检索(稠密+稀疏 RRF 融合)""" print("\n" + "="*70) print("4. 测试混合检索(稠密+稀疏 RRF 融合)...") print("="*70) # 获取嵌入服务和稀疏嵌入器 embeddings = get_embedding_service() vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) client = vs.get_qdrant_client() sparse_embedder = get_sparse_embedder() # 测试查询 query = "Ant and Grasshopper story" print(f"查询: {query}") # 生成双向量 dense_query = embeddings.embed_query(query) sparse_query = sparse_embedder.embed_query(query) sparse_vec = models.SparseVector( indices=sparse_query["indices"], values=sparse_query["values"] ) # 使用 Qdrant 的 query_points 做混合检索 response = client.query_points( collection_name="rag_documents", prefetch=[ models.Prefetch( query=dense_query, using="dense", limit=3 ), models.Prefetch( query=sparse_vec, using="sparse", limit=3 ) ], query=models.FusionQuery(fusion=models.Fusion.RRF), limit=3, with_payload=True ) print(f"\n找到 {len(response.points)} 个结果:") for i, point in enumerate(response.points, 1): print(f"\n{i}. (RRF 融合分数: {point.score:.4f})") text = point.payload.get("text", "") metadata = {k: v for k, v in point.payload.items() if k != "text"} print(f" 元数据: {metadata}") content = text.strip() if len(content) > 200: content = content[:200] + "..." print(f" 内容: {content}") def test_parent_child_retrieval_simple(): """简单测试父子文档检索""" print("\n" + "="*70) print("5. 测试父子文档混合检索...") print("="*70) # 获取嵌入服务和稀疏嵌入器 embeddings = get_embedding_service() vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) client = vs.get_qdrant_client() sparse_embedder = get_sparse_embedder() # 测试查询 query = "The Ant and the Grasshopper story moral" print(f"查询: {query}") # 生成双向量 dense_query = embeddings.embed_query(query) sparse_query = sparse_embedder.embed_query(query) sparse_vec = models.SparseVector( indices=sparse_query["indices"], values=sparse_query["values"] ) # 先做混合检索找到子文档 response = client.query_points( collection_name="rag_documents", prefetch=[ models.Prefetch( query=dense_query, using="dense", limit=5 ), models.Prefetch( query=sparse_vec, using="sparse", limit=5 ) ], query=models.FusionQuery(fusion=models.Fusion.RRF), limit=5, with_payload=True ) # 收集 parent_id parent_score_map = {} child_points = {} for point in response.points: parent_id = point.payload.get("parent_id", point.id) score = point.score if parent_id not in parent_score_map or score > parent_score_map[parent_id]: parent_score_map[parent_id] = score child_points[parent_id] = point parent_ids = list(parent_score_map.keys()) print(f"\n找到 {len(parent_ids)} 个不同的 parent_id:") # 查找父文档 if parent_ids: parent_docs = client.retrieve( collection_name="rag_documents", ids=parent_ids, with_payload=True ) found_parent_ids = {p.id for p in parent_docs} # 准备结果列表 results = [] for p in parent_docs: score = parent_score_map[p.id] results.append((p, score)) # 处理没找到父文档的情况 - 用子文档代替 missing = set(parent_ids) - found_parent_ids for parent_id in missing: child_point = child_points[parent_id] print(f"\n注意: parent_id {parent_id} 未找到,使用子文档代替") results.append((child_point, parent_score_map[parent_id])) # 按分数排序 results.sort(key=lambda x: x[1], reverse=True) # 显示 print(f"\n共 {len(results)} 个结果(去重后):") for i, (point, score) in enumerate(results[:3], 1): print(f"\n{i}. (分数: {score:.4f})") text = point.payload.get("text", "") metadata = {k: v for k, v in point.payload.items() if k != "text"} print(f" 元数据: {metadata}") content = text.strip() if len(content) > 400: content = content[:400] + "..." print(f" 内容: {content}") else: print("\n未找到结果") async def main(): """主测试函数""" # 1. 先构建索引 await test_index_builder() # 2. 测试稠密检索 test_dense_retrieval() # 3. 测试稀疏检索 test_sparse_retrieval_simple() # 4. 测试混合检索 test_hybrid_retrieval_simple() # 5. 测试父子文档检索 test_parent_child_retrieval_simple() print("\n" + "="*70) print("所有测试完成!") print("="*70) if __name__ == "__main__": asyncio.run(main())