本地RAG尝试
This commit is contained in:
341
app/rag/pipeline.py
Normal file
341
app/rag/pipeline.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""
|
||||
RAG 检索流水线
|
||||
|
||||
组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from .retriever import (
|
||||
create_base_retriever,
|
||||
create_hybrid_retriever,
|
||||
create_ensemble_retriever,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from .reranker import CrossEncoderReranker
|
||||
from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline
|
||||
|
||||
|
||||
class RAGLevel(Enum):
|
||||
"""RAG 功能级别"""
|
||||
BASIC = 1 # 基础向量搜索
|
||||
HYBRID = 2 # 混合检索 + 重排序
|
||||
FUSION = 3 # RAG-Fusion
|
||||
AGENTIC = 4 # Agentic RAG
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGConfig:
|
||||
"""RAG 配置"""
|
||||
# Qdrant 配置
|
||||
collection_name: str = "documents"
|
||||
qdrant_url: Optional[str] = None
|
||||
qdrant_api_key: Optional[str] = None
|
||||
|
||||
# 检索配置
|
||||
rag_level: RAGLevel = RAGLevel.FUSION
|
||||
dense_k: int = 10 # 向量检索数量
|
||||
sparse_k: int = 10 # BM25 检索数量
|
||||
total_k: int = 20 # 总检索数量
|
||||
rerank_top_n: int = 5 # 重排序返回数量
|
||||
|
||||
# 查询改写配置
|
||||
num_queries: int = 3 # RAG-Fusion 查询数量
|
||||
|
||||
# 模型配置
|
||||
reranker_model: str = "BAAI/bge-reranker-base"
|
||||
device: Optional[str] = None
|
||||
|
||||
# 性能配置
|
||||
enable_cache: bool = True
|
||||
verbose: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""检索结果"""
|
||||
documents: List[Document]
|
||||
query_time: float
|
||||
level: RAGLevel
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class RAGPipeline:
|
||||
"""
|
||||
RAG 检索流水线
|
||||
|
||||
支持从 Level 1 到 Level 4 的所有功能。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: Embeddings,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
config: Optional[RAGConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化 RAG 流水线
|
||||
|
||||
Args:
|
||||
embeddings: 嵌入模型
|
||||
llm: 语言模型(用于查询改写,Level 3+ 需要)
|
||||
config: 配置
|
||||
"""
|
||||
self.embeddings = embeddings
|
||||
self.llm = llm
|
||||
self.config = config or RAGConfig()
|
||||
|
||||
# 初始化组件
|
||||
self._client = None
|
||||
self._reranker = None
|
||||
self._query_transformer = None
|
||||
self._retriever = None
|
||||
|
||||
# 缓存
|
||||
self._cache = {}
|
||||
|
||||
def _get_client(self):
|
||||
"""获取 Qdrant 客户端"""
|
||||
if self._client is None:
|
||||
self._client = create_qdrant_client(
|
||||
url=self.config.qdrant_url,
|
||||
api_key=self.config.qdrant_api_key,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _get_reranker(self):
|
||||
"""获取重排序器"""
|
||||
if self._reranker is None:
|
||||
self._reranker = CrossEncoderReranker(
|
||||
model_name=self.config.reranker_model,
|
||||
top_n=self.config.rerank_top_n,
|
||||
device=self.config.device,
|
||||
)
|
||||
return self._reranker
|
||||
|
||||
def _get_query_transformer(self):
|
||||
"""获取查询改写器"""
|
||||
if self._query_transformer is None and self.llm is not None:
|
||||
self._query_transformer = MultiQueryTransformer(
|
||||
llm=self.llm,
|
||||
num_queries=self.config.num_queries,
|
||||
)
|
||||
return self._query_transformer
|
||||
|
||||
def _create_basic_retriever(self):
|
||||
"""创建基础检索器(Level 1)"""
|
||||
return create_base_retriever(
|
||||
collection_name=self.config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
search_kwargs={"k": self.config.total_k},
|
||||
client=self._get_client(),
|
||||
)
|
||||
|
||||
def _create_hybrid_retriever(self):
|
||||
"""创建混合检索器(Level 2)"""
|
||||
base_retriever = create_hybrid_retriever(
|
||||
collection_name=self.config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
dense_k=self.config.dense_k,
|
||||
sparse_k=self.config.sparse_k,
|
||||
client=self._get_client(),
|
||||
)
|
||||
|
||||
# 应用重排序
|
||||
reranker = self._get_reranker()
|
||||
return reranker.create_contextual_compression_retriever(base_retriever)
|
||||
|
||||
def _create_fusion_retriever(self):
|
||||
"""创建 RAG-Fusion 检索器(Level 3)"""
|
||||
if self.llm is None:
|
||||
raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写")
|
||||
|
||||
# 创建基础混合检索器
|
||||
base_retriever = create_hybrid_retriever(
|
||||
collection_name=self.config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
dense_k=self.config.dense_k,
|
||||
sparse_k=self.config.sparse_k,
|
||||
client=self._get_client(),
|
||||
)
|
||||
|
||||
# 创建 RAG-Fusion 流水线
|
||||
reranker = self._get_reranker()
|
||||
return create_rag_fusion_pipeline(
|
||||
base_retriever=base_retriever,
|
||||
llm=self.llm,
|
||||
reranker=reranker,
|
||||
num_queries=self.config.num_queries,
|
||||
)
|
||||
|
||||
def _get_retriever(self):
|
||||
"""根据配置级别获取检索器"""
|
||||
if self._retriever is None:
|
||||
if self.config.rag_level == RAGLevel.BASIC:
|
||||
self._retriever = self._create_basic_retriever()
|
||||
elif self.config.rag_level == RAGLevel.HYBRID:
|
||||
self._retriever = self._create_hybrid_retriever()
|
||||
elif self.config.rag_level == RAGLevel.FUSION:
|
||||
self._retriever = self._create_fusion_retriever()
|
||||
elif self.config.rag_level == RAGLevel.AGENTIC:
|
||||
# Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装
|
||||
self._retriever = self._create_fusion_retriever()
|
||||
else:
|
||||
raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}")
|
||||
|
||||
return self._retriever
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> RetrievalResult:
|
||||
"""
|
||||
执行检索
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
use_cache: 是否使用缓存
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
检索结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 检查缓存
|
||||
if use_cache is None:
|
||||
use_cache = self.config.enable_cache
|
||||
|
||||
cache_key = f"{query}:{self.config.rag_level.value}"
|
||||
if use_cache and cache_key in self._cache:
|
||||
if self.config.verbose:
|
||||
print(f"使用缓存结果: {query}")
|
||||
return self._cache[cache_key]
|
||||
|
||||
# 获取检索器并执行检索
|
||||
retriever = self._get_retriever()
|
||||
documents = retriever.invoke(query, **kwargs)
|
||||
|
||||
# 计算查询时间
|
||||
query_time = time.time() - start_time
|
||||
|
||||
# 创建结果
|
||||
result = RetrievalResult(
|
||||
documents=documents,
|
||||
query_time=query_time,
|
||||
level=self.config.rag_level,
|
||||
metadata={
|
||||
"query": query,
|
||||
"collection": self.config.collection_name,
|
||||
"doc_count": len(documents),
|
||||
},
|
||||
)
|
||||
|
||||
# 缓存结果
|
||||
if use_cache:
|
||||
self._cache[cache_key] = result
|
||||
|
||||
if self.config.verbose:
|
||||
print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s")
|
||||
|
||||
return result
|
||||
|
||||
def format_context(
|
||||
self,
|
||||
documents: List[Document],
|
||||
max_length: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
格式化检索到的文档为上下文文本
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
max_length: 最大长度(字符数)
|
||||
|
||||
Returns:
|
||||
格式化后的上下文文本
|
||||
"""
|
||||
context_parts = []
|
||||
total_length = 0
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
# 提取内容和元数据
|
||||
content = doc.page_content.strip()
|
||||
metadata = doc.metadata
|
||||
|
||||
# 格式化文档
|
||||
doc_text = f"[文档 {i+1}]\n"
|
||||
if metadata.get("source"):
|
||||
doc_text += f"来源: {metadata['source']}\n"
|
||||
if metadata.get("page"):
|
||||
doc_text += f"页码: {metadata['page']}\n"
|
||||
|
||||
doc_text += f"内容: {content}\n\n"
|
||||
|
||||
# 检查长度限制
|
||||
if max_length is not None:
|
||||
if total_length + len(doc_text) > max_length:
|
||||
# 如果添加这个文档会超限,则截断并添加说明
|
||||
remaining = max_length - total_length
|
||||
if remaining > 100: # 至少保留100字符
|
||||
doc_text = doc_text[:remaining] + "...\n\n[内容已截断]"
|
||||
context_parts.append(doc_text)
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
context_parts.append(doc_text)
|
||||
total_length += len(doc_text)
|
||||
|
||||
return "".join(context_parts).strip()
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
|
||||
@classmethod
|
||||
def create_from_config(
|
||||
cls,
|
||||
embeddings: Embeddings,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
config_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> "RAGPipeline":
|
||||
"""
|
||||
从配置字典创建流水线
|
||||
|
||||
Args:
|
||||
embeddings: 嵌入模型
|
||||
llm: 语言模型
|
||||
config_dict: 配置字典
|
||||
|
||||
Returns:
|
||||
RAGPipeline 实例
|
||||
"""
|
||||
config_dict = config_dict or {}
|
||||
|
||||
# 创建配置对象
|
||||
config = RAGConfig(
|
||||
collection_name=config_dict.get("collection_name", "documents"),
|
||||
qdrant_url=config_dict.get("qdrant_url"),
|
||||
qdrant_api_key=config_dict.get("qdrant_api_key"),
|
||||
rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)),
|
||||
dense_k=config_dict.get("dense_k", 10),
|
||||
sparse_k=config_dict.get("sparse_k", 10),
|
||||
total_k=config_dict.get("total_k", 20),
|
||||
rerank_top_n=config_dict.get("rerank_top_n", 5),
|
||||
num_queries=config_dict.get("num_queries", 3),
|
||||
reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"),
|
||||
device=config_dict.get("device"),
|
||||
enable_cache=config_dict.get("enable_cache", True),
|
||||
verbose=config_dict.get("verbose", True),
|
||||
)
|
||||
|
||||
return cls(embeddings=embeddings, llm=llm, config=config)
|
||||
Reference in New Issue
Block a user