参数配置统一

This commit is contained in:
2026-04-21 19:06:34 +08:00
parent e2eaac9498
commit 37e86f3bb1
10 changed files with 120 additions and 166 deletions

View File

@@ -37,7 +37,6 @@ RAG 检索与生成模块
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
@@ -50,7 +49,6 @@ __all__ = [
# 检索器工厂函数
"create_base_retriever",
"create_hybrid_retriever",
"create_qdrant_client",
# 重排序器
"LLaMaCPPReranker",

View File

@@ -25,66 +25,25 @@ Qdrant 向量检索器模块
>>> docs = retriever.invoke("什么是 RAG")
"""
from typing import Optional, Dict, Any
from typing import Dict, Any
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from langchain_qdrant import QdrantVectorStore
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from rag_core import QDRANT_URL, QDRANT_API_KEY
from rag_core import QDRANT_URL, QDRANT_API_KEY, LlamaCppEmbedder
from rag_core.client import create_qdrant_client as create_core_qdrant_client
# 模块级常量
DEFAULT_SEARCH_K = 20
DEFAULT_SCORE_THRESHOLD = 0.3
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 30,
) -> QdrantClient:
"""
创建并返回一个配置好的 Qdrant 客户端。
优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。
Args:
url: Qdrant 服务地址,例如 "http://localhost:6333"
默认从环境变量 QDRANT_URL 读取。
api_key: API 密钥(若 Qdrant 启用了认证)。
默认从环境变量 QDRANT_API_KEY 读取。
timeout: 请求超时时间(秒),默认 30 秒。
Returns:
配置好的 QdrantClient 实例。
Raises:
ValueError: 如果 url 为空且环境变量也未设置。
"""
effective_url = url or QDRANT_URL
if not effective_url:
raise ValueError(
"Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL"
)
effective_api_key = api_key or QDRANT_API_KEY
client_kwargs = {
"url": effective_url,
"timeout": timeout,
}
if effective_api_key:
client_kwargs["api_key"] = effective_api_key
return QdrantClient(**client_kwargs)
def create_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
search_kwargs: Dict[str, Any] | None = None,
client: QdrantClient | None = None,
) -> BaseRetriever:
"""
创建基础向量检索器(仅稠密向量检索)。
@@ -94,7 +53,6 @@ def create_base_retriever(
Args:
collection_name: Qdrant 集合名称(需预先创建并索引)。
embeddings: LangChain 兼容的嵌入模型实例。
search_kwargs: 搜索参数,可包含:
- k (int): 返回的文档数量,默认 20。
- score_threshold (float): 相似度阈值,仅返回高于此分数的文档。
@@ -108,6 +66,10 @@ def create_base_retriever(
Raises:
ValueError: 如果集合不存在或嵌入模型无效。
"""
# 嵌入模型
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 合并默认搜索参数
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
if search_kwargs:
@@ -115,7 +77,7 @@ def create_base_retriever(
# 创建或复用 Qdrant 客户端
if client is None:
client = create_qdrant_client()
client = create_core_qdrant_client()
# 验证集合是否存在(可选,便于提前发现问题)
try:
@@ -140,11 +102,10 @@ def create_base_retriever(
def create_hybrid_retriever(
collection_name: str,
embeddings: Embeddings,
dense_k: int = 10,
sparse_k: int = 10,
score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD,
client: Optional[QdrantClient] = None,
score_threshold: float | None = DEFAULT_SCORE_THRESHOLD,
client: QdrantClient | None = None,
) -> BaseRetriever:
"""
创建混合检索器(稠密向量 + BM25 稀疏向量)。
@@ -157,7 +118,6 @@ def create_hybrid_retriever(
Args:
collection_name: Qdrant 集合名称。
embeddings: 嵌入模型(用于稠密向量)。
dense_k: 稠密向量检索返回数量,默认 10。
sparse_k: 稀疏向量检索返回数量,默认 10。
score_threshold: 相似度阈值,默认 0.3。
@@ -177,7 +137,6 @@ def create_hybrid_retriever(
# 复用基础检索器创建逻辑,只需调整搜索参数
return create_base_retriever(
collection_name=collection_name,
embeddings=embeddings,
search_kwargs=search_kwargs,
client=client,
)
@@ -186,9 +145,8 @@ def create_hybrid_retriever(
# 可选:提供异步友好的辅助函数
async def acreate_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
search_kwargs: Dict[str, Any] | None = None,
client: QdrantClient | None = None,
) -> BaseRetriever:
"""
异步创建基础向量检索器(与同步版本功能相同)。
@@ -196,4 +154,4 @@ async def acreate_base_retriever(
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
"""
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
return create_base_retriever(collection_name, embeddings, search_kwargs, client)
return create_base_retriever(collection_name, search_kwargs, client)