重排,多路查询
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 35m37s

This commit is contained in:
2026-04-20 01:10:18 +08:00
parent 933d418d77
commit 3c906e91d9
21 changed files with 728 additions and 635 deletions

View File

@@ -1,62 +1,43 @@
"""
查询转换器模块
# rag/query_transform.py
实现多路查询改写功能,用于 RAG-Fusion。
"""
from typing import List, Optional
from typing import List
from langchain_core.language_models import BaseLanguageModel
# from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.prompts import PromptTemplate
MULTI_QUERY_PROMPT = PromptTemplate.from_template(
"""你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
class MultiQueryTransformer:
"""多路查询改写器,用于 RAG-Fusion。"""
原始问题: {question}
请生成 {num_queries} 个不同版本的查询,每个版本一行。
确保每个版本都是独立、完整的查询语句。
生成 {num_queries} 个查询:"""
)
class MultiQueryGenerator:
"""多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever"""
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
"""
初始化多路查询改写器。
Args:
llm: 语言模型实例
num_queries: 生成的查询数量
"""
self.llm = llm
self.num_queries = num_queries
self.prompt = MULTI_QUERY_PROMPT
def create_multi_query_retriever(self, base_retriever):
"""
创建多路查询检索器。
Args:
base_retriever: 基础检索器
Returns:
MultiQueryRetriever 实例
"""
# 由于当前 LangChain 版本不支持 MultiQueryRetriever暂时返回基础检索器
# retriever = MultiQueryRetriever.from_llm(
# retriever=base_retriever,
# llm=self.llm,
# include_original=True
# )
#
# # 自定义提示词
# retriever.llm_chain.prompt = PromptTemplate.from_template(
# "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
# "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
# "原始问题: {question}\n\n"
# "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
# "确保每个版本都是独立、完整的查询语句。\n\n"
# "生成 {num_queries} 个查询:"
# )
#
# # 修改调用参数以包含 num_queries
# original_ainvoke = retriever.llm_chain.ainvoke
# async def new_ainvoke(input_dict):
# input_dict["num_queries"] = self.num_queries
# return await original_ainvoke(input_dict)
# retriever.llm_chain.ainvoke = new_ainvoke
#
# return retriever
return base_retriever
def generate(self, query: str) -> List[str]:
"""同步生成多个查询变体"""
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
response = self.llm.invoke(prompt_str)
# 处理响应内容,按行分割并去除空行和首尾空白
lines = response.content.strip().split('\n')
queries = [line.strip() for line in lines if line.strip()]
# 确保至少返回原始查询
return queries[:self.num_queries] if queries else [query]
async def agenerate(self, query: str) -> List[str]:
"""异步生成多个查询变体"""
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
response = await self.llm.ainvoke(prompt_str)
lines = response.content.strip().split('\n')
queries = [line.strip() for line in lines if line.strip()]
return queries[:self.num_queries] if queries else [query]