2026-05-04 02:01:22 +08:00
|
|
|
|
"""
|
2026-05-05 23:17:00 +08:00
|
|
|
|
RAG 检索流水线
|
|
|
|
|
|
流程: 检索子文档 → 重排 → 获取父文档 → 返回
|
2026-05-04 02:01:22 +08:00
|
|
|
|
"""
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
|
|
|
|
|
import asyncio
|
2026-05-05 23:17:00 +08:00
|
|
|
|
import logging
|
|
|
|
|
|
from typing import List
|
2026-04-21 11:02:16 +08:00
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
from ..model_services import get_rerank_service, get_small_llm_service
|
|
|
|
|
|
from ..rag.rerank import create_document_reranker
|
|
|
|
|
|
from ..rag.query_transform import MultiQueryGenerator
|
|
|
|
|
|
from ..rag.fusion import reciprocal_rank_fusion
|
|
|
|
|
|
from ..rag.retriever import create_parent_hybrid_retriever
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2026-05-04 02:01:22 +08:00
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
|
|
|
|
|
|
class RAGPipeline:
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
2026-05-04 02:01:22 +08:00
|
|
|
|
retriever=None,
|
2026-05-05 23:17:00 +08:00
|
|
|
|
llm: BaseLanguageModel | str = "default_small",
|
2026-04-21 11:02:16 +08:00
|
|
|
|
num_queries: int = 3,
|
|
|
|
|
|
rerank_top_n: int = 5,
|
2026-05-04 02:01:22 +08:00
|
|
|
|
collection_name: str = "rag_documents",
|
2026-05-05 23:17:00 +08:00
|
|
|
|
use_rerank: bool = True,
|
|
|
|
|
|
return_parent_docs: bool = True,
|
2026-04-21 11:02:16 +08:00
|
|
|
|
):
|
2026-05-05 23:17:00 +08:00
|
|
|
|
self.retriever = retriever or create_parent_hybrid_retriever(
|
|
|
|
|
|
collection_name=collection_name, search_k=rerank_top_n * 4
|
|
|
|
|
|
)
|
|
|
|
|
|
self.num_queries = num_queries
|
|
|
|
|
|
self.rerank_top_n = rerank_top_n
|
|
|
|
|
|
self.use_rerank = use_rerank
|
|
|
|
|
|
self.return_parent_docs = return_parent_docs
|
|
|
|
|
|
|
2026-05-04 17:58:10 +08:00
|
|
|
|
if llm == "default_small":
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.llm = get_small_llm_service()
|
2026-05-05 23:17:00 +08:00
|
|
|
|
except Exception:
|
2026-05-04 17:58:10 +08:00
|
|
|
|
self.llm = None
|
|
|
|
|
|
else:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
self.llm = llm if llm else None
|
|
|
|
|
|
|
|
|
|
|
|
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
|
|
|
|
|
|
self.reranker = create_document_reranker() if use_rerank else None
|
|
|
|
|
|
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
|
|
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
async def aretrieve(self, query: str) -> List[Document]:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
# Step 1: 检索
|
|
|
|
|
|
child_docs = await self._retrieve(query)
|
|
|
|
|
|
logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档")
|
|
|
|
|
|
# 调试:打印子文档长度
|
|
|
|
|
|
for i, doc in enumerate(child_docs[:5]):
|
|
|
|
|
|
content_len = len(doc.page_content)
|
|
|
|
|
|
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
|
|
|
|
|
|
|
|
|
|
|
|
# Step 2: 重排
|
|
|
|
|
|
if self.reranker:
|
|
|
|
|
|
try:
|
|
|
|
|
|
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
|
|
|
|
|
|
logger.info(f"[Pipeline] 重排后 {len(child_docs)} 个")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"[Pipeline] 重排失败: {e}")
|
|
|
|
|
|
child_docs = child_docs[:self.rerank_top_n]
|
|
|
|
|
|
|
|
|
|
|
|
# Step 3: 获取父文档
|
|
|
|
|
|
if self.return_parent_docs:
|
|
|
|
|
|
return await self._get_parents(child_docs)
|
|
|
|
|
|
return child_docs
|
|
|
|
|
|
|
|
|
|
|
|
async def _retrieve(self, query: str) -> List[Document]:
|
|
|
|
|
|
if self.query_generator:
|
2026-05-04 02:01:22 +08:00
|
|
|
|
queries = await self.query_generator.agenerate(query)
|
2026-05-05 23:17:00 +08:00
|
|
|
|
queries = [query] + [q for q in queries if q != query]
|
|
|
|
|
|
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
|
|
|
|
|
|
return reciprocal_rank_fusion(doc_lists)
|
|
|
|
|
|
return await self.retriever.ainvoke(query)
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
|
2026-05-06 01:15:52 +08:00
|
|
|
|
# 收集 parent_id 和对应的分数
|
|
|
|
|
|
parent_map = {} # parent_id -> (embedding_score, rerank_score)
|
|
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
for doc in child_docs:
|
|
|
|
|
|
pid = doc.metadata.get("parent_id")
|
|
|
|
|
|
if pid and pid not in parent_map:
|
2026-05-06 01:15:52 +08:00
|
|
|
|
# embedding 分数
|
|
|
|
|
|
embedding_score = doc.metadata.get("score", 0.0)
|
|
|
|
|
|
# rerank 分数(如果有的话)
|
|
|
|
|
|
rerank_score = doc.metadata.get("rerank_score", 0.0)
|
|
|
|
|
|
parent_map[pid] = (embedding_score, rerank_score)
|
2026-05-05 23:17:00 +08:00
|
|
|
|
|
|
|
|
|
|
if not parent_map:
|
|
|
|
|
|
logger.warning("[Pipeline] 未找到 parent_id,返回子文档")
|
|
|
|
|
|
return child_docs
|
|
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
try:
|
2026-05-05 23:17:00 +08:00
|
|
|
|
from backend.rag_core import create_docstore
|
|
|
|
|
|
docstore, _ = create_docstore()
|
|
|
|
|
|
parent_docs = docstore.mget(list(parent_map.keys()))
|
2026-05-06 01:15:52 +08:00
|
|
|
|
|
|
|
|
|
|
# 构建结果,保持分数信息
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for doc in parent_docs:
|
|
|
|
|
|
if doc:
|
|
|
|
|
|
pid = doc.metadata.get("id")
|
|
|
|
|
|
scores = parent_map.get(pid, (0.0, 0.0))
|
|
|
|
|
|
# 将分数添加到 metadata 中
|
|
|
|
|
|
doc.metadata["embedding_score"] = scores[0]
|
|
|
|
|
|
doc.metadata["rerank_score"] = scores[1]
|
|
|
|
|
|
result.append((doc, scores[0] + scores[1] * 2)) # 综合分数,rerank 权重更高
|
|
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
result.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
docs = [d for d, _ in result]
|
|
|
|
|
|
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
|
|
|
|
|
|
return docs
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"[Pipeline] 获取父文档失败: {e}")
|
|
|
|
|
|
return child_docs
|
2026-05-04 17:58:10 +08:00
|
|
|
|
|
2026-04-21 11:02:16 +08:00
|
|
|
|
def format_context(self, documents: List[Document]) -> str:
|
|
|
|
|
|
if not documents:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
parts = []
|
|
|
|
|
|
for i, doc in enumerate(documents, 1):
|
|
|
|
|
|
source = doc.metadata.get("source", "未知来源")
|
|
|
|
|
|
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
|
2026-05-04 02:01:22 +08:00
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-05 23:17:00 +08:00
|
|
|
|
def create_rag_pipeline(**kwargs) -> RAGPipeline:
|
|
|
|
|
|
return RAGPipeline(**kwargs)
|