This commit is contained in:
@@ -1,137 +1,114 @@
|
||||
"""
|
||||
RAG 检索流水线模块
|
||||
|
||||
提供固定流程的 RAG 检索:
|
||||
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||
|
||||
默认使用混合检索(稠密+稀疏)+ 父子文档模式。
|
||||
RAG 检索流水线
|
||||
流程: 检索子文档 → 重排 → 获取父文档 → 返回
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from app.model_services import get_rerank_service, get_small_llm_service
|
||||
from app.rag.rerank import create_document_reranker
|
||||
from app.rag.query_transform import MultiQueryGenerator
|
||||
from app.rag.fusion import reciprocal_rank_fusion
|
||||
from app.rag.retriever import create_parent_hybrid_retriever
|
||||
from ..model_services import get_rerank_service, get_small_llm_service
|
||||
from ..rag.rerank import create_document_reranker
|
||||
from ..rag.query_transform import MultiQueryGenerator
|
||||
from ..rag.fusion import reciprocal_rank_fusion
|
||||
from ..rag.retriever import create_parent_hybrid_retriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGPipeline:
|
||||
"""
|
||||
固定流程的 RAG 检索流水线:
|
||||
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever=None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
llm: BaseLanguageModel | str = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
use_rerank: bool = True,
|
||||
return_parent_docs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
|
||||
如果不提供,会自动创建默认的父子文档混合检索器。
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量。
|
||||
rerank_top_n: 最终返回的文档数量。
|
||||
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
|
||||
"""
|
||||
# 如果没有提供 retriever,自动创建默认的混合检索器
|
||||
if retriever is None:
|
||||
self.retriever = create_parent_hybrid_retriever(
|
||||
collection_name=collection_name,
|
||||
search_k=rerank_top_n * 2 # 多取一些给重排序用
|
||||
)
|
||||
else:
|
||||
self.retriever = retriever
|
||||
|
||||
# 处理 llm 参数
|
||||
self.retriever = retriever or create_parent_hybrid_retriever(
|
||||
collection_name=collection_name, search_k=rerank_top_n * 4
|
||||
)
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
self.use_rerank = use_rerank
|
||||
self.return_parent_docs = return_parent_docs
|
||||
|
||||
if llm == "default_small":
|
||||
try:
|
||||
self.llm = get_small_llm_service()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"小模型初始化失败,将不做查询改写: {e}")
|
||||
except Exception:
|
||||
self.llm = None
|
||||
elif llm in (None, False):
|
||||
self.llm = None
|
||||
else:
|
||||
self.llm = llm
|
||||
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
|
||||
# 初始化组件 - 使用统一的重排服务获取接口
|
||||
self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None
|
||||
self.reranker = create_document_reranker()
|
||||
|
||||
self.llm = llm if llm else None
|
||||
|
||||
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
|
||||
self.reranker = create_document_reranker() if use_rerank else None
|
||||
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
"""
|
||||
异步执行完整检索流程
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
|
||||
Returns:
|
||||
检索到的相关文档列表
|
||||
"""
|
||||
# 如果有 query_generator,做多路改写
|
||||
if self.query_generator and self.llm:
|
||||
# Step 1: 生成多路查询
|
||||
# Step 1: 检索
|
||||
child_docs = await self._retrieve(query)
|
||||
logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档")
|
||||
# 调试:打印子文档长度
|
||||
for i, doc in enumerate(child_docs[:5]):
|
||||
content_len = len(doc.page_content)
|
||||
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
|
||||
|
||||
# Step 2: 重排
|
||||
if self.reranker:
|
||||
try:
|
||||
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
|
||||
logger.info(f"[Pipeline] 重排后 {len(child_docs)} 个")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Pipeline] 重排失败: {e}")
|
||||
child_docs = child_docs[:self.rerank_top_n]
|
||||
|
||||
# Step 3: 获取父文档
|
||||
if self.return_parent_docs:
|
||||
return await self._get_parents(child_docs)
|
||||
return child_docs
|
||||
|
||||
async def _retrieve(self, query: str) -> List[Document]:
|
||||
if self.query_generator:
|
||||
queries = await self.query_generator.agenerate(query)
|
||||
# 包含原始查询,确保至少有一条
|
||||
if query not in queries:
|
||||
queries.insert(0, query)
|
||||
else:
|
||||
# 如果原始查询已在列表中,将其移至首位
|
||||
queries.remove(query)
|
||||
queries.insert(0, query)
|
||||
|
||||
# Step 2: 并行检索(每个查询获取文档列表)
|
||||
tasks = [self.retriever.ainvoke(q) for q in queries]
|
||||
doc_lists = await asyncio.gather(*tasks)
|
||||
|
||||
# Step 3: RRF 融合
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||
else:
|
||||
# 没有 LLM 做查询改写,直接用原始查询检索
|
||||
fused_docs = await self.retriever.ainvoke(query)
|
||||
|
||||
# Step 4: 重排序
|
||||
queries = [query] + [q for q in queries if q != query]
|
||||
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
|
||||
return reciprocal_rank_fusion(doc_lists)
|
||||
return await self.retriever.ainvoke(query)
|
||||
|
||||
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
|
||||
parent_map = {}
|
||||
for doc in child_docs:
|
||||
pid = doc.metadata.get("parent_id")
|
||||
if pid and pid not in parent_map:
|
||||
parent_map[pid] = doc.metadata.get("score", 0.0)
|
||||
|
||||
if not parent_map:
|
||||
logger.warning("[Pipeline] 未找到 parent_id,返回子文档")
|
||||
return child_docs
|
||||
|
||||
try:
|
||||
final_docs = self.reranker.compress_documents(fused_docs, query, top_n=self.rerank_top_n)
|
||||
except Exception:
|
||||
# 若重排序器不可用,直接返回融合后的前 N 个结果
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
return final_docs
|
||||
from backend.rag_core import create_docstore
|
||||
docstore, _ = create_docstore()
|
||||
# 同步获取(异步版本不存在)
|
||||
parent_docs = docstore.mget(list(parent_map.keys()))
|
||||
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
|
||||
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2]
|
||||
result.sort(key=lambda x: x[1], reverse=True)
|
||||
docs = [d for d, _ in result]
|
||||
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
|
||||
return docs
|
||||
except Exception as e:
|
||||
logger.warning(f"[Pipeline] 获取父文档失败: {e}")
|
||||
return child_docs
|
||||
|
||||
def format_context(self, documents: List[Document]) -> str:
|
||||
"""
|
||||
将文档列表格式化为上下文字符串
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
|
||||
Returns:
|
||||
格式化后的上下文字符串
|
||||
"""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for i, doc in enumerate(documents, 1):
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
@@ -139,30 +116,5 @@ class RAGPipeline:
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def create_rag_pipeline(
|
||||
collection_name: str = "rag_documents",
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
) -> RAGPipeline:
|
||||
"""
|
||||
创建 RAG 检索流水线的便捷函数
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
|
||||
Returns:
|
||||
RAGPipeline 实例
|
||||
"""
|
||||
return RAGPipeline(
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
collection_name=collection_name
|
||||
)
|
||||
def create_rag_pipeline(**kwargs) -> RAGPipeline:
|
||||
return RAGPipeline(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user