""" 查询改写器 基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。 """ from typing import List, Optional, Any from langchain.retrievers.multi_query import MultiQueryRetriever from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser class MultiQueryTransformer: """ 多路查询改写器 将单个查询改写成多个相关查询,用于 RAG-Fusion。 """ def __init__( self, llm: BaseLanguageModel, num_queries: int = 3, prompt_template: Optional[str] = None, ): """ 初始化查询改写器 Args: llm: 语言模型实例 num_queries: 生成的查询数量 prompt_template: 提示词模板 """ self.llm = llm self.num_queries = num_queries # 默认提示词模板 self.prompt_template = prompt_template or """ 你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。 这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。 原始问题: {question} 请生成 {num_queries} 个不同版本的查询,每个版本一行。 确保每个版本都是独立、完整的查询语句。 生成 {num_queries} 个查询: """ def transform_query(self, query: str) -> List[str]: """ 将单个查询改写成多个查询 Args: query: 原始查询 Returns: 改写后的查询列表 """ prompt = PromptTemplate.from_template(self.prompt_template) chain = prompt | self.llm | StrOutputParser() response = chain.invoke({ "question": query, "num_queries": self.num_queries, }) # 解析响应,每行一个查询 queries = [ q.strip() for q in response.strip().split('\n') if q.strip() ] # 确保数量正确,如果不够则添加原始查询 if len(queries) < self.num_queries: queries.extend([query] * (self.num_queries - len(queries))) elif len(queries) > self.num_queries: queries = queries[:self.num_queries] # 确保包含原始查询 if query not in queries: queries = [query] + queries[:self.num_queries-1] return queries def create_multi_query_retriever( self, base_retriever: Any, include_original: bool = True, ) -> MultiQueryRetriever: """ 创建多路查询检索器 Args: base_retriever: 基础检索器 include_original: 是否包含原始查询 Returns: MultiQueryRetriever 实例 """ retriever = MultiQueryRetriever.from_llm( retriever=base_retriever, llm=self.llm, include_original=include_original, ) # 设置生成的查询数量 retriever.llm_chain.prompt = PromptTemplate.from_template( "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" "原始问题: {question}\n\n" "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" "确保每个版本都是独立、完整的查询语句。\n\n" "生成 {num_queries} 个查询:" ) # 修改调用参数以包含 num_queries original_invoke = retriever.llm_chain.invoke def new_invoke(input_dict): input_dict["num_queries"] = self.num_queries return original_invoke(input_dict) retriever.llm_chain.invoke = new_invoke return retriever @classmethod def create_from_config( cls, llm: BaseLanguageModel, config: Optional[dict] = None, ) -> "MultiQueryTransformer": """ 从配置创建查询改写器 Args: llm: 语言模型实例 config: 配置字典 Returns: MultiQueryTransformer 实例 """ config = config or {} return cls( llm=llm, num_queries=config.get("num_queries", 3), prompt_template=config.get("prompt_template", None), ) def create_rag_fusion_pipeline( base_retriever: Any, llm: BaseLanguageModel, reranker: Optional[Any] = None, num_queries: int = 3, ) -> Any: """ 创建完整的 RAG-Fusion 流水线 Args: base_retriever: 基础检索器 llm: 语言模型(用于查询改写) reranker: 重排序器(可选) num_queries: 查询改写数量 Returns: 检索器实例 """ # 创建多路查询改写器 query_transformer = MultiQueryTransformer( llm=llm, num_queries=num_queries, ) # 创建多路查询检索器 multi_query_retriever = query_transformer.create_multi_query_retriever( base_retriever=base_retriever, include_original=True, ) # 如果提供了重排序器,则应用重排序 if reranker is not None: from langchain.retrievers import ContextualCompressionRetriever return ContextualCompressionRetriever( base_compressor=reranker, base_retriever=multi_query_retriever, ) return multi_query_retriever