Files
ailine/app/rag/pipeline.py
root 933d418d77
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s
检索器重构
2026-04-19 22:01:55 +08:00

169 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 检索流水线
整合基础检索、重排序和 RAG-Fusion 功能。
"""
from enum import Enum
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryTransformer
from rag_core import QDRANT_URL, QDRANT_API_KEY
class RAGLevel(Enum):
"""RAG 级别"""
BASIC = "basic" # 基础向量检索
RERANK = "rerank" # 基础检索 + 重排序
FUSION = "fusion" # RAG-Fusion多路查询 + RRF
class RAGPipeline:
"""RAG 检索流水线"""
def __init__(
self,
embeddings,
llm: Optional[BaseLanguageModel] = None,
config: Optional[Dict[str, Any]] = None,
):
"""
初始化 RAG 流水线
Args:
embeddings: 嵌入模型
llm: 语言模型(用于 RAG-Fusion
config: 配置参数
"""
self.embeddings = embeddings
self.llm = llm
self.config = config or {}
self.collection_name = self.config.get("collection_name", "rag_documents")
self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
self.num_queries = self.config.get("num_queries", 3)
self.rerank_top_n = self.config.get("rerank_top_n", 5)
# 初始化基础检索器
self.base_retriever = create_base_retriever(
collection_name=self.collection_name,
embeddings=self.embeddings,
search_kwargs={"k": 20}, # 召回 20 条
)
# 初始化重排序器
try:
self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
except Exception as e:
print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
self.reranker = None
# 根据 RAG 级别创建检索器
self.retriever = self._create_retriever()
def _create_retriever(self):
"""根据 RAG 级别创建检索器"""
if self.rag_level == RAGLevel.BASIC.value:
return self.base_retriever
# 基础检索 + 重排序
def rerank_retriever(query):
documents = self.base_retriever.invoke(query)
if self.reranker:
return self.reranker.compress_documents(documents, query)
else:
return documents[:self.rerank_top_n]
if self.rag_level == RAGLevel.RERANK.value:
return SimpleRetriever(rerank_retriever)
# RAG-Fusion
if self.rag_level == RAGLevel.FUSION.value:
if not self.llm:
raise ValueError("RAG-Fusion 需要提供 llm 参数")
# 创建多路查询检索器
transformer = MultiQueryTransformer(
llm=self.llm,
num_queries=self.num_queries
)
multi_query_retriever = transformer.create_multi_query_retriever(
base_retriever=SimpleRetriever(rerank_retriever)
)
return multi_query_retriever
return SimpleRetriever(rerank_retriever)
def retrieve(self, query: str) -> List[Document]:
"""
执行检索
Args:
query: 查询字符串
Returns:
相关文档列表
"""
return self.retriever.invoke(query)
async def aretrieve(self, query: str) -> List[Document]:
"""
异步执行检索
Args:
query: 查询字符串
Returns:
相关文档列表
"""
return await self.retriever.ainvoke(query)
def format_context(self, documents: List[Document]) -> str:
"""
格式化上下文
Args:
documents: 文档列表
Returns:
格式化后的上下文字符串
"""
if not documents:
return ""
context_parts = []
for i, doc in enumerate(documents, 1):
content = doc.page_content
metadata = doc.metadata or {}
source = metadata.get("source", "未知来源")
part = f"【资料 {i}\n"
part += f"来源: {source}\n"
part += f"内容: {content}\n"
part += "---\n"
context_parts.append(part)
return "".join(context_parts)
class SimpleRetriever:
"""简单检索器包装类"""
def __init__(self, retrieve_func):
self.retrieve_func = retrieve_func
def invoke(self, query):
return self.retrieve_func(query)
async def ainvoke(self, query):
return self.retrieve_func(query)