63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
"""
|
||
查询转换器模块
|
||
|
||
实现多路查询改写功能,用于 RAG-Fusion。
|
||
"""
|
||
|
||
from typing import List, Optional
|
||
from langchain_core.language_models import BaseLanguageModel
|
||
# from langchain.retrievers.multi_query import MultiQueryRetriever
|
||
from langchain_core.prompts import PromptTemplate
|
||
|
||
|
||
class MultiQueryTransformer:
|
||
"""多路查询改写器,用于 RAG-Fusion。"""
|
||
|
||
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
|
||
"""
|
||
初始化多路查询改写器。
|
||
|
||
Args:
|
||
llm: 语言模型实例
|
||
num_queries: 生成的查询数量
|
||
"""
|
||
self.llm = llm
|
||
self.num_queries = num_queries
|
||
|
||
def create_multi_query_retriever(self, base_retriever):
|
||
"""
|
||
创建多路查询检索器。
|
||
|
||
Args:
|
||
base_retriever: 基础检索器
|
||
|
||
Returns:
|
||
MultiQueryRetriever 实例
|
||
"""
|
||
# 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器
|
||
# retriever = MultiQueryRetriever.from_llm(
|
||
# retriever=base_retriever,
|
||
# llm=self.llm,
|
||
# include_original=True
|
||
# )
|
||
#
|
||
# # 自定义提示词
|
||
# retriever.llm_chain.prompt = PromptTemplate.from_template(
|
||
# "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
||
# "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
||
# "原始问题: {question}\n\n"
|
||
# "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
||
# "确保每个版本都是独立、完整的查询语句。\n\n"
|
||
# "生成 {num_queries} 个查询:"
|
||
# )
|
||
#
|
||
# # 修改调用参数以包含 num_queries
|
||
# original_ainvoke = retriever.llm_chain.ainvoke
|
||
# async def new_ainvoke(input_dict):
|
||
# input_dict["num_queries"] = self.num_queries
|
||
# return await original_ainvoke(input_dict)
|
||
# retriever.llm_chain.ainvoke = new_ainvoke
|
||
#
|
||
# return retriever
|
||
return base_retriever
|