Files
ailine/app/rag/query_transform.py

193 lines
5.8 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
"""
查询改写器
基于 MultiQueryRetriever 实现多路查询改写扩大搜索范围
"""
from typing import List, Optional, Any
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
class MultiQueryTransformer:
"""
多路查询改写器
将单个查询改写成多个相关查询用于 RAG-Fusion
"""
def __init__(
self,
llm: BaseLanguageModel,
num_queries: int = 3,
prompt_template: Optional[str] = None,
):
"""
初始化查询改写器
Args:
llm: 语言模型实例
num_queries: 生成的查询数量
prompt_template: 提示词模板
"""
self.llm = llm
self.num_queries = num_queries
# 默认提示词模板
self.prompt_template = prompt_template or """
你是一个专业的查询改写助手你的任务是将用户的问题改写成 {num_queries} 个不同的版本
这些版本应该从不同的角度使用不同的关键词来表达相同或相关的意图
原始问题: {question}
请生成 {num_queries} 个不同版本的查询每个版本一行
确保每个版本都是独立完整的查询语句
生成 {num_queries} 个查询:
"""
def transform_query(self, query: str) -> List[str]:
"""
将单个查询改写成多个查询
Args:
query: 原始查询
Returns:
改写后的查询列表
"""
prompt = PromptTemplate.from_template(self.prompt_template)
chain = prompt | self.llm | StrOutputParser()
response = chain.invoke({
"question": query,
"num_queries": self.num_queries,
})
# 解析响应,每行一个查询
queries = [
q.strip()
for q in response.strip().split('\n')
if q.strip()
]
# 确保数量正确,如果不够则添加原始查询
if len(queries) < self.num_queries:
queries.extend([query] * (self.num_queries - len(queries)))
elif len(queries) > self.num_queries:
queries = queries[:self.num_queries]
# 确保包含原始查询
if query not in queries:
queries = [query] + queries[:self.num_queries-1]
return queries
def create_multi_query_retriever(
self,
base_retriever: Any,
include_original: bool = True,
) -> MultiQueryRetriever:
"""
创建多路查询检索器
Args:
base_retriever: 基础检索器
include_original: 是否包含原始查询
Returns:
MultiQueryRetriever 实例
"""
retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=self.llm,
include_original=include_original,
)
# 设置生成的查询数量
retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_invoke = retriever.llm_chain.invoke
def new_invoke(input_dict):
input_dict["num_queries"] = self.num_queries
return original_invoke(input_dict)
retriever.llm_chain.invoke = new_invoke
return retriever
@classmethod
def create_from_config(
cls,
llm: BaseLanguageModel,
config: Optional[dict] = None,
) -> "MultiQueryTransformer":
"""
从配置创建查询改写器
Args:
llm: 语言模型实例
config: 配置字典
Returns:
MultiQueryTransformer 实例
"""
config = config or {}
return cls(
llm=llm,
num_queries=config.get("num_queries", 3),
prompt_template=config.get("prompt_template", None),
)
def create_rag_fusion_pipeline(
base_retriever: Any,
llm: BaseLanguageModel,
reranker: Optional[Any] = None,
num_queries: int = 3,
) -> Any:
"""
创建完整的 RAG-Fusion 流水线
Args:
base_retriever: 基础检索器
llm: 语言模型用于查询改写
reranker: 重排序器可选
num_queries: 查询改写数量
Returns:
检索器实例
"""
# 创建多路查询改写器
query_transformer = MultiQueryTransformer(
llm=llm,
num_queries=num_queries,
)
# 创建多路查询检索器
multi_query_retriever = query_transformer.create_multi_query_retriever(
base_retriever=base_retriever,
include_original=True,
)
# 如果提供了重排序器,则应用重排序
if reranker is not None:
from langchain.retrievers import ContextualCompressionRetriever
return ContextualCompressionRetriever(
base_compressor=reranker,
base_retriever=multi_query_retriever,
)
return multi_query_retriever