154 lines
4.4 KiB
Python
154 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
简单测试脚本:检查 Qdrant 内容,测试各种检索方式
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
import sys
|
||
|
||
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
|
||
sys.path.insert(0, os.path.join(project_root, "backend"))
|
||
|
||
from qdrant_client import models
|
||
from rag_core import QdrantVectorStore, get_sparse_embedder
|
||
from app.model_services import get_embedding_service
|
||
|
||
|
||
def check_qdrant_content():
|
||
"""检查 Qdrant 里的内容"""
|
||
print("="*70)
|
||
print("检查 Qdrant 内容...")
|
||
print("="*70)
|
||
|
||
embeddings = get_embedding_service()
|
||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||
client = vs.get_qdrant_client()
|
||
|
||
# 滚动获取前 5 个点
|
||
points, _ = client.scroll(
|
||
collection_name="rag_documents",
|
||
limit=5,
|
||
with_payload=True,
|
||
with_vectors=False
|
||
)
|
||
|
||
print(f"\n找到 {len(points)} 个文档\n")
|
||
for i, point in enumerate(points):
|
||
print(f"--- 文档 {i+1} ---")
|
||
print(f"ID: {point.id}")
|
||
print(f"Payload 键: {list(point.payload.keys())}")
|
||
|
||
# 打印完整 payload
|
||
for k, v in point.payload.items():
|
||
if isinstance(v, str) and len(v) > 150:
|
||
v = v[:150] + "..."
|
||
print(f" {k}: {v}")
|
||
print()
|
||
|
||
|
||
def test_dense_retrieval():
|
||
"""测试稠密检索"""
|
||
print("="*70)
|
||
print("测试稠密检索...")
|
||
print("="*70)
|
||
|
||
embeddings = get_embedding_service()
|
||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||
|
||
query = "蚂蚁" # 用中文查询
|
||
print(f"\n查询: {query}")
|
||
|
||
results = vs.similarity_search(query, k=3)
|
||
|
||
print(f"\n找到 {len(results)} 个结果\n")
|
||
for i, doc in enumerate(results):
|
||
print(f"--- 结果 {i+1} ---")
|
||
print(doc.page_content[:200])
|
||
print()
|
||
|
||
|
||
def test_sparse_retrieval():
|
||
"""测试稀疏检索"""
|
||
print("="*70)
|
||
print("测试稀疏检索(BM25)...")
|
||
print("="*70)
|
||
|
||
embeddings = get_embedding_service()
|
||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||
client = vs.get_qdrant_client()
|
||
sparse_embedder = get_sparse_embedder()
|
||
|
||
query = "冬天"
|
||
print(f"\n查询: {query}")
|
||
|
||
sparse_query = sparse_embedder.embed_query(query)
|
||
sparse_vec = models.SparseVector(
|
||
indices=sparse_query["indices"],
|
||
values=sparse_query["values"]
|
||
)
|
||
|
||
response = client.query_points(
|
||
collection_name="rag_documents",
|
||
query=sparse_vec,
|
||
using="sparse",
|
||
limit=3,
|
||
with_payload=True
|
||
)
|
||
|
||
print(f"\n找到 {len(response.points)} 个结果\n")
|
||
for i, point in enumerate(response.points):
|
||
print(f"--- 结果 {i+1} ---")
|
||
print(f"分数: {point.score:.4f}")
|
||
text = point.payload.get("page_content", point.payload.get("text", ""))
|
||
print(text[:200])
|
||
print()
|
||
|
||
|
||
def test_hybrid_retrieval():
|
||
"""测试混合检索"""
|
||
print("="*70)
|
||
print("测试混合检索(稠密+稀疏 RRF 融合)...")
|
||
print("="*70)
|
||
|
||
embeddings = get_embedding_service()
|
||
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||
client = vs.get_qdrant_client()
|
||
sparse_embedder = get_sparse_embedder()
|
||
|
||
query = "蚂蚁和蚱蜢"
|
||
print(f"\n查询: {query}")
|
||
|
||
dense_query = embeddings.embed_query(query)
|
||
sparse_query = sparse_embedder.embed_query(query)
|
||
sparse_vec = models.SparseVector(
|
||
indices=sparse_query["indices"],
|
||
values=sparse_query["values"]
|
||
)
|
||
|
||
response = client.query_points(
|
||
collection_name="rag_documents",
|
||
prefetch=[
|
||
models.Prefetch(query=dense_query, using="dense", limit=3),
|
||
models.Prefetch(query=sparse_vec, using="sparse", limit=3)
|
||
],
|
||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||
limit=3,
|
||
with_payload=True
|
||
)
|
||
|
||
print(f"\n找到 {len(response.points)} 个结果\n")
|
||
for i, point in enumerate(response.points):
|
||
print(f"--- 结果 {i+1} ---")
|
||
print(f"分数: {point.score:.4f}")
|
||
text = point.payload.get("page_content", point.payload.get("text", ""))
|
||
print(text[:200])
|
||
print()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
check_qdrant_content()
|
||
test_dense_retrieval()
|
||
test_sparse_retrieval()
|
||
test_hybrid_retrieval()
|