Files
ailine/backend/app/rag/pipeline.py
root ef6fbc1521
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m36s
推理优化
2026-05-06 04:26:06 +08:00

163 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索流水线
流程: 检索子文档 → 重排 → 获取父文档 → 返回
"""
import asyncio
import logging
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from ..model_services import get_rerank_service, get_small_llm_service
from ..rag.rerank import create_document_reranker
from ..rag.query_transform import MultiQueryGenerator
from ..rag.fusion import reciprocal_rank_fusion
from ..rag.retriever import create_parent_hybrid_retriever
logger = logging.getLogger(__name__)
class RAGPipeline:
def __init__(
self,
retriever=None,
llm: BaseLanguageModel | str = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
use_rerank: bool = True,
return_parent_docs: bool = True,
):
self.retriever = retriever or create_parent_hybrid_retriever(
collection_name=collection_name, search_k=rerank_top_n * 4
)
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
self.use_rerank = use_rerank
self.return_parent_docs = return_parent_docs
self._last_docs = [] # 保存最后一次检索的文档
self._last_scores = [] # 保存最后一次检索的分数
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception:
self.llm = None
else:
self.llm = llm if llm else None
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
self.reranker = create_document_reranker() if use_rerank else None
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
@property
def last_docs(self) -> List[Document]:
"""获取最后一次检索的文档"""
return self._last_docs
@property
def last_scores(self) -> List[dict]:
"""获取最后一次检索的分数信息"""
return self._last_scores
async def aretrieve(self, query: str) -> List[Document]:
# Step 1: 检索
child_docs = await self._retrieve(query)
logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档")
# 调试:打印子文档长度
for i, doc in enumerate(child_docs[:5]):
content_len = len(doc.page_content)
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
# Step 2: 重排
if self.reranker:
try:
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
logger.info(f"[Pipeline] 重排后 {len(child_docs)}")
except Exception as e:
logger.warning(f"[Pipeline] 重排失败: {e}")
child_docs = child_docs[:self.rerank_top_n]
# Step 3: 获取父文档
if self.return_parent_docs:
parent_docs = await self._get_parents(child_docs)
# 保存分数信息到 last_scores 供外部访问
self._last_scores = self._extract_scores(parent_docs)
return parent_docs
self._last_scores = self._extract_scores(child_docs)
return child_docs
def _extract_scores(self, docs: List[Document]) -> List[dict]:
"""提取文档的分数信息"""
scores = []
for doc in docs:
scores.append({
"embedding_score": doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0)),
"rerank_score": doc.metadata.get("rerank_score", 0.0),
})
return scores
async def _retrieve(self, query: str) -> List[Document]:
if self.query_generator:
queries = await self.query_generator.agenerate(query)
queries = [query] + [q for q in queries if q != query]
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
return reciprocal_rank_fusion(doc_lists)
return await self.retriever.ainvoke(query)
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
# 收集 parent_id 和对应的分数
parent_map = {} # parent_id -> (embedding_score, rerank_score)
for doc in child_docs:
pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map:
# embedding 分数
embedding_score = doc.metadata.get("score", 0.0)
# rerank 分数(如果有的话)
rerank_score = doc.metadata.get("rerank_score", 0.0)
parent_map[pid] = (embedding_score, rerank_score)
if not parent_map:
logger.warning("[Pipeline] 未找到 parent_id返回子文档")
return child_docs
try:
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
parent_docs =await docstore.amget(list(parent_map.keys()))
# 构建结果,保持分数信息
result = []
for doc in parent_docs:
if doc:
pid = doc.metadata.get("id")
scores = parent_map.get(pid, (0.0, 0.0))
# 将分数添加到 metadata 中
doc.metadata["embedding_score"] = scores[0]
doc.metadata["rerank_score"] = scores[1]
result.append((doc, scores[0] + scores[1] * 2)) # 综合分数rerank 权重更高
result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result]
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
return docs
except Exception as e:
logger.warning(f"[Pipeline] 获取父文档失败: {e}")
return child_docs
def format_context(self, documents: List[Document]) -> str:
if not documents:
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)
def create_rag_pipeline(**kwargs) -> RAGPipeline:
return RAGPipeline(**kwargs)