188 lines
5.9 KiB
Python
188 lines
5.9 KiB
Python
"""
|
||
验证 RAG 索引完整性。
|
||
|
||
检查 Qdrant 向量库、PostgreSQL 文档存储及检索功能。
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
import sys
|
||
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
|
||
COLLECTION_NAME = "rag_documents"
|
||
TABLE_NAME = "parent_documents"
|
||
|
||
|
||
def check_qdrant():
|
||
"""检查 Qdrant 向量库。"""
|
||
from qdrant_client import QdrantClient
|
||
|
||
print("=" * 60)
|
||
print("Qdrant 向量库")
|
||
print("=" * 60)
|
||
|
||
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||
|
||
# 集合列表
|
||
collections = client.get_collections().collections
|
||
print(f"\n集合数: {len(collections)}")
|
||
for c in collections:
|
||
print(f" - {c.name}")
|
||
|
||
# 目标集合信息
|
||
if not any(c.name == COLLECTION_NAME for c in collections):
|
||
print(f"\n集合 '{COLLECTION_NAME}' 不存在")
|
||
return
|
||
|
||
info = client.get_collection(COLLECTION_NAME)
|
||
print(f"\n集合 '{COLLECTION_NAME}':")
|
||
print(f" 状态: {info.status}")
|
||
print(f" 向量数: {info.points_count}")
|
||
|
||
vectors_config = info.config.params.vectors
|
||
if isinstance(vectors_config, dict):
|
||
for name, vc in vectors_config.items():
|
||
print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}")
|
||
else:
|
||
print(f" 向量维度: {vectors_config.size}")
|
||
|
||
# 抽样查看
|
||
print(f"\n前 3 个向量:")
|
||
points = client.scroll(
|
||
collection_name=COLLECTION_NAME,
|
||
limit=3,
|
||
with_payload=True,
|
||
with_vectors=False
|
||
)
|
||
for i, point in enumerate(points[0]):
|
||
print(f"\n {i+1}. ID: {point.id}")
|
||
payload = point.payload or {}
|
||
print(f" 内容: {payload.get('page_content', '')[:100]}...")
|
||
|
||
|
||
async def check_postgres():
|
||
"""检查 PostgreSQL 文档存储。"""
|
||
import asyncpg
|
||
|
||
print("\n" + "=" * 60)
|
||
print("PostgreSQL 文档存储")
|
||
print("=" * 60)
|
||
|
||
conn = await asyncpg.connect(dsn=DB_URI)
|
||
|
||
try:
|
||
# 表是否存在
|
||
tables = await conn.fetch(
|
||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
|
||
)
|
||
table_names = [t['table_name'] for t in tables]
|
||
|
||
if TABLE_NAME not in table_names:
|
||
print(f"\n表 '{TABLE_NAME}' 不存在")
|
||
return
|
||
|
||
# 统计
|
||
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
|
||
print(f"\n表 '{TABLE_NAME}': {count} 条记录")
|
||
|
||
# 抽样
|
||
print(f"\n前 3 个文档:")
|
||
rows = await conn.fetch(
|
||
f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3"
|
||
)
|
||
for i, row in enumerate(rows):
|
||
print(f"\n {i+1}. Key: {row['key']}")
|
||
val = row['value']
|
||
if isinstance(val, dict) and 'page_content' in val:
|
||
print(f" 内容: {val['page_content'][:100]}...")
|
||
|
||
# Key 前缀分布
|
||
key_prefixes = await conn.fetch(
|
||
f"""
|
||
SELECT
|
||
CASE
|
||
WHEN key LIKE '%:%' THEN split_part(key, ':', 1)
|
||
ELSE 'no_prefix'
|
||
END AS prefix,
|
||
COUNT(*) AS cnt
|
||
FROM {TABLE_NAME}
|
||
GROUP BY prefix
|
||
ORDER BY cnt DESC
|
||
LIMIT 10
|
||
"""
|
||
)
|
||
print(f"\nKey 前缀分布:")
|
||
for row in key_prefixes:
|
||
print(f" {row['prefix']}: {row['cnt']}")
|
||
|
||
finally:
|
||
await conn.close()
|
||
|
||
|
||
async def test_search():
|
||
"""测试检索功能。"""
|
||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||
from rag_indexer.splitters import SplitterType
|
||
|
||
print("\n" + "=" * 60)
|
||
print("检索测试")
|
||
print("=" * 60)
|
||
|
||
# 使用配置对象初始化(与默认构建方式一致)
|
||
config = IndexBuilderConfig(
|
||
collection_name=COLLECTION_NAME,
|
||
splitter_type=SplitterType.PARENT_CHILD,
|
||
)
|
||
builder = IndexBuilder(config)
|
||
|
||
# 确保检索器已初始化
|
||
if builder.retriever is None:
|
||
print("错误: 检索器未初始化,请检查切分策略")
|
||
return
|
||
|
||
query = input("\n查询 (回车使用默认): ").strip() or "你好"
|
||
print(f"\n查询: {query}")
|
||
|
||
# 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块)
|
||
print("\n--- 标准检索 (返回父块) ---")
|
||
results = await builder.retriever.ainvoke(query)
|
||
for i, doc in enumerate(results):
|
||
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
|
||
print(f"\n {i+1}. {content}...")
|
||
if hasattr(doc, 'metadata'):
|
||
source = doc.metadata.get('source', '')
|
||
if source:
|
||
print(f" 来源: {source}")
|
||
|
||
# 若需要仅返回子块,可以临时修改检索器的 search_type
|
||
# (注意:ParentDocumentRetriever 的 search_type 默认为 "similarity")
|
||
print("\n--- 检索子块 (通过修改检索器参数) ---")
|
||
# 创建一个新的检索器副本,设置为返回子块
|
||
# 简单起见,直接调用 vectorstore 进行相似度搜索获取子块
|
||
vectorstore = builder.vector_store.get_langchain_vectorstore()
|
||
sub_results = await vectorstore.asimilarity_search(query, k=3)
|
||
for i, doc in enumerate(sub_results):
|
||
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
|
||
print(f"\n {i+1}. {content}...")
|
||
if hasattr(doc, 'metadata'):
|
||
parent_id = doc.metadata.get('parent_id', '')
|
||
if parent_id:
|
||
print(f" 父块 ID: {parent_id}")
|
||
|
||
|
||
async def main():
|
||
check_qdrant()
|
||
await check_postgres()
|
||
await test_search()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |