修改rag,实现混合检索
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m42s

This commit is contained in:
2026-05-04 04:28:32 +08:00
parent d0590240f9
commit 82dde7113e
15 changed files with 536 additions and 65 deletions

View File

@@ -0,0 +1,80 @@
#!/usr/bin/env python3
"""
检查 Qdrant 集合里的数据结构
"""
import asyncio
import os
import sys
# 添加项目根目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
sys.path.insert(0, project_root)
from rag_core import QdrantVectorStore
from app.model_services import get_embedding_service
def check_qdrant_data():
"""检查 Qdrant 中的数据结构"""
print("="*70)
print("检查 Qdrant 中的数据结构...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
# 先获取几个点看看 payload 结构
print("\n获取 5 个随机文档:")
results = client.scroll(
collection_name="rag_documents",
limit=5,
with_payload=True,
with_vectors=True
)
for i, point in enumerate(results[0], 1):
print(f"\n{i}. ID: {point.id}")
print(f" Payload: {point.payload}")
print(f" Payload 键: {list(point.payload.keys())}")
if "text" in point.payload:
text = point.payload["text"]
print(f" Text 长度: {len(text)}")
print(f" Text 预览: {text[:150]}...")
if "page_content" in point.payload:
print(f" page_content: {point.payload['page_content'][:150]}...")
# 看看向量
if point.vector:
print(f" 向量存在: {type(point.vector)}")
if isinstance(point.vector, dict):
print(f" 向量键: {list(point.vector.keys())}")
def check_sparse_embedder():
"""检查稀疏嵌入器"""
from rag_core import get_sparse_embedder
print("\n" + "="*70)
print("检查稀疏嵌入器...")
print("="*70)
sparse_embedder = get_sparse_embedder()
print(f"\n稀疏嵌入器: {sparse_embedder}")
print(f"Vocabulary 大小: {len(sparse_embedder.model.vocab)}")
print(f"示例查询: '冬天 食物'")
# 用中文试试
sparse_vec = sparse_embedder.embed_query("冬天 食物")
print(f"\n生成的稀疏向量:")
print(f" 索引数量: {len(sparse_vec['indices'])}")
print(f" 索引: {sparse_vec['indices'][:10]}")
print(f" 值: {sparse_vec['values'][:10]}")
if __name__ == "__main__":
check_qdrant_data()
check_sparse_embedder()

40
tools/test/quick_test.py Normal file
View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python3
"""
简单测试脚本:测试文档里真正有的内容
"""
import asyncio
import os
import sys
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from qdrant_client import models
from rag_core import QdrantVectorStore, get_sparse_embedder
from app.model_services import get_embedding_service
def test_dense_retrieval():
"""测试稠密检索"""
print("="*70)
print("测试稠密检索...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
query = "黄双银" # 用文档里真正有的名字查询
print(f"\n查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
if __name__ == "__main__":
test_dense_retrieval()

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python3
"""
删除 Qdrant 集合并重新索引
"""
import asyncio
import os
import sys
# 添加项目根目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
sys.path.insert(0, project_root)
from rag_core import QdrantVectorStore
from app.model_services import get_embedding_service
async def delete_and_recreate():
"""删除并重新创建集合"""
print("="*70)
print("删除旧集合并重新创建...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
# 删除旧集合
try:
vs.delete_collection()
print("✅ 旧集合已删除")
except Exception as e:
print(f"⚠️ 删除集合时出错(可能不存在): {e}")
# 重新创建
vs.create_collection()
print("✅ 新集合已创建")
if __name__ == "__main__":
asyncio.run(delete_and_recreate())

View File

@@ -0,0 +1,30 @@
#!/usr/bin/env python3
"""
简单删除 Qdrant 集合
"""
import sys
import os
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from rag_core.client import create_qdrant_client
def delete_collection():
print("="*70)
print("删除 rag_documents 集合...")
print("="*70)
client = create_qdrant_client()
try:
client.delete_collection("rag_documents")
print("✅ 删除成功")
except Exception as e:
print(f"⚠️ 删除失败: {e}")
if __name__ == "__main__":
delete_collection()

153
tools/test/simple_test.py Normal file
View File

@@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
简单测试脚本:检查 Qdrant 内容,测试各种检索方式
"""
import asyncio
import os
import sys
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from qdrant_client import models
from rag_core import QdrantVectorStore, get_sparse_embedder
from app.model_services import get_embedding_service
def check_qdrant_content():
"""检查 Qdrant 里的内容"""
print("="*70)
print("检查 Qdrant 内容...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
# 滚动获取前 5 个点
points, _ = client.scroll(
collection_name="rag_documents",
limit=5,
with_payload=True,
with_vectors=False
)
print(f"\n找到 {len(points)} 个文档\n")
for i, point in enumerate(points):
print(f"--- 文档 {i+1} ---")
print(f"ID: {point.id}")
print(f"Payload 键: {list(point.payload.keys())}")
# 打印完整 payload
for k, v in point.payload.items():
if isinstance(v, str) and len(v) > 150:
v = v[:150] + "..."
print(f" {k}: {v}")
print()
def test_dense_retrieval():
"""测试稠密检索"""
print("="*70)
print("测试稠密检索...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
query = "蚂蚁" # 用中文查询
print(f"\n查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
def test_sparse_retrieval():
"""测试稀疏检索"""
print("="*70)
print("测试稀疏检索BM25...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
query = "冬天"
print(f"\n查询: {query}")
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
response = client.query_points(
collection_name="rag_documents",
query=sparse_vec,
using="sparse",
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果\n")
for i, point in enumerate(response.points):
print(f"--- 结果 {i+1} ---")
print(f"分数: {point.score:.4f}")
text = point.payload.get("page_content", point.payload.get("text", ""))
print(text[:200])
print()
def test_hybrid_retrieval():
"""测试混合检索"""
print("="*70)
print("测试混合检索(稠密+稀疏 RRF 融合)...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
query = "蚂蚁和蚱蜢"
print(f"\n查询: {query}")
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(query=dense_query, using="dense", limit=3),
models.Prefetch(query=sparse_vec, using="sparse", limit=3)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果\n")
for i, point in enumerate(response.points):
print(f"--- 结果 {i+1} ---")
print(f"分数: {point.score:.4f}")
text = point.payload.get("page_content", point.payload.get("text", ""))
print(text[:200])
print()
if __name__ == "__main__":
check_qdrant_content()
test_dense_retrieval()
test_sparse_retrieval()
test_hybrid_retrieval()

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python3
"""
测试 app/rag/retriever.py 里的混合检索函数
"""
import asyncio
import os
import sys
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from app.rag.retriever import create_hybrid_retriever, create_parent_hybrid_retriever
def test_hybrid_retriever():
"""测试混合检索器"""
print("="*70)
print("测试 HybridRetriever...")
print("="*70)
retriever = create_hybrid_retriever(collection_name="rag_documents", search_k=3)
results = retriever.invoke("黄双银")
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
def test_parent_hybrid_retriever():
"""测试父子混合检索器"""
print("\n" + "="*70)
print("测试 ParentHybridRetriever...")
print("="*70)
retriever = create_parent_hybrid_retriever(
collection_name="rag_documents",
search_k=3,
use_docstore=False
)
results = retriever.invoke("黄双银")
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:300])
print()
if __name__ == "__main__":
test_hybrid_retriever()
test_parent_hybrid_retriever()