""" RAG 检索流水线 流程: 检索子文档 → 重排 → 获取父文档 → 返回 """ import asyncio import logging from typing import List from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from ..model_services import get_rerank_service, get_small_llm_service from ..rag.rerank import create_document_reranker from ..rag.query_transform import MultiQueryGenerator from ..rag.fusion import reciprocal_rank_fusion from ..rag.retriever import create_parent_hybrid_retriever logger = logging.getLogger(__name__) class RAGPipeline: def __init__( self, retriever=None, llm: BaseLanguageModel | str = "default_small", num_queries: int = 3, rerank_top_n: int = 5, collection_name: str = "rag_documents", use_rerank: bool = True, return_parent_docs: bool = True, ): self.retriever = retriever or create_parent_hybrid_retriever( collection_name=collection_name, search_k=rerank_top_n * 4 ) self.num_queries = num_queries self.rerank_top_n = rerank_top_n self.use_rerank = use_rerank self.return_parent_docs = return_parent_docs self._last_docs = [] # 保存最后一次检索的文档 self._last_scores = [] # 保存最后一次检索的分数 if llm == "default_small": try: self.llm = get_small_llm_service() except Exception: self.llm = None else: self.llm = llm if llm else None self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None self.reranker = create_document_reranker() if use_rerank else None logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}") @property def last_docs(self) -> List[Document]: """获取最后一次检索的文档""" return self._last_docs @property def last_scores(self) -> List[dict]: """获取最后一次检索的分数信息""" return self._last_scores async def aretrieve(self, query: str) -> List[Document]: # Step 1: 检索 child_docs = await self._retrieve(query) logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档") # 调试:打印子文档长度 for i, doc in enumerate(child_docs[:5]): content_len = len(doc.page_content) logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符") # Step 2: 重排 if self.reranker: try: child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n) logger.info(f"[Pipeline] 重排后 {len(child_docs)} 个") except Exception as e: logger.warning(f"[Pipeline] 重排失败: {e}") child_docs = child_docs[:self.rerank_top_n] # Step 3: 获取父文档 if self.return_parent_docs: parent_docs = await self._get_parents(child_docs) # 保存分数信息到 last_scores 供外部访问 self._last_scores = self._extract_scores(parent_docs) return parent_docs self._last_scores = self._extract_scores(child_docs) return child_docs def _extract_scores(self, docs: List[Document]) -> List[dict]: """提取文档的分数信息""" scores = [] for doc in docs: scores.append({ "embedding_score": doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0)), "rerank_score": doc.metadata.get("rerank_score", 0.0), }) return scores async def _retrieve(self, query: str) -> List[Document]: if self.query_generator: queries = await self.query_generator.agenerate(query) queries = [query] + [q for q in queries if q != query] doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries]) return reciprocal_rank_fusion(doc_lists) return await self.retriever.ainvoke(query) async def _get_parents(self, child_docs: List[Document]) -> List[Document]: # 收集 parent_id 和对应的分数 parent_map = {} # parent_id -> (embedding_score, rerank_score) for doc in child_docs: pid = doc.metadata.get("parent_id") if pid and pid not in parent_map: # embedding 分数 embedding_score = doc.metadata.get("score", 0.0) # rerank 分数(如果有的话) rerank_score = doc.metadata.get("rerank_score", 0.0) parent_map[pid] = (embedding_score, rerank_score) if not parent_map: logger.warning("[Pipeline] 未找到 parent_id,返回子文档") return child_docs try: from backend.rag_core import create_docstore docstore, _ = create_docstore() parent_docs =await docstore.amget(list(parent_map.keys())) # 构建结果,保持分数信息 result = [] for doc in parent_docs: if doc: pid = doc.metadata.get("id") scores = parent_map.get(pid, (0.0, 0.0)) # 将分数添加到 metadata 中 doc.metadata["embedding_score"] = scores[0] doc.metadata["rerank_score"] = scores[1] result.append((doc, scores[0] + scores[1] * 2)) # 综合分数,rerank 权重更高 result.sort(key=lambda x: x[1], reverse=True) docs = [d for d, _ in result] logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档") return docs except Exception as e: logger.warning(f"[Pipeline] 获取父文档失败: {e}") return child_docs def format_context(self, documents: List[Document]) -> str: if not documents: return "" parts = [] for i, doc in enumerate(documents, 1): source = doc.metadata.get("source", "未知来源") parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n") return "\n".join(parts) def create_rag_pipeline(**kwargs) -> RAGPipeline: return RAGPipeline(**kwargs)