Files
ailine/rag_indexer/test/test_validate_index.py
root 933d418d77
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s
检索器重构
2026-04-19 22:01:55 +08:00

188 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
验证 RAG 索引完整性。
检查 Qdrant 向量库、PostgreSQL 文档存储及检索功能。
"""
import asyncio
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from dotenv import load_dotenv
load_dotenv()
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
COLLECTION_NAME = "rag_documents"
TABLE_NAME = "parent_documents"
def check_qdrant():
"""检查 Qdrant 向量库。"""
from qdrant_client import QdrantClient
print("=" * 60)
print("Qdrant 向量库")
print("=" * 60)
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
# 集合列表
collections = client.get_collections().collections
print(f"\n集合数: {len(collections)}")
for c in collections:
print(f" - {c.name}")
# 目标集合信息
if not any(c.name == COLLECTION_NAME for c in collections):
print(f"\n集合 '{COLLECTION_NAME}' 不存在")
return
info = client.get_collection(COLLECTION_NAME)
print(f"\n集合 '{COLLECTION_NAME}':")
print(f" 状态: {info.status}")
print(f" 向量数: {info.points_count}")
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
for name, vc in vectors_config.items():
print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}")
else:
print(f" 向量维度: {vectors_config.size}")
# 抽样查看
print(f"\n前 3 个向量:")
points = client.scroll(
collection_name=COLLECTION_NAME,
limit=3,
with_payload=True,
with_vectors=False
)
for i, point in enumerate(points[0]):
print(f"\n {i+1}. ID: {point.id}")
payload = point.payload or {}
print(f" 内容: {payload.get('page_content', '')[:100]}...")
async def check_postgres():
"""检查 PostgreSQL 文档存储。"""
import asyncpg
print("\n" + "=" * 60)
print("PostgreSQL 文档存储")
print("=" * 60)
conn = await asyncpg.connect(dsn=DB_URI)
try:
# 表是否存在
tables = await conn.fetch(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
)
table_names = [t['table_name'] for t in tables]
if TABLE_NAME not in table_names:
print(f"\n'{TABLE_NAME}' 不存在")
return
# 统计
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
print(f"\n'{TABLE_NAME}': {count} 条记录")
# 抽样
print(f"\n前 3 个文档:")
rows = await conn.fetch(
f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3"
)
for i, row in enumerate(rows):
print(f"\n {i+1}. Key: {row['key']}")
val = row['value']
if isinstance(val, dict) and 'page_content' in val:
print(f" 内容: {val['page_content'][:100]}...")
# Key 前缀分布
key_prefixes = await conn.fetch(
f"""
SELECT
CASE
WHEN key LIKE '%:%' THEN split_part(key, ':', 1)
ELSE 'no_prefix'
END AS prefix,
COUNT(*) AS cnt
FROM {TABLE_NAME}
GROUP BY prefix
ORDER BY cnt DESC
LIMIT 10
"""
)
print(f"\nKey 前缀分布:")
for row in key_prefixes:
print(f" {row['prefix']}: {row['cnt']}")
finally:
await conn.close()
async def test_search():
"""测试检索功能。"""
from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
from rag_indexer.splitters import SplitterType
print("\n" + "=" * 60)
print("检索测试")
print("=" * 60)
# 使用配置对象初始化(与默认构建方式一致)
config = IndexBuilderConfig(
collection_name=COLLECTION_NAME,
splitter_type=SplitterType.PARENT_CHILD,
)
builder = IndexBuilder(config)
# 确保检索器已初始化
if builder.retriever is None:
print("错误: 检索器未初始化,请检查切分策略")
return
query = input("\n查询 (回车使用默认): ").strip() or "你好"
print(f"\n查询: {query}")
# 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块)
print("\n--- 标准检索 (返回父块) ---")
results = await builder.retriever.ainvoke(query)
for i, doc in enumerate(results):
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
print(f"\n {i+1}. {content}...")
if hasattr(doc, 'metadata'):
source = doc.metadata.get('source', '')
if source:
print(f" 来源: {source}")
# 若需要仅返回子块,可以临时修改检索器的 search_type
# 注意ParentDocumentRetriever 的 search_type 默认为 "similarity"
print("\n--- 检索子块 (通过修改检索器参数) ---")
# 创建一个新的检索器副本,设置为返回子块
# 简单起见,直接调用 vectorstore 进行相似度搜索获取子块
vectorstore = builder.vector_store.get_langchain_vectorstore()
sub_results = await vectorstore.asimilarity_search(query, k=3)
for i, doc in enumerate(sub_results):
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
print(f"\n {i+1}. {content}...")
if hasattr(doc, 'metadata'):
parent_id = doc.metadata.get('parent_id', '')
if parent_id:
print(f" 父块 ID: {parent_id}")
async def main():
check_qdrant()
await check_postgres()
await test_search()
if __name__ == "__main__":
asyncio.run(main())