Files
ailine/app/rag/query_transform.py
root 933d418d77
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 17m12s
检索器重构
2026-04-19 22:01:55 +08:00

63 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
查询转换器模块
实现多路查询改写功能,用于 RAG-Fusion。
"""
from typing import List, Optional
from langchain_core.language_models import BaseLanguageModel
# from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.prompts import PromptTemplate
class MultiQueryTransformer:
"""多路查询改写器,用于 RAG-Fusion。"""
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
"""
初始化多路查询改写器。
Args:
llm: 语言模型实例
num_queries: 生成的查询数量
"""
self.llm = llm
self.num_queries = num_queries
def create_multi_query_retriever(self, base_retriever):
"""
创建多路查询检索器。
Args:
base_retriever: 基础检索器
Returns:
MultiQueryRetriever 实例
"""
# 由于当前 LangChain 版本不支持 MultiQueryRetriever暂时返回基础检索器
# retriever = MultiQueryRetriever.from_llm(
# retriever=base_retriever,
# llm=self.llm,
# include_original=True
# )
#
# # 自定义提示词
# retriever.llm_chain.prompt = PromptTemplate.from_template(
# "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
# "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
# "原始问题: {question}\n\n"
# "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
# "确保每个版本都是独立、完整的查询语句。\n\n"
# "生成 {num_queries} 个查询:"
# )
#
# # 修改调用参数以包含 num_queries
# original_ainvoke = retriever.llm_chain.ainvoke
# async def new_ainvoke(input_dict):
# input_dict["num_queries"] = self.num_queries
# return await original_ainvoke(input_dict)
# retriever.llm_chain.ainvoke = new_ainvoke
#
# return retriever
return base_retriever