92 lines
2.7 KiB
Python
92 lines
2.7 KiB
Python
"""清理 RAG 索引数据。
|
||
|
||
用法:
|
||
python reset_index.py # 清理全部
|
||
python reset_index.py --qdrant # 仅清理 Qdrant
|
||
python reset_index.py --postgres # 仅清理 PostgreSQL
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
import argparse
|
||
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
QDRANT_URL = os.getenv("QDRANT_URL")
|
||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||
|
||
# PostgreSQL 配置(使用分离配置,优先于 DB_URI)
|
||
DB_HOST = os.getenv("DB_HOST")
|
||
DB_PORT = os.getenv("DB_PORT", "5432")
|
||
DB_USER = os.getenv("DB_USER")
|
||
DB_PASSWORD = os.getenv("DB_PASSWORD")
|
||
DB_NAME = os.getenv("DB_NAME")
|
||
|
||
# 构建 DB_URI(如果没有直接配置)
|
||
DB_URI = os.getenv("DB_URI")
|
||
if not DB_URI and all([DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME]):
|
||
DB_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}?sslmode=disable"
|
||
COLLECTION_NAME = "rag_documents"
|
||
TABLE_NAME = "parent_documents"
|
||
|
||
|
||
def clear_qdrant():
|
||
"""删除 Qdrant 集合。"""
|
||
from qdrant_client import QdrantClient
|
||
|
||
print("清理 Qdrant...")
|
||
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||
|
||
collections = client.get_collections().collections
|
||
if any(c.name == COLLECTION_NAME for c in collections):
|
||
client.delete_collection(COLLECTION_NAME)
|
||
print(f" 集合 '{COLLECTION_NAME}' 已删除")
|
||
else:
|
||
print(f" 集合 '{COLLECTION_NAME}' 不存在")
|
||
|
||
|
||
async def clear_postgres():
|
||
"""清空 PostgreSQL 表数据。"""
|
||
import asyncpg
|
||
|
||
print("清理 PostgreSQL...")
|
||
conn = await asyncpg.connect(dsn=DB_URI)
|
||
|
||
try:
|
||
exists = await conn.fetchval(
|
||
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1)",
|
||
TABLE_NAME
|
||
)
|
||
if exists:
|
||
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
|
||
await conn.execute(f"DELETE FROM {TABLE_NAME}")
|
||
print(f" 表 '{TABLE_NAME}' 已清空,删除 {count} 条记录")
|
||
else:
|
||
print(f" 表 '{TABLE_NAME}' 不存在")
|
||
finally:
|
||
await conn.close()
|
||
|
||
|
||
async def main():
|
||
parser = argparse.ArgumentParser(description="清理 RAG 索引数据")
|
||
parser.add_argument("--qdrant", action="store_true", help="仅清理 Qdrant")
|
||
parser.add_argument("--postgres", action="store_true", help="仅清理 PostgreSQL")
|
||
args = parser.parse_args()
|
||
|
||
if not args.qdrant and not args.postgres:
|
||
args.qdrant = True
|
||
args.postgres = True
|
||
|
||
if args.qdrant:
|
||
clear_qdrant()
|
||
|
||
if args.postgres:
|
||
await clear_postgres()
|
||
|
||
print("\n完成。运行 `python -m rag_indexer.cli` 重建索引")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|