Files
ailine/backend/app/rag/pipeline.py
root 3ae9daa01a
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m44s
导入方式修改
2026-05-05 23:17:00 +08:00

121 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索流水线
流程: 检索子文档 → 重排 → 获取父文档 → 返回
"""
import asyncio
import logging
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from ..model_services import get_rerank_service, get_small_llm_service
from ..rag.rerank import create_document_reranker
from ..rag.query_transform import MultiQueryGenerator
from ..rag.fusion import reciprocal_rank_fusion
from ..rag.retriever import create_parent_hybrid_retriever
logger = logging.getLogger(__name__)
class RAGPipeline:
def __init__(
self,
retriever=None,
llm: BaseLanguageModel | str = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
use_rerank: bool = True,
return_parent_docs: bool = True,
):
self.retriever = retriever or create_parent_hybrid_retriever(
collection_name=collection_name, search_k=rerank_top_n * 4
)
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
self.use_rerank = use_rerank
self.return_parent_docs = return_parent_docs
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception:
self.llm = None
else:
self.llm = llm if llm else None
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
self.reranker = create_document_reranker() if use_rerank else None
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
async def aretrieve(self, query: str) -> List[Document]:
# Step 1: 检索
child_docs = await self._retrieve(query)
logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档")
# 调试:打印子文档长度
for i, doc in enumerate(child_docs[:5]):
content_len = len(doc.page_content)
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
# Step 2: 重排
if self.reranker:
try:
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
logger.info(f"[Pipeline] 重排后 {len(child_docs)}")
except Exception as e:
logger.warning(f"[Pipeline] 重排失败: {e}")
child_docs = child_docs[:self.rerank_top_n]
# Step 3: 获取父文档
if self.return_parent_docs:
return await self._get_parents(child_docs)
return child_docs
async def _retrieve(self, query: str) -> List[Document]:
if self.query_generator:
queries = await self.query_generator.agenerate(query)
queries = [query] + [q for q in queries if q != query]
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
return reciprocal_rank_fusion(doc_lists)
return await self.retriever.ainvoke(query)
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
parent_map = {}
for doc in child_docs:
pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map:
parent_map[pid] = doc.metadata.get("score", 0.0)
if not parent_map:
logger.warning("[Pipeline] 未找到 parent_id返回子文档")
return child_docs
try:
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
# 同步获取(异步版本不存在)
parent_docs = docstore.mget(list(parent_map.keys()))
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2]
result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result]
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
return docs
except Exception as e:
logger.warning(f"[Pipeline] 获取父文档失败: {e}")
return child_docs
def format_context(self, documents: List[Document]) -> str:
if not documents:
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)
def create_rag_pipeline(**kwargs) -> RAGPipeline:
return RAGPipeline(**kwargs)