43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
|
|
# rag/query_transform.py
|
|||
|
|
|
|||
|
|
from typing import List
|
|||
|
|
from langchain_core.language_models import BaseLanguageModel
|
|||
|
|
from langchain_core.prompts import PromptTemplate
|
|||
|
|
|
|||
|
|
MULTI_QUERY_PROMPT = PromptTemplate.from_template(
|
|||
|
|
"""你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
|
|||
|
|
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
|
|||
|
|
|
|||
|
|
原始问题: {question}
|
|||
|
|
|
|||
|
|
请生成 {num_queries} 个不同版本的查询,每个版本一行。
|
|||
|
|
确保每个版本都是独立、完整的查询语句。
|
|||
|
|
|
|||
|
|
生成 {num_queries} 个查询:"""
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
class MultiQueryGenerator:
|
|||
|
|
"""多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever)"""
|
|||
|
|
|
|||
|
|
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
|
|||
|
|
self.llm = llm
|
|||
|
|
self.num_queries = num_queries
|
|||
|
|
self.prompt = MULTI_QUERY_PROMPT
|
|||
|
|
|
|||
|
|
def generate(self, query: str) -> List[str]:
|
|||
|
|
"""同步生成多个查询变体"""
|
|||
|
|
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
|
|||
|
|
response = self.llm.invoke(prompt_str)
|
|||
|
|
# 处理响应内容,按行分割并去除空行和首尾空白
|
|||
|
|
lines = response.content.strip().split('\n')
|
|||
|
|
queries = [line.strip() for line in lines if line.strip()]
|
|||
|
|
# 确保至少返回原始查询
|
|||
|
|
return queries[:self.num_queries] if queries else [query]
|
|||
|
|
|
|||
|
|
async def agenerate(self, query: str) -> List[str]:
|
|||
|
|
"""异步生成多个查询变体"""
|
|||
|
|
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
|
|||
|
|
response = await self.llm.ainvoke(prompt_str)
|
|||
|
|
lines = response.content.strip().split('\n')
|
|||
|
|
queries = [line.strip() for line in lines if line.strip()]
|
|||
|
|
return queries[:self.num_queries] if queries else [query]
|