Files
ailine/tools/test/simple_test.py
root 82dde7113e
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m42s
修改rag,实现混合检索
2026-05-04 04:28:32 +08:00

154 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
简单测试脚本:检查 Qdrant 内容,测试各种检索方式
"""
import asyncio
import os
import sys
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from qdrant_client import models
from rag_core import QdrantVectorStore, get_sparse_embedder
from app.model_services import get_embedding_service
def check_qdrant_content():
"""检查 Qdrant 里的内容"""
print("="*70)
print("检查 Qdrant 内容...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
# 滚动获取前 5 个点
points, _ = client.scroll(
collection_name="rag_documents",
limit=5,
with_payload=True,
with_vectors=False
)
print(f"\n找到 {len(points)} 个文档\n")
for i, point in enumerate(points):
print(f"--- 文档 {i+1} ---")
print(f"ID: {point.id}")
print(f"Payload 键: {list(point.payload.keys())}")
# 打印完整 payload
for k, v in point.payload.items():
if isinstance(v, str) and len(v) > 150:
v = v[:150] + "..."
print(f" {k}: {v}")
print()
def test_dense_retrieval():
"""测试稠密检索"""
print("="*70)
print("测试稠密检索...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
query = "蚂蚁" # 用中文查询
print(f"\n查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
def test_sparse_retrieval():
"""测试稀疏检索"""
print("="*70)
print("测试稀疏检索BM25...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
query = "冬天"
print(f"\n查询: {query}")
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
response = client.query_points(
collection_name="rag_documents",
query=sparse_vec,
using="sparse",
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果\n")
for i, point in enumerate(response.points):
print(f"--- 结果 {i+1} ---")
print(f"分数: {point.score:.4f}")
text = point.payload.get("page_content", point.payload.get("text", ""))
print(text[:200])
print()
def test_hybrid_retrieval():
"""测试混合检索"""
print("="*70)
print("测试混合检索(稠密+稀疏 RRF 融合)...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
query = "蚂蚁和蚱蜢"
print(f"\n查询: {query}")
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(query=dense_query, using="dense", limit=3),
models.Prefetch(query=sparse_vec, using="sparse", limit=3)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果\n")
for i, point in enumerate(response.points):
print(f"--- 结果 {i+1} ---")
print(f"分数: {point.score:.4f}")
text = point.payload.get("page_content", point.payload.get("text", ""))
print(text[:200])
print()
if __name__ == "__main__":
check_qdrant_content()
test_dense_retrieval()
test_sparse_retrieval()
test_hybrid_retrieval()