74 lines
1.8 KiB
Python
74 lines
1.8 KiB
Python
# app/rag_initializer.py
|
|
from ...rag.tools import create_rag_tool
|
|
from ...rag.retriever import create_parent_hybrid_retriever
|
|
from ...model_services import get_embedding_service
|
|
from ...logger import info, warning
|
|
import sys
|
|
|
|
# 全局 RAG 工具
|
|
_rag_tool = None
|
|
_initialized = False
|
|
|
|
|
|
def get_rag_tool() -> callable:
|
|
"""获取全局 RAG 工具"""
|
|
return _rag_tool
|
|
|
|
|
|
def is_initialized() -> bool:
|
|
"""检查是否已初始化"""
|
|
return _initialized
|
|
|
|
|
|
async def init_rag_tool(force: bool = False):
|
|
"""
|
|
初始化 RAG 工具(注册到模块级变量,内部获取所需服务)
|
|
|
|
Args:
|
|
force: 是否强制重新初始化
|
|
|
|
Returns:
|
|
RAG 工具(@tool 装饰函数)或 None
|
|
"""
|
|
global _rag_tool, _initialized
|
|
|
|
# 防止重复初始化
|
|
if _initialized and not force:
|
|
info("[RAG] 已初始化,跳过")
|
|
return _rag_tool
|
|
|
|
try:
|
|
from backend.app.model_services.chat_services import get_chat_service
|
|
|
|
info("🔄 正在初始化 RAG 检索系统...")
|
|
embeddings = get_embedding_service()
|
|
retriever = create_parent_hybrid_retriever(
|
|
collection_name="rag_documents",
|
|
search_k=5,
|
|
embeddings=embeddings,
|
|
)
|
|
rewrite_llm = get_chat_service()
|
|
|
|
rag_tool = create_rag_tool(
|
|
retriever=retriever,
|
|
llm=rewrite_llm,
|
|
num_queries=3,
|
|
rerank_top_n=5,
|
|
)
|
|
|
|
_rag_tool = rag_tool
|
|
_initialized = True
|
|
info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})")
|
|
return rag_tool
|
|
|
|
except Exception as e:
|
|
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
|
return None
|
|
|
|
|
|
def reset():
|
|
"""重置(用于测试)"""
|
|
global _rag_tool, _initialized
|
|
_rag_tool = None
|
|
_initialized = False
|