Files
ailine/backend/app/rag/pipeline.py
root 1260bef5cb
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m31s
添加rag置信度判断
2026-05-06 01:15:52 +08:00

136 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索流水线
流程: 检索子文档 → 重排 → 获取父文档 → 返回
"""
import asyncio
import logging
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from ..model_services import get_rerank_service, get_small_llm_service
from ..rag.rerank import create_document_reranker
from ..rag.query_transform import MultiQueryGenerator
from ..rag.fusion import reciprocal_rank_fusion
from ..rag.retriever import create_parent_hybrid_retriever
logger = logging.getLogger(__name__)
class RAGPipeline:
def __init__(
self,
retriever=None,
llm: BaseLanguageModel | str = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
use_rerank: bool = True,
return_parent_docs: bool = True,
):
self.retriever = retriever or create_parent_hybrid_retriever(
collection_name=collection_name, search_k=rerank_top_n * 4
)
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
self.use_rerank = use_rerank
self.return_parent_docs = return_parent_docs
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception:
self.llm = None
else:
self.llm = llm if llm else None
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
self.reranker = create_document_reranker() if use_rerank else None
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
async def aretrieve(self, query: str) -> List[Document]:
# Step 1: 检索
child_docs = await self._retrieve(query)
logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档")
# 调试:打印子文档长度
for i, doc in enumerate(child_docs[:5]):
content_len = len(doc.page_content)
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
# Step 2: 重排
if self.reranker:
try:
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
logger.info(f"[Pipeline] 重排后 {len(child_docs)}")
except Exception as e:
logger.warning(f"[Pipeline] 重排失败: {e}")
child_docs = child_docs[:self.rerank_top_n]
# Step 3: 获取父文档
if self.return_parent_docs:
return await self._get_parents(child_docs)
return child_docs
async def _retrieve(self, query: str) -> List[Document]:
if self.query_generator:
queries = await self.query_generator.agenerate(query)
queries = [query] + [q for q in queries if q != query]
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
return reciprocal_rank_fusion(doc_lists)
return await self.retriever.ainvoke(query)
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
# 收集 parent_id 和对应的分数
parent_map = {} # parent_id -> (embedding_score, rerank_score)
for doc in child_docs:
pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map:
# embedding 分数
embedding_score = doc.metadata.get("score", 0.0)
# rerank 分数(如果有的话)
rerank_score = doc.metadata.get("rerank_score", 0.0)
parent_map[pid] = (embedding_score, rerank_score)
if not parent_map:
logger.warning("[Pipeline] 未找到 parent_id返回子文档")
return child_docs
try:
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
parent_docs = docstore.mget(list(parent_map.keys()))
# 构建结果,保持分数信息
result = []
for doc in parent_docs:
if doc:
pid = doc.metadata.get("id")
scores = parent_map.get(pid, (0.0, 0.0))
# 将分数添加到 metadata 中
doc.metadata["embedding_score"] = scores[0]
doc.metadata["rerank_score"] = scores[1]
result.append((doc, scores[0] + scores[1] * 2)) # 综合分数rerank 权重更高
result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result]
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
return docs
except Exception as e:
logger.warning(f"[Pipeline] 获取父文档失败: {e}")
return child_docs
def format_context(self, documents: List[Document]) -> str:
if not documents:
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)
def create_rag_pipeline(**kwargs) -> RAGPipeline:
return RAGPipeline(**kwargs)