""" 查询转换器模块 实现多路查询改写功能,用于 RAG-Fusion。 """ from typing import List, Optional from langchain_core.language_models import BaseLanguageModel # from langchain.retrievers.multi_query import MultiQueryRetriever from langchain_core.prompts import PromptTemplate class MultiQueryTransformer: """多路查询改写器,用于 RAG-Fusion。""" def __init__(self, llm: BaseLanguageModel, num_queries: int = 3): """ 初始化多路查询改写器。 Args: llm: 语言模型实例 num_queries: 生成的查询数量 """ self.llm = llm self.num_queries = num_queries def create_multi_query_retriever(self, base_retriever): """ 创建多路查询检索器。 Args: base_retriever: 基础检索器 Returns: MultiQueryRetriever 实例 """ # 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器 # retriever = MultiQueryRetriever.from_llm( # retriever=base_retriever, # llm=self.llm, # include_original=True # ) # # # 自定义提示词 # retriever.llm_chain.prompt = PromptTemplate.from_template( # "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" # "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" # "原始问题: {question}\n\n" # "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" # "确保每个版本都是独立、完整的查询语句。\n\n" # "生成 {num_queries} 个查询:" # ) # # # 修改调用参数以包含 num_queries # original_ainvoke = retriever.llm_chain.ainvoke # async def new_ainvoke(input_dict): # input_dict["num_queries"] = self.num_queries # return await original_ainvoke(input_dict) # retriever.llm_chain.ainvoke = new_ainvoke # # return retriever return base_retriever