Files
ailine/backend/app/main_graph/utils/rag_initializer.py

74 lines
1.8 KiB
Python
Raw Normal View History

2026-04-21 11:02:16 +08:00
# app/rag_initializer.py
2026-05-05 23:17:00 +08:00
from ...rag.tools import create_rag_tool
from ...rag.retriever import create_parent_hybrid_retriever
from ...model_services import get_embedding_service
2026-05-06 01:15:52 +08:00
from backend.app.logger import info, warning
import sys
# 全局 RAG 工具
_rag_tool = None
_initialized = False
def get_rag_tool() -> callable:
"""获取全局 RAG 工具"""
return _rag_tool
def is_initialized() -> bool:
"""检查是否已初始化"""
return _initialized
2026-05-05 13:30:31 +08:00
async def init_rag_tool(force: bool = False):
"""
2026-05-05 13:30:31 +08:00
初始化 RAG 工具注册到模块级变量内部获取所需服务
Args:
force: 是否强制重新初始化
Returns:
RAG 工具@tool 装饰函数 None
"""
global _rag_tool, _initialized
# 防止重复初始化
if _initialized and not force:
info("[RAG] 已初始化,跳过")
return _rag_tool
2026-04-21 11:02:16 +08:00
try:
2026-05-05 23:17:00 +08:00
from backend.app.model_services.chat_services import get_chat_service
2026-05-05 13:30:31 +08:00
2026-04-21 11:02:16 +08:00
info("🔄 正在初始化 RAG 检索系统...")
embeddings = get_embedding_service()
retriever = create_parent_hybrid_retriever(
2026-04-21 11:02:16 +08:00
collection_name="rag_documents",
search_k=5,
2026-05-05 13:30:31 +08:00
embeddings=embeddings,
2026-04-21 11:02:16 +08:00
)
2026-05-05 13:30:31 +08:00
rewrite_llm = get_chat_service()
rag_tool = create_rag_tool(
retriever=retriever,
llm=rewrite_llm,
num_queries=3,
2026-05-05 13:30:31 +08:00
rerank_top_n=5,
2026-04-21 11:02:16 +08:00
)
_rag_tool = rag_tool
_initialized = True
info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})")
2026-04-21 11:02:16 +08:00
return rag_tool
2026-04-21 11:02:16 +08:00
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None
def reset():
"""重置(用于测试)"""
global _rag_tool, _initialized
_rag_tool = None
_initialized = False