193 lines
5.8 KiB
Python
193 lines
5.8 KiB
Python
"""
|
|
查询改写器
|
|
|
|
基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。
|
|
"""
|
|
|
|
from typing import List, Optional, Any
|
|
from langchain.retrievers.multi_query import MultiQueryRetriever
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
from langchain_core.prompts import PromptTemplate
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
|
|
class MultiQueryTransformer:
|
|
"""
|
|
多路查询改写器
|
|
|
|
将单个查询改写成多个相关查询,用于 RAG-Fusion。
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm: BaseLanguageModel,
|
|
num_queries: int = 3,
|
|
prompt_template: Optional[str] = None,
|
|
):
|
|
"""
|
|
初始化查询改写器
|
|
|
|
Args:
|
|
llm: 语言模型实例
|
|
num_queries: 生成的查询数量
|
|
prompt_template: 提示词模板
|
|
"""
|
|
self.llm = llm
|
|
self.num_queries = num_queries
|
|
|
|
# 默认提示词模板
|
|
self.prompt_template = prompt_template or """
|
|
你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
|
|
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
|
|
|
|
原始问题: {question}
|
|
|
|
请生成 {num_queries} 个不同版本的查询,每个版本一行。
|
|
确保每个版本都是独立、完整的查询语句。
|
|
|
|
生成 {num_queries} 个查询:
|
|
"""
|
|
|
|
def transform_query(self, query: str) -> List[str]:
|
|
"""
|
|
将单个查询改写成多个查询
|
|
|
|
Args:
|
|
query: 原始查询
|
|
|
|
Returns:
|
|
改写后的查询列表
|
|
"""
|
|
prompt = PromptTemplate.from_template(self.prompt_template)
|
|
|
|
chain = prompt | self.llm | StrOutputParser()
|
|
|
|
response = chain.invoke({
|
|
"question": query,
|
|
"num_queries": self.num_queries,
|
|
})
|
|
|
|
# 解析响应,每行一个查询
|
|
queries = [
|
|
q.strip()
|
|
for q in response.strip().split('\n')
|
|
if q.strip()
|
|
]
|
|
|
|
# 确保数量正确,如果不够则添加原始查询
|
|
if len(queries) < self.num_queries:
|
|
queries.extend([query] * (self.num_queries - len(queries)))
|
|
elif len(queries) > self.num_queries:
|
|
queries = queries[:self.num_queries]
|
|
|
|
# 确保包含原始查询
|
|
if query not in queries:
|
|
queries = [query] + queries[:self.num_queries-1]
|
|
|
|
return queries
|
|
|
|
def create_multi_query_retriever(
|
|
self,
|
|
base_retriever: Any,
|
|
include_original: bool = True,
|
|
) -> MultiQueryRetriever:
|
|
"""
|
|
创建多路查询检索器
|
|
|
|
Args:
|
|
base_retriever: 基础检索器
|
|
include_original: 是否包含原始查询
|
|
|
|
Returns:
|
|
MultiQueryRetriever 实例
|
|
"""
|
|
retriever = MultiQueryRetriever.from_llm(
|
|
retriever=base_retriever,
|
|
llm=self.llm,
|
|
include_original=include_original,
|
|
)
|
|
|
|
# 设置生成的查询数量
|
|
retriever.llm_chain.prompt = PromptTemplate.from_template(
|
|
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
|
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
|
"原始问题: {question}\n\n"
|
|
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
|
"确保每个版本都是独立、完整的查询语句。\n\n"
|
|
"生成 {num_queries} 个查询:"
|
|
)
|
|
|
|
# 修改调用参数以包含 num_queries
|
|
original_invoke = retriever.llm_chain.invoke
|
|
|
|
def new_invoke(input_dict):
|
|
input_dict["num_queries"] = self.num_queries
|
|
return original_invoke(input_dict)
|
|
|
|
retriever.llm_chain.invoke = new_invoke
|
|
|
|
return retriever
|
|
|
|
@classmethod
|
|
def create_from_config(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
config: Optional[dict] = None,
|
|
) -> "MultiQueryTransformer":
|
|
"""
|
|
从配置创建查询改写器
|
|
|
|
Args:
|
|
llm: 语言模型实例
|
|
config: 配置字典
|
|
|
|
Returns:
|
|
MultiQueryTransformer 实例
|
|
"""
|
|
config = config or {}
|
|
return cls(
|
|
llm=llm,
|
|
num_queries=config.get("num_queries", 3),
|
|
prompt_template=config.get("prompt_template", None),
|
|
)
|
|
|
|
|
|
def create_rag_fusion_pipeline(
|
|
base_retriever: Any,
|
|
llm: BaseLanguageModel,
|
|
reranker: Optional[Any] = None,
|
|
num_queries: int = 3,
|
|
) -> Any:
|
|
"""
|
|
创建完整的 RAG-Fusion 流水线
|
|
|
|
Args:
|
|
base_retriever: 基础检索器
|
|
llm: 语言模型(用于查询改写)
|
|
reranker: 重排序器(可选)
|
|
num_queries: 查询改写数量
|
|
|
|
Returns:
|
|
检索器实例
|
|
"""
|
|
# 创建多路查询改写器
|
|
query_transformer = MultiQueryTransformer(
|
|
llm=llm,
|
|
num_queries=num_queries,
|
|
)
|
|
|
|
# 创建多路查询检索器
|
|
multi_query_retriever = query_transformer.create_multi_query_retriever(
|
|
base_retriever=base_retriever,
|
|
include_original=True,
|
|
)
|
|
|
|
# 如果提供了重排序器,则应用重排序
|
|
if reranker is not None:
|
|
from langchain.retrievers import ContextualCompressionRetriever
|
|
return ContextualCompressionRetriever(
|
|
base_compressor=reranker,
|
|
base_retriever=multi_query_retriever,
|
|
)
|
|
|
|
return multi_query_retriever |