文件变更
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# rag/pipeline.py
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@@ -23,7 +24,6 @@ class RAGPipeline:
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
rerank_model: str = "BAAI/bge-reranker-base",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -41,9 +41,9 @@ class RAGPipeline:
|
||||
# 初始化组件
|
||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
|
||||
self.reranker = LLaMaCPPReranker(
|
||||
base_url="http://127.0.0.1:8083",
|
||||
base_url=os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083"),
|
||||
api_key=os.getenv("LLAMACPP_API_KEY", "huang1998"),
|
||||
top_n=rerank_top_n,
|
||||
api_key="huang1998"
|
||||
)
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
@@ -68,9 +68,9 @@ class RAGPipeline:
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||
|
||||
# Step 4: 重排序
|
||||
if self.reranker.model is not None:
|
||||
try:
|
||||
final_docs = self.reranker.compress_documents(fused_docs, query)
|
||||
else:
|
||||
except Exception:
|
||||
# 若重排序器不可用,直接返回融合后的前 N 条
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user