Files
ailine/backend/app/rag/query_transform.py
root 8b354b7ccc
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 47m14s
重构代码,统一config配置
2026-04-21 11:02:16 +08:00

43 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rag/query_transform.py
from typing import List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
MULTI_QUERY_PROMPT = PromptTemplate.from_template(
"""你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
原始问题: {question}
请生成 {num_queries} 个不同版本的查询,每个版本一行。
确保每个版本都是独立、完整的查询语句。
生成 {num_queries} 个查询:"""
)
class MultiQueryGenerator:
"""多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever"""
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
self.llm = llm
self.num_queries = num_queries
self.prompt = MULTI_QUERY_PROMPT
def generate(self, query: str) -> List[str]:
"""同步生成多个查询变体"""
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
response = self.llm.invoke(prompt_str)
# 处理响应内容,按行分割并去除空行和首尾空白
lines = response.content.strip().split('\n')
queries = [line.strip() for line in lines if line.strip()]
# 确保至少返回原始查询
return queries[:self.num_queries] if queries else [query]
async def agenerate(self, query: str) -> List[str]:
"""异步生成多个查询变体"""
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
response = await self.llm.ainvoke(prompt_str)
lines = response.content.strip().split('\n')
queries = [line.strip() for line in lines if line.strip()]
return queries[:self.num_queries] if queries else [query]