refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s

This commit is contained in:
2026-05-04 17:58:10 +08:00
parent a07e398739
commit 9841f47432
31 changed files with 578 additions and 1496 deletions

View File

@@ -13,7 +13,7 @@ from typing import List, Optional
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from app.model_services import get_rerank_service
from app.model_services import get_rerank_service, get_small_llm_service
from app.rag.rerank import create_document_reranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
@@ -31,7 +31,7 @@ class RAGPipeline:
def __init__(
self,
retriever=None,
llm: Optional[BaseLanguageModel] = None,
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
@@ -41,6 +41,9 @@ class RAGPipeline:
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
如果不提供,会自动创建默认的父子文档混合检索器。
llm: 用于生成多路查询的语言模型。
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量。
rerank_top_n: 最终返回的文档数量。
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
@@ -53,13 +56,26 @@ class RAGPipeline:
)
else:
self.retriever = retriever
# 处理 llm 参数
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"小模型初始化失败,将不做查询改写: {e}")
self.llm = None
elif llm in (None, False):
self.llm = None
else:
self.llm = llm
self.llm = llm
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
# 初始化组件 - 使用统一的重排服务获取接口
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None
self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None
self.reranker = create_document_reranker()
async def aretrieve(self, query: str) -> List[Document]:
@@ -102,11 +118,7 @@ class RAGPipeline:
final_docs = fused_docs[:self.rerank_top_n]
return final_docs
def retrieve(self, query: str) -> List[Document]:
"""同步检索入口(内部调用异步方法)"""
return asyncio.run(self.aretrieve(query))
def format_context(self, documents: List[Document]) -> str:
"""
将文档列表格式化为上下文字符串
@@ -129,7 +141,7 @@ class RAGPipeline:
def create_rag_pipeline(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
) -> RAGPipeline:
@@ -138,7 +150,10 @@ def create_rag_pipeline(
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型
llm: 用于生成多路查询的语言模型
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量