Files
ailine/rag_indexer/test/test_validate_index.py

188 lines
5.9 KiB
Python
Raw Normal View History

2026-04-19 22:01:55 +08:00
"""
验证 RAG 索引完整性
检查 Qdrant 向量库PostgreSQL 文档存储及检索功能
"""
import asyncio
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from dotenv import load_dotenv
load_dotenv()
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
COLLECTION_NAME = "rag_documents"
TABLE_NAME = "parent_documents"
def check_qdrant():
"""检查 Qdrant 向量库。"""
from qdrant_client import QdrantClient
print("=" * 60)
print("Qdrant 向量库")
print("=" * 60)
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
# 集合列表
collections = client.get_collections().collections
print(f"\n集合数: {len(collections)}")
for c in collections:
print(f" - {c.name}")
# 目标集合信息
if not any(c.name == COLLECTION_NAME for c in collections):
print(f"\n集合 '{COLLECTION_NAME}' 不存在")
return
info = client.get_collection(COLLECTION_NAME)
print(f"\n集合 '{COLLECTION_NAME}':")
print(f" 状态: {info.status}")
print(f" 向量数: {info.points_count}")
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
for name, vc in vectors_config.items():
print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}")
else:
print(f" 向量维度: {vectors_config.size}")
# 抽样查看
print(f"\n前 3 个向量:")
points = client.scroll(
collection_name=COLLECTION_NAME,
limit=3,
with_payload=True,
with_vectors=False
)
for i, point in enumerate(points[0]):
print(f"\n {i+1}. ID: {point.id}")
payload = point.payload or {}
print(f" 内容: {payload.get('page_content', '')[:100]}...")
async def check_postgres():
"""检查 PostgreSQL 文档存储。"""
import asyncpg
print("\n" + "=" * 60)
print("PostgreSQL 文档存储")
print("=" * 60)
conn = await asyncpg.connect(dsn=DB_URI)
try:
# 表是否存在
tables = await conn.fetch(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
)
table_names = [t['table_name'] for t in tables]
if TABLE_NAME not in table_names:
print(f"\n'{TABLE_NAME}' 不存在")
return
# 统计
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
print(f"\n'{TABLE_NAME}': {count} 条记录")
# 抽样
print(f"\n前 3 个文档:")
rows = await conn.fetch(
f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3"
)
for i, row in enumerate(rows):
print(f"\n {i+1}. Key: {row['key']}")
val = row['value']
if isinstance(val, dict) and 'page_content' in val:
print(f" 内容: {val['page_content'][:100]}...")
# Key 前缀分布
key_prefixes = await conn.fetch(
f"""
SELECT
CASE
WHEN key LIKE '%:%' THEN split_part(key, ':', 1)
ELSE 'no_prefix'
END AS prefix,
COUNT(*) AS cnt
FROM {TABLE_NAME}
GROUP BY prefix
ORDER BY cnt DESC
LIMIT 10
"""
)
print(f"\nKey 前缀分布:")
for row in key_prefixes:
print(f" {row['prefix']}: {row['cnt']}")
finally:
await conn.close()
async def test_search():
"""测试检索功能。"""
2026-04-20 01:10:18 +08:00
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
2026-04-19 22:01:55 +08:00
from rag_indexer.splitters import SplitterType
print("\n" + "=" * 60)
print("检索测试")
print("=" * 60)
# 使用配置对象初始化(与默认构建方式一致)
config = IndexBuilderConfig(
collection_name=COLLECTION_NAME,
splitter_type=SplitterType.PARENT_CHILD,
)
builder = IndexBuilder(config)
# 确保检索器已初始化
if builder.retriever is None:
print("错误: 检索器未初始化,请检查切分策略")
return
query = input("\n查询 (回车使用默认): ").strip() or "你好"
print(f"\n查询: {query}")
# 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块)
print("\n--- 标准检索 (返回父块) ---")
results = await builder.retriever.ainvoke(query)
for i, doc in enumerate(results):
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
print(f"\n {i+1}. {content}...")
if hasattr(doc, 'metadata'):
source = doc.metadata.get('source', '')
if source:
print(f" 来源: {source}")
# 若需要仅返回子块,可以临时修改检索器的 search_type
# 注意ParentDocumentRetriever 的 search_type 默认为 "similarity"
print("\n--- 检索子块 (通过修改检索器参数) ---")
# 创建一个新的检索器副本,设置为返回子块
# 简单起见,直接调用 vectorstore 进行相似度搜索获取子块
vectorstore = builder.vector_store.get_langchain_vectorstore()
sub_results = await vectorstore.asimilarity_search(query, k=3)
for i, doc in enumerate(sub_results):
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
print(f"\n {i+1}. {content}...")
if hasattr(doc, 'metadata'):
parent_id = doc.metadata.get('parent_id', '')
if parent_id:
print(f" 父块 ID: {parent_id}")
async def main():
check_qdrant()
await check_postgres()
await test_search()
if __name__ == "__main__":
asyncio.run(main())