检索器重构
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s

This commit is contained in:
2026-04-19 22:01:55 +08:00
parent cc8ef41ef9
commit 933d418d77
26 changed files with 1694 additions and 1717 deletions

View File

@@ -1,193 +1,62 @@
"""
查询改写器
查询转换器模块
基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围
实现多路查询改写功能,用于 RAG-Fusion
"""
from typing import List, Optional, Any
from langchain.retrievers.multi_query import MultiQueryRetriever
from typing import List, Optional
from langchain_core.language_models import BaseLanguageModel
# from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
class MultiQueryTransformer:
"""
多路查询改写器
"""多路查询改写器,用于 RAG-Fusion。"""
将单个查询改写成多个相关查询,用于 RAG-Fusion。
"""
def __init__(
self,
llm: BaseLanguageModel,
num_queries: int = 3,
prompt_template: Optional[str] = None,
):
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
"""
初始化查询改写器
初始化多路查询改写器
Args:
llm: 语言模型实例
num_queries: 生成的查询数量
prompt_template: 提示词模板
"""
self.llm = llm
self.num_queries = num_queries
# 默认提示词模板
self.prompt_template = prompt_template or """
你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
原始问题: {question}
请生成 {num_queries} 个不同版本的查询,每个版本一行。
确保每个版本都是独立、完整的查询语句。
生成 {num_queries} 个查询:
"""
def transform_query(self, query: str) -> List[str]:
def create_multi_query_retriever(self, base_retriever):
"""
将单个查询改写成多个查询
Args:
query: 原始查询
Returns:
改写后的查询列表
"""
prompt = PromptTemplate.from_template(self.prompt_template)
chain = prompt | self.llm | StrOutputParser()
response = chain.invoke({
"question": query,
"num_queries": self.num_queries,
})
# 解析响应,每行一个查询
queries = [
q.strip()
for q in response.strip().split('\n')
if q.strip()
]
# 确保数量正确,如果不够则添加原始查询
if len(queries) < self.num_queries:
queries.extend([query] * (self.num_queries - len(queries)))
elif len(queries) > self.num_queries:
queries = queries[:self.num_queries]
# 确保包含原始查询
if query not in queries:
queries = [query] + queries[:self.num_queries-1]
return queries
def create_multi_query_retriever(
self,
base_retriever: Any,
include_original: bool = True,
) -> MultiQueryRetriever:
"""
创建多路查询检索器
创建多路查询检索器。
Args:
base_retriever: 基础检索器
include_original: 是否包含原始查询
Returns:
MultiQueryRetriever 实例
"""
retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=self.llm,
include_original=include_original,
)
# 设置生成的查询数量
retriever.llm_chain.prompt = PromptTemplate.from_template(
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
"原始问题: {question}\n\n"
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
"确保每个版本都是独立、完整的查询语句。\n\n"
"生成 {num_queries} 个查询:"
)
# 修改调用参数以包含 num_queries
original_invoke = retriever.llm_chain.invoke
def new_invoke(input_dict):
input_dict["num_queries"] = self.num_queries
return original_invoke(input_dict)
retriever.llm_chain.invoke = new_invoke
return retriever
@classmethod
def create_from_config(
cls,
llm: BaseLanguageModel,
config: Optional[dict] = None,
) -> "MultiQueryTransformer":
"""
从配置创建查询改写器
Args:
llm: 语言模型实例
config: 配置字典
Returns:
MultiQueryTransformer 实例
"""
config = config or {}
return cls(
llm=llm,
num_queries=config.get("num_queries", 3),
prompt_template=config.get("prompt_template", None),
)
def create_rag_fusion_pipeline(
base_retriever: Any,
llm: BaseLanguageModel,
reranker: Optional[Any] = None,
num_queries: int = 3,
) -> Any:
"""
创建完整的 RAG-Fusion 流水线
Args:
base_retriever: 基础检索器
llm: 语言模型(用于查询改写)
reranker: 重排序器(可选)
num_queries: 查询改写数量
Returns:
检索器实例
"""
# 创建多路查询改写器
query_transformer = MultiQueryTransformer(
llm=llm,
num_queries=num_queries,
)
# 创建多路查询检索器
multi_query_retriever = query_transformer.create_multi_query_retriever(
base_retriever=base_retriever,
include_original=True,
)
# 如果提供了重排序器,则应用重排序
if reranker is not None:
from langchain.retrievers import ContextualCompressionRetriever
return ContextualCompressionRetriever(
base_compressor=reranker,
base_retriever=multi_query_retriever,
)
return multi_query_retriever
# 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器
# retriever = MultiQueryRetriever.from_llm(
# retriever=base_retriever,
# llm=self.llm,
# include_original=True
# )
#
# # 自定义提示词
# retriever.llm_chain.prompt = PromptTemplate.from_template(
# "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
# "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
# "原始问题: {question}\n\n"
# "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
# "确保每个版本都是独立、完整的查询语句。\n\n"
# "生成 {num_queries} 个查询:"
# )
#
# # 修改调用参数以包含 num_queries
# original_ainvoke = retriever.llm_chain.ainvoke
# async def new_ainvoke(input_dict):
# input_dict["num_queries"] = self.num_queries
# return await original_ainvoke(input_dict)
# retriever.llm_chain.ainvoke = new_ainvoke
#
# return retriever
return base_retriever