本地RAG尝试
This commit is contained in:
193
app/rag/query_transform.py
Normal file
193
app/rag/query_transform.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
查询改写器
|
||||
|
||||
基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Any
|
||||
from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
|
||||
class MultiQueryTransformer:
|
||||
"""
|
||||
多路查询改写器
|
||||
|
||||
将单个查询改写成多个相关查询,用于 RAG-Fusion。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
prompt_template: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化查询改写器
|
||||
|
||||
Args:
|
||||
llm: 语言模型实例
|
||||
num_queries: 生成的查询数量
|
||||
prompt_template: 提示词模板
|
||||
"""
|
||||
self.llm = llm
|
||||
self.num_queries = num_queries
|
||||
|
||||
# 默认提示词模板
|
||||
self.prompt_template = prompt_template or """
|
||||
你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
|
||||
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
|
||||
|
||||
原始问题: {question}
|
||||
|
||||
请生成 {num_queries} 个不同版本的查询,每个版本一行。
|
||||
确保每个版本都是独立、完整的查询语句。
|
||||
|
||||
生成 {num_queries} 个查询:
|
||||
"""
|
||||
|
||||
def transform_query(self, query: str) -> List[str]:
|
||||
"""
|
||||
将单个查询改写成多个查询
|
||||
|
||||
Args:
|
||||
query: 原始查询
|
||||
|
||||
Returns:
|
||||
改写后的查询列表
|
||||
"""
|
||||
prompt = PromptTemplate.from_template(self.prompt_template)
|
||||
|
||||
chain = prompt | self.llm | StrOutputParser()
|
||||
|
||||
response = chain.invoke({
|
||||
"question": query,
|
||||
"num_queries": self.num_queries,
|
||||
})
|
||||
|
||||
# 解析响应,每行一个查询
|
||||
queries = [
|
||||
q.strip()
|
||||
for q in response.strip().split('\n')
|
||||
if q.strip()
|
||||
]
|
||||
|
||||
# 确保数量正确,如果不够则添加原始查询
|
||||
if len(queries) < self.num_queries:
|
||||
queries.extend([query] * (self.num_queries - len(queries)))
|
||||
elif len(queries) > self.num_queries:
|
||||
queries = queries[:self.num_queries]
|
||||
|
||||
# 确保包含原始查询
|
||||
if query not in queries:
|
||||
queries = [query] + queries[:self.num_queries-1]
|
||||
|
||||
return queries
|
||||
|
||||
def create_multi_query_retriever(
|
||||
self,
|
||||
base_retriever: Any,
|
||||
include_original: bool = True,
|
||||
) -> MultiQueryRetriever:
|
||||
"""
|
||||
创建多路查询检索器
|
||||
|
||||
Args:
|
||||
base_retriever: 基础检索器
|
||||
include_original: 是否包含原始查询
|
||||
|
||||
Returns:
|
||||
MultiQueryRetriever 实例
|
||||
"""
|
||||
retriever = MultiQueryRetriever.from_llm(
|
||||
retriever=base_retriever,
|
||||
llm=self.llm,
|
||||
include_original=include_original,
|
||||
)
|
||||
|
||||
# 设置生成的查询数量
|
||||
retriever.llm_chain.prompt = PromptTemplate.from_template(
|
||||
"你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
||||
"这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
||||
"原始问题: {question}\n\n"
|
||||
"请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
||||
"确保每个版本都是独立、完整的查询语句。\n\n"
|
||||
"生成 {num_queries} 个查询:"
|
||||
)
|
||||
|
||||
# 修改调用参数以包含 num_queries
|
||||
original_invoke = retriever.llm_chain.invoke
|
||||
|
||||
def new_invoke(input_dict):
|
||||
input_dict["num_queries"] = self.num_queries
|
||||
return original_invoke(input_dict)
|
||||
|
||||
retriever.llm_chain.invoke = new_invoke
|
||||
|
||||
return retriever
|
||||
|
||||
@classmethod
|
||||
def create_from_config(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
config: Optional[dict] = None,
|
||||
) -> "MultiQueryTransformer":
|
||||
"""
|
||||
从配置创建查询改写器
|
||||
|
||||
Args:
|
||||
llm: 语言模型实例
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
MultiQueryTransformer 实例
|
||||
"""
|
||||
config = config or {}
|
||||
return cls(
|
||||
llm=llm,
|
||||
num_queries=config.get("num_queries", 3),
|
||||
prompt_template=config.get("prompt_template", None),
|
||||
)
|
||||
|
||||
|
||||
def create_rag_fusion_pipeline(
|
||||
base_retriever: Any,
|
||||
llm: BaseLanguageModel,
|
||||
reranker: Optional[Any] = None,
|
||||
num_queries: int = 3,
|
||||
) -> Any:
|
||||
"""
|
||||
创建完整的 RAG-Fusion 流水线
|
||||
|
||||
Args:
|
||||
base_retriever: 基础检索器
|
||||
llm: 语言模型(用于查询改写)
|
||||
reranker: 重排序器(可选)
|
||||
num_queries: 查询改写数量
|
||||
|
||||
Returns:
|
||||
检索器实例
|
||||
"""
|
||||
# 创建多路查询改写器
|
||||
query_transformer = MultiQueryTransformer(
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
)
|
||||
|
||||
# 创建多路查询检索器
|
||||
multi_query_retriever = query_transformer.create_multi_query_retriever(
|
||||
base_retriever=base_retriever,
|
||||
include_original=True,
|
||||
)
|
||||
|
||||
# 如果提供了重排序器,则应用重排序
|
||||
if reranker is not None:
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
return ContextualCompressionRetriever(
|
||||
base_compressor=reranker,
|
||||
base_retriever=multi_query_retriever,
|
||||
)
|
||||
|
||||
return multi_query_retriever
|
||||
Reference in New Issue
Block a user