Files
ailine/backend/app/rag/pipeline.py
root d09b0d16ce
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 10m46s
修改日志用项目统一的 logger
2026-05-06 16:15:09 +08:00

192 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索流水线
流程: 检索子文档 → 重排 → 获取父文档 → 返回
"""
import asyncio
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from backend.app.logger import info, warning
from ..model_services import get_rerank_service, get_small_llm_service
from ..rag.rerank import create_document_reranker
from ..rag.query_transform import MultiQueryGenerator
from ..rag.fusion import reciprocal_rank_fusion
from ..rag.retriever import create_parent_hybrid_retriever
class RAGPipeline:
def __init__(
self,
retriever=None,
llm: BaseLanguageModel | str = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
use_rerank: bool = True,
return_parent_docs: bool = True,
):
self.retriever = retriever or create_parent_hybrid_retriever(
collection_name=collection_name, search_k=rerank_top_n * 4
)
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
self.use_rerank = use_rerank
self.return_parent_docs = return_parent_docs
self._last_docs = [] # 保存最后一次检索的文档
self._last_scores = [] # 保存最后一次检索的分数
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception:
self.llm = None
else:
self.llm = llm if llm else None
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
self.reranker = create_document_reranker() if use_rerank else None
info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
@property
def last_docs(self) -> List[Document]:
"""获取最后一次检索的文档"""
return self._last_docs
@property
def last_scores(self) -> List[dict]:
"""获取最后一次检索的分数信息"""
return self._last_scores
async def aretrieve(self, query: str) -> List[Document]:
info(f"[Pipeline] aretrieve 开始: query={query[:50]}...")
# Step 1: 检索
info(f"[Pipeline] Step 1: 调用 _retrieve")
child_docs = await self._retrieve(query)
info(f"[Pipeline] Step 1 完成: 检索到 {len(child_docs)} 个子文档")
# 调试:打印子文档长度
for i, doc in enumerate(child_docs[:5]):
content_len = len(doc.page_content)
info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
# Step 2: 重排
info(f"[Pipeline] Step 2: 开始重排")
if self.reranker:
try:
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
info(f"[Pipeline] Step 2 完成: 重排后 {len(child_docs)}")
except Exception as e:
warning(f"[Pipeline] 重排失败: {e}")
child_docs = child_docs[:self.rerank_top_n]
else:
info(f"[Pipeline] Step 2 跳过: 未启用 reranker")
# Step 3: 获取父文档
info(f"[Pipeline] Step 3: 开始获取父文档")
if self.return_parent_docs:
parent_docs = await self._get_parents(child_docs)
info(f"[Pipeline] Step 3 完成: 获取到 {len(parent_docs)} 个父文档")
# 保存分数信息到 last_scores 供外部访问
self._last_scores = self._extract_scores(parent_docs)
info(f"[Pipeline] aretrieve 结束: 返回父文档")
return parent_docs
self._last_scores = self._extract_scores(child_docs)
info(f"[Pipeline] aretrieve 结束: 返回子文档")
return child_docs
def _extract_scores(self, docs: List[Document]) -> List[dict]:
"""提取文档的分数信息"""
scores = []
for doc in docs:
scores.append({
"embedding_score": doc.metadata.get("embedding_score", doc.metadata.get("score", 0.0)),
"rerank_score": doc.metadata.get("rerank_score", 0.0),
})
return scores
async def _retrieve(self, query: str) -> List[Document]:
info(f"[Pipeline] _retrieve 开始: query={query[:50]}...")
if self.query_generator:
info(f"[Pipeline] _retrieve: 调用 query_generator.agenerate")
queries = await self.query_generator.agenerate(query)
queries = [query] + [q for q in queries if q != query]
info(f"[Pipeline] _retrieve: 生成 {len(queries)} 个查询: {queries}")
info(f"[Pipeline] _retrieve: 开始 asyncio.gather 并行检索")
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
info(f"[Pipeline] _retrieve: asyncio.gather 完成,得到 {len(doc_lists)} 组结果")
info(f"[Pipeline] _retrieve: 开始 reciprocal_rank_fusion")
result = reciprocal_rank_fusion(doc_lists)
info(f"[Pipeline] _retrieve: RRF 完成,得到 {len(result)} 个文档")
info(f"[Pipeline] _retrieve 结束")
return result
info(f"[Pipeline] _retrieve: query_generator 未启用,直接单次检索")
result = await self.retriever.ainvoke(query)
info(f"[Pipeline] _retrieve 结束")
return result
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
info(f"[Pipeline] _get_parents 开始: {len(child_docs)} 个子文档")
# 收集 parent_id 和对应的分数
parent_map = {} # parent_id -> (embedding_score, rerank_score)
for doc in child_docs:
pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map:
# embedding 分数
embedding_score = doc.metadata.get("score", 0.0)
# rerank 分数(如果有的话)
rerank_score = doc.metadata.get("rerank_score", 0.0)
parent_map[pid] = (embedding_score, rerank_score)
info(f"[Pipeline] _get_parents: 收集到 {len(parent_map)} 个 unique parent_id")
if not parent_map:
warning("[Pipeline] 未找到 parent_id返回子文档")
return child_docs
try:
info(f"[Pipeline] _get_parents: 调用 create_docstore")
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
info(f"[Pipeline] _get_parents: 调用 docstore.amget")
parent_docs =await docstore.amget(list(parent_map.keys()))
info(f"[Pipeline] _get_parents: docstore.amget 返回 {len(parent_docs)} 个结果")
# 构建结果,保持分数信息
result = []
for doc in parent_docs:
if doc:
pid = doc.metadata.get("id")
scores = parent_map.get(pid, (0.0, 0.0))
# 将分数添加到 metadata 中
doc.metadata["embedding_score"] = scores[0]
doc.metadata["rerank_score"] = scores[1]
result.append((doc, scores[0] + scores[1] * 2)) # 综合分数rerank 权重更高
result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result]
info(f"[Pipeline] _get_parents: 最终得到 {len(docs)} 个父文档")
info(f"[Pipeline] _get_parents 结束")
return docs
except Exception as e:
warning(f"[Pipeline] 获取父文档失败: {e}", exc_info=True)
return child_docs
def format_context(self, documents: List[Document]) -> str:
info(f"[Pipeline] format_context 开始: {len(documents)} 个文档")
if not documents:
info(f"[Pipeline] format_context: 无文档,返回空字符串")
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
result = "\n".join(parts)
info(f"[Pipeline] format_context 结束: 结果长度={len(result)} 字符")
return result
def create_rag_pipeline(**kwargs) -> RAGPipeline:
return RAGPipeline(**kwargs)