2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
查询转换器模块
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
实现多路查询改写功能,用于 RAG-Fusion。
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
from typing import List, Optional
|
2026-04-18 16:31:48 +08:00
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# from langchain.retrievers.multi_query import MultiQueryRetriever
|
2026-04-18 16:31:48 +08:00
|
|
|
|
from langchain_core.prompts import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiQueryTransformer:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
"""多路查询改写器,用于 RAG-Fusion。"""
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
初始化多路查询改写器。
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
llm: 语言模型实例
|
|
|
|
|
|
num_queries: 生成的查询数量
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.llm = llm
|
|
|
|
|
|
self.num_queries = num_queries
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
def create_multi_query_retriever(self, base_retriever):
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
创建多路查询检索器。
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
base_retriever: 基础检索器
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
MultiQueryRetriever 实例
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器
|
|
|
|
|
|
# retriever = MultiQueryRetriever.from_llm(
|
|
|
|
|
|
# retriever=base_retriever,
|
|
|
|
|
|
# llm=self.llm,
|
|
|
|
|
|
# include_original=True
|
|
|
|
|
|
# )
|
|
|
|
|
|
#
|
|
|
|
|
|
# # 自定义提示词
|
|
|
|
|
|
# retriever.llm_chain.prompt = PromptTemplate.from_template(
|
|
|
|
|
|
# "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
|
|
|
|
|
# "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
|
|
|
|
|
# "原始问题: {question}\n\n"
|
|
|
|
|
|
# "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
|
|
|
|
|
# "确保每个版本都是独立、完整的查询语句。\n\n"
|
|
|
|
|
|
# "生成 {num_queries} 个查询:"
|
|
|
|
|
|
# )
|
|
|
|
|
|
#
|
|
|
|
|
|
# # 修改调用参数以包含 num_queries
|
|
|
|
|
|
# original_ainvoke = retriever.llm_chain.ainvoke
|
|
|
|
|
|
# async def new_ainvoke(input_dict):
|
|
|
|
|
|
# input_dict["num_queries"] = self.num_queries
|
|
|
|
|
|
# return await original_ainvoke(input_dict)
|
|
|
|
|
|
# retriever.llm_chain.ainvoke = new_ainvoke
|
|
|
|
|
|
#
|
|
|
|
|
|
# return retriever
|
|
|
|
|
|
return base_retriever
|