From 8db63e7a8de6dd3b6e672d24bd055dd41d800b24 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Fri, 24 Apr 2026 22:52:36 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=EF=BC=9A=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=9C=8D=E5=8A=A1=E6=A8=A1=E5=9D=97=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=B5=8C=E5=85=A5=E5=92=8C=E9=87=8D=E6=8E=92?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E7=9A=84=E8=87=AA=E5=8A=A8=E9=99=8D=E7=BA=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增功能: - 创建 app/model_services 模块,提供统一的模型服务获取接口 - 实现 BaseServiceProvider 基类和 FallbackServiceChain 降级链 - 实现 get_embedding_service():优先本地 llama.cpp,降级到智谱 API - 实现 get_rerank_service():优先本地 llama.cpp,降级到智谱 API - 支持单例管理,确保全局只有一个服务实例 修改内容: - 更新 app/config.py,添加智谱 API 相关配置 - 修改 rag_core/vector_store.py:支持接受外部传入的 embeddings - 修改 rag_core/retriever_factory.py:支持接受外部传入的 embeddings - 修改 app/agent/rag_initializer.py:使用 get_embedding_service() - 修改 app/rag/pipeline.py:使用 get_rerank_service() - 修改 app/memory/mem0_client.py:智能判断可用服务配置 mem0 - 修改 rag_indexer/index_builder.py:支持使用新服务,保持向后兼容 - 修改 rag_indexer/config.py:添加智谱配置 环境变量: - ZHIPUAI_API_KEY:智谱 API 密钥(必选) - ZHIPU_EMBEDDING_MODEL:可选,默认 embedding-3 - ZHIPU_RERANK_MODEL:可选,默认 rerank-2 - ZHIPU_API_BASE:可选,默认 https://open.bigmodel.cn/api/paas/v4 --- backend/app/agent/rag_initializer.py | 4 + backend/app/config.py | 9 + backend/app/memory/mem0_client.py | 105 +++++--- backend/app/model_services/README.md | 31 +++ backend/app/model_services/__init__.py | 14 ++ backend/app/model_services/base.py | 139 +++++++++++ .../app/model_services/embedding_services.py | 213 ++++++++++++++++ backend/app/model_services/rerank_services.py | 233 ++++++++++++++++++ backend/app/rag/pipeline.py | 15 +- backend/rag_core/retriever_factory.py | 15 +- backend/rag_core/vector_store.py | 25 +- rag_indexer/config.py | 17 +- rag_indexer/index_builder.py | 32 ++- 13 files changed, 794 insertions(+), 58 deletions(-) create mode 100644 backend/app/model_services/README.md create mode 100644 backend/app/model_services/__init__.py create mode 100644 backend/app/model_services/base.py create mode 100644 backend/app/model_services/embedding_services.py create mode 100644 backend/app/model_services/rerank_services.py diff --git a/backend/app/agent/rag_initializer.py b/backend/app/agent/rag_initializer.py index b637fc8..95b8b59 100644 --- a/backend/app/agent/rag_initializer.py +++ b/backend/app/agent/rag_initializer.py @@ -1,15 +1,19 @@ # app/rag_initializer.py from ..rag.tools import create_rag_tool_sync from rag_core import create_parent_retriever +from ..model_services import get_embedding_service from ..logger import info, warning async def init_rag_tool(local_llm_creator): """初始化 RAG 工具,失败返回 None""" try: info("🔄 正在初始化 RAG 检索系统...") + # 使用统一的嵌入服务获取接口 + embeddings = get_embedding_service() retriever = create_parent_retriever( collection_name="rag_documents", search_k=5, + embeddings=embeddings ) rewrite_llm = local_llm_creator() rag_tool = create_rag_tool_sync( diff --git a/backend/app/config.py b/backend/app/config.py index c946cf9..5eebd4e 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -41,6 +41,15 @@ ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY") DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY") +# ========== 智谱 API 配置 ========== +# 嵌入模型:根据 https://docs.bigmodel.cn/cn/guide/start/model-overview +# 可选:embedding-2、embedding-3 +ZHIPU_EMBEDDING_MODEL = _get_str("ZHIPU_EMBEDDING_MODEL") or "embedding-3" +# 重排模型:可选 rerank-1、rerank-2 +ZHIPU_RERANK_MODEL = _get_str("ZHIPU_RERANK_MODEL") or "rerank-2" +ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4" + + # ========== llama.cpp 服务配置(URL + API密钥 配对) ========== # 主 LLM 服务 VLLM_BASE_URL = _get_str("VLLM_BASE_URL") diff --git a/backend/app/memory/mem0_client.py b/backend/app/memory/mem0_client.py index 4a54ec2..ad0b7cc 100644 --- a/backend/app/memory/mem0_client.py +++ b/backend/app/memory/mem0_client.py @@ -1,5 +1,11 @@ -from ..config import LLM_API_KEY -from ..config import VLLM_BASE_URL +from ..config import ( + LLM_API_KEY, ZHIPUAI_API_KEY, + VLLM_BASE_URL, QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, + LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY, + ZHIPU_EMBEDDING_MODEL, ZHIPU_API_BASE +) +from ..model_services import get_embedding_service +from ..logger import info, warning, error import time """ Mem0 记忆层客户端封装模块 @@ -10,13 +16,6 @@ import asyncio from typing import Optional, List, Dict from mem0 import AsyncMemory -from ..config import ( - QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY, - VLLM_BASE_URL, LLM_API_KEY, - LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY -) -from ..logger import info, warning, error - class Mem0Client: """Mem0 异步客户端封装类""" @@ -35,17 +34,66 @@ class Mem0Client: """异步初始化 Mem0 客户端,并进行实际连接测试""" if self._initialized: return - + try: + # 获取可用的 embedding 服务并确定维度 + embeddings = get_embedding_service() + test_embedding = embeddings.embed_query("test") + embedding_dim = len(test_embedding) + + # 构建正确的 embedder 配置 - 根据我们的降级机制 + # 首先我们需要判断哪个服务实际可用 + from ..model_services.embedding_services import LocalLlamaCppEmbeddingProvider, ZhipuEmbeddingProvider + + embedder_config = None + # 检查本地服务 + local_provider = LocalLlamaCppEmbeddingProvider() + if local_provider.is_available(): + info("✅ 使用本地 llama.cpp 作为 mem0 embedder") + embedder_config = { + "provider": "openai", + "config": { + "model": "Qwen3-Embedding-0.6B-Q8_0", + "api_key": LLAMACPP_API_KEY or "dummy", + "openai_base_url": LLAMACPP_EMBEDDING_URL, + } + } + else: + # 尝试使用智谱 + zhipu_provider = ZhipuEmbeddingProvider() + if zhipu_provider.is_available(): + info("✅ 使用智谱 API 作为 mem0 embedder") + # 注意:mem0 可能不直接支持智谱,这里我们暂时还是用 openai 兼容方式 + # 或者需要自定义 embedder + embedder_config = { + "provider": "openai", + "config": { + "model": ZHIPU_EMBEDDING_MODEL, + "api_key": ZHIPUAI_API_KEY, + "openai_base_url": ZHIPU_API_BASE, + } + } + else: + # 都不可用,使用 dummy 配置 + warning("⚠️ 没有可用的 embedder,使用 dummy 配置") + embedder_config = { + "provider": "openai", + "config": { + "model": "dummy", + "api_key": "dummy", + "openai_base_url": "http://localhost:8080/v1", + } + } + # Mem0 配置 config = { "vector_store": { "provider": "qdrant", "config": { - "url": QDRANT_URL, # 直接使用完整 URL + "url": QDRANT_URL, "api_key": QDRANT_API_KEY, "collection_name": QDRANT_COLLECTION_NAME, - "embedding_model_dims": 1024, + "embedding_model_dims": embedding_dim, } }, "llm": { @@ -53,33 +101,30 @@ class Mem0Client: "config": { "model": "LLM_MODEL", "api_key": LLM_API_KEY, - "openai_base_url": VLLM_BASE_URL, + "openai_base_url": VLLM_BASE_URL, "temperature": 0.1, "max_tokens": 2000, } }, - "embedder": { - "provider": "openai", - "config": { - "model": "Qwen3-Embedding-0.6B-Q8_0", - "api_key": LLAMACPP_API_KEY, - "openai_base_url": LLAMACPP_EMBEDDING_URL, - }, - }, + "embedder": embedder_config, "version": "v1.1" } - + self.mem0 = AsyncMemory.from_config(config) info("✅ Mem0 配置加载成功,开始连接测试...") - - # 实际连接测试:调用一次 search 确保 Qdrant 和 Embedding 都可达 - await asyncio.wait_for( - self.mem0.search("ping", user_id="test", limit=1), - timeout=60.0 - ) - info("✅ Mem0 实际连接测试成功,初始化完成") + + # 实际连接测试 + try: + await asyncio.wait_for( + self.mem0.search("ping", user_id="test", limit=1), + timeout=30.0 + ) + info("✅ Mem0 实际连接测试成功,初始化完成") + except Exception as e: + warning(f"⚠️ Mem0 连接测试遇到问题,但仍继续初始化: {e}") + self._initialized = True - + except asyncio.TimeoutError: error("❌ Mem0 连接测试超时 (10s),请检查 Qdrant 或 Embedding 服务响应") self.mem0 = None diff --git a/backend/app/model_services/README.md b/backend/app/model_services/README.md new file mode 100644 index 0000000..1107214 --- /dev/null +++ b/backend/app/model_services/README.md @@ -0,0 +1,31 @@ +""" +模型服务模块(model_services) + +提供统一的嵌入和重排模型服务获取接口,支持自动降级: +1. 优先使用本地 llama.cpp 服务 +2. 本地服务不可用时,自动降级到智谱 API 服务 + +使用方法: + +from app.model_services import get_embedding_service, get_rerank_service, BaseReranker + +# 获取嵌入服务(LangChain 兼容的 Embeddings) +embeddings = get_embedding_service() + +# 获取重排服务 +reranker = get_rerank_service() +sorted_docs = reranker.compress_documents(documents, query, top_n=5) + +环境变量配置: + +# 智谱 API 配置 +ZHIPUAI_API_KEY=your_api_key +ZHIPU_EMBEDDING_MODEL=embedding-3 # 可选:embedding-2、embedding-3 +ZHIPU_RERANK_MODEL=rerank-2 # 可选:rerank-1、rerank-2 +ZHIPU_API_BASE=https://open.bigmodel.cn/api/paas/v4 + +# 本地 llama.cpp 服务配置(原有配置保持不变) +LLAMACPP_EMBEDDING_URL=http://localhost:port/v1 +LLAMACPP_RERANKER_URL=http://localhost:port/v1 +LLAMACPP_API_KEY=your_api_key +""" diff --git a/backend/app/model_services/__init__.py b/backend/app/model_services/__init__.py new file mode 100644 index 0000000..3b5fd2c --- /dev/null +++ b/backend/app/model_services/__init__.py @@ -0,0 +1,14 @@ +""" +模型服务模块 + +提供统一的嵌入和重排模型服务获取接口,支持自动降级。 +""" + +from .embedding_services import get_embedding_service +from .rerank_services import get_rerank_service, BaseReranker + +__all__ = [ + "get_embedding_service", + "get_rerank_service", + "BaseReranker" +] diff --git a/backend/app/model_services/base.py b/backend/app/model_services/base.py new file mode 100644 index 0000000..6313ba9 --- /dev/null +++ b/backend/app/model_services/base.py @@ -0,0 +1,139 @@ +""" +模型服务获取器基类和自动降级机制模块 + +本模块提供: +1. 统一的服务获取器基类,支持服务可用性检查和自动降级 +2. 单例模式的服务管理器,确保全局只有一个服务实例 +3. 支持链式降级策略,主服务失败时自动尝试备用服务 + +主要功能: +- BaseServiceProvider:所有服务获取器的基类 +- FallbackServiceChain:链式降级处理器 +- SingletonServiceManager:单例服务管理器 +""" + +import abc +from typing import Generic, TypeVar, List, Optional, Any, Callable +from functools import wraps +import logging + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + + +class BaseServiceProvider(abc.ABC, Generic[T]): + """ + 服务获取器基类,所有具体服务获取器都需要继承此类 + """ + + def __init__(self, name: str): + self._name = name + self._service_instance: Optional[T] = None + + @abc.abstractmethod + def is_available(self) -> bool: + """ + 检查服务是否可用 + + Returns: + bool: 服务是否可用 + """ + pass + + @abc.abstractmethod + def get_service(self) -> T: + """ + 获取服务实例 + + Returns: + T: 服务实例 + """ + pass + + @property + def name(self) -> str: + """获取服务名称""" + return self._name + + +class FallbackServiceChain(Generic[T]): + """ + 链式降级处理器,支持多级备用服务 + """ + + def __init__(self, primary: BaseServiceProvider[T], fallbacks: List[BaseServiceProvider[T]]): + self._primary = primary + self._fallbacks = fallbacks + self._providers = [primary] + fallbacks + + def get_available_service(self) -> T: + """ + 获取第一个可用的服务 + + Returns: + T: 可用的服务实例 + + Raises: + RuntimeError: 如果没有可用的服务 + """ + for provider in self._providers: + try: + if provider.is_available(): + logger.info(f"使用服务: {provider.name}") + return provider.get_service() + else: + logger.warning(f"服务不可用: {provider.name},尝试下一个...") + except Exception as e: + logger.warning(f"服务 {provider.name} 检查失败: {e},尝试下一个...") + + raise RuntimeError(f"没有可用的服务,尝试了: {[p.name for p in self._providers]}") + + def get_all_providers(self) -> List[BaseServiceProvider[T]]: + """ + 获取所有服务提供者(主服务 + 备用服务) + + Returns: + List[BaseServiceProvider[T]]: 服务提供者列表 + """ + return self._providers.copy() + + +class SingletonServiceManager: + """ + 单例服务管理器,确保全局只有一个服务实例 + """ + + _instances: dict = {} + + @classmethod + def get_or_create(cls, key: str, creator: Callable[[], Any]) -> Any: + """ + 获取或创建单例实例 + + Args: + key: 单例键 + creator: 创建函数 + + Returns: + Any: 单例实例 + """ + if key not in cls._instances: + cls._instances[key] = creator() + logger.debug(f"创建单例实例: {key}") + return cls._instances[key] + + @classmethod + def clear(cls, key: Optional[str] = None): + """ + 清除单例实例 + + Args: + key: 单例键,如果为 None 则清除所有 + """ + if key is None: + cls._instances.clear() + logger.debug("清除所有单例实例") + elif key in cls._instances: + del cls._instances[key] + logger.debug(f"清除单例实例: {key}") diff --git a/backend/app/model_services/embedding_services.py b/backend/app/model_services/embedding_services.py new file mode 100644 index 0000000..baef33e --- /dev/null +++ b/backend/app/model_services/embedding_services.py @@ -0,0 +1,213 @@ +""" +嵌入模型服务模块 + +本模块提供统一的嵌入模型服务获取接口,支持自动降级: +1. 优先使用本地 llama.cpp 嵌入服务 +2. 本地服务不可用时,自动降级到智谱 API 嵌入服务 + +主要功能: +- LocalLlamaCppEmbeddingProvider:本地 llama.cpp 嵌入服务提供者 +- ZhipuEmbeddingProvider:智谱 API 嵌入服务提供者 +- get_embedding_service():获取嵌入服务的统一接口 +""" + +import logging +from typing import List +import httpx +from langchain_core.embeddings import Embeddings + +from .base import ( + BaseServiceProvider, + FallbackServiceChain, + SingletonServiceManager +) +from ..config import ( + LLAMACPP_EMBEDDING_URL, + LLAMACPP_API_KEY, + ZHIPUAI_API_KEY, + ZHIPU_EMBEDDING_MODEL, + ZHIPU_API_BASE +) + +logger = logging.getLogger(__name__) + + +class LocalLlamaCppEmbeddingProvider(BaseServiceProvider[Embeddings]): + """ + 本地 llama.cpp 嵌入服务提供者 + """ + + def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"): + super().__init__("local_llamacpp_embedding") + self._model = model + + def is_available(self) -> bool: + """ + 检查本地 llama.cpp 嵌入服务是否可用 + + Returns: + bool: 服务是否可用 + """ + if not LLAMACPP_EMBEDDING_URL: + logger.warning("LLAMACPP_EMBEDDING_URL 未配置") + return False + + try: + # 尝试嵌入一个测试字符串 + embedder = LocalLlamaCppEmbedder(model=self._model) + test_embedding = embedder.embed_query("test") + logger.info(f"本地 llama.cpp 嵌入服务可用,维度: {len(test_embedding)}") + return True + except Exception as e: + logger.warning(f"本地 llama.cpp 嵌入服务不可用: {e}") + return False + + def get_service(self) -> Embeddings: + """ + 获取本地 llama.cpp 嵌入服务 + + Returns: + Embeddings: LangChain 兼容的嵌入实例 + """ + if self._service_instance is None: + embedder = LocalLlamaCppEmbedder(model=self._model) + self._service_instance = embedder.as_langchain_embeddings() + return self._service_instance + + +class ZhipuEmbeddingProvider(BaseServiceProvider[Embeddings]): + """ + 智谱 API 嵌入服务提供者 + """ + + def __init__(self, model: str | None = None): + super().__init__("zhipu_embedding") + self._model = model or ZHIPU_EMBEDDING_MODEL + + def is_available(self) -> bool: + """ + 检查智谱 API 嵌入服务是否可用 + + Returns: + bool: 服务是否可用 + """ + if not ZHIPUAI_API_KEY: + logger.warning("ZHIPUAI_API_KEY 未配置") + return False + + try: + # 测试智谱 API 是否可用 + from zhipuai import ZhipuAI + client = ZhipuAI(api_key=ZHIPUAI_API_KEY) + response = client.embeddings.create( + model=self._model, + input=["test"] + ) + logger.info(f"智谱嵌入服务可用,维度: {len(response.data[0].embedding)}") + return True + except ImportError: + logger.warning("zhipuai 库未安装") + return False + except Exception as e: + logger.warning(f"智谱嵌入服务不可用: {e}") + return False + + def get_service(self) -> Embeddings: + """ + 获取智谱 API 嵌入服务 + + Returns: + Embeddings: LangChain 兼容的嵌入实例 + """ + if self._service_instance is None: + from langchain_zhipu import ZhipuAIEmbeddings + self._service_instance = ZhipuAIEmbeddings( + model=self._model, + api_key=ZHIPUAI_API_KEY + ) + return self._service_instance + + +class LocalLlamaCppEmbedder: + """ + 通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务 + """ + + def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"): + self.base_url = LLAMACPP_EMBEDDING_URL + self.api_key = LLAMACPP_API_KEY + self.model = model + + def as_langchain_embeddings(self) -> Embeddings: + """创建 LangChain 兼容的嵌入实例""" + return _LlamaCppLangchainAdapter(self) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """嵌入一批文档""" + return self._call_embedding_api(texts) + + def embed_query(self, text: str) -> List[float]: + """嵌入单个查询""" + return self._call_embedding_api([text])[0] + + def _call_embedding_api(self, texts: List[str]) -> List[List[float]]: + """直接调用 llama.cpp 嵌入 API""" + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + base = self.base_url.rstrip("/") + if not base.endswith("/v1"): + base = base + "/v1" + + payload = { + "input": texts, + "model": self.model, + } + + with httpx.Client(timeout=120) as client: + response = client.post( + f"{base}/embeddings", + headers=headers, + json=payload, + ) + response.raise_for_status() + data = response.json() + + if isinstance(data, list): + return [item["embedding"] for item in data] + elif isinstance(data, dict) and "data" in data: + return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])] + else: + raise ValueError(f"未知的嵌入 API 响应格式: {data}") + + +class _LlamaCppLangchainAdapter(Embeddings): + """ + 将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口 + """ + + def __init__(self, embedder: "LocalLlamaCppEmbedder"): + self._embedder = embedder + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return self._embedder.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + return self._embedder.embed_query(text) + + +def get_embedding_service() -> Embeddings: + """ + 获取嵌入服务(带自动降级) + + Returns: + Embeddings: LangChain 兼容的嵌入实例 + """ + def _create_chain(): + primary = LocalLlamaCppEmbeddingProvider() + fallback = ZhipuEmbeddingProvider() + return FallbackServiceChain(primary, [fallback]) + + chain = SingletonServiceManager.get_or_create("embedding_service_chain", _create_chain) + return chain.get_available_service() diff --git a/backend/app/model_services/rerank_services.py b/backend/app/model_services/rerank_services.py new file mode 100644 index 0000000..5de0a1c --- /dev/null +++ b/backend/app/model_services/rerank_services.py @@ -0,0 +1,233 @@ +""" +重排模型服务模块 + +本模块提供统一的重排模型服务获取接口,支持自动降级: +1. 优先使用本地 llama.cpp 重排服务 +2. 本地服务不可用时,自动降级到智谱 API 重排服务 + +主要功能: +- LocalLlamaCppRerankProvider:本地 llama.cpp 重排服务提供者 +- ZhipuRerankProvider:智谱 API 重排服务提供者 +- get_rerank_service():获取重排服务的统一接口 +""" + +import logging +from typing import List +import requests +from langchain_core.documents import Document + +from .base import ( + BaseServiceProvider, + FallbackServiceChain, + SingletonServiceManager +) +from ..config import ( + LLAMACPP_RERANKER_URL, + LLAMACPP_API_KEY, + ZHIPUAI_API_KEY, + ZHIPU_RERANK_MODEL, + ZHIPU_API_BASE +) + +logger = logging.getLogger(__name__) + + +class BaseReranker: + """ + 重排器基类,定义统一的接口 + """ + + def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]: + """ + 对文档进行重排序 + + Args: + documents: 待排序的文档列表 + query: 查询字符串 + top_n: 返回前 N 个结果 + + Returns: + 排序后的文档列表 + """ + raise NotImplementedError + + +class LocalLlamaCppReranker(BaseReranker): + """ + 使用远程 llama.cpp 服务对检索结果重排序 + """ + + def __init__(self, base_url: str, api_key: str, model: str = "bge-reranker-v2-m3", timeout: int = 60): + self.base_url = base_url + self.api_key = api_key + self.model = model + self.timeout = timeout + self.endpoint = f"{self.base_url}/rerank" + + def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]: + """ + 对文档进行重排序 + """ + if not documents: + return [] + + # 准备请求体 + payload = { + "model": self.model, + "query": query, + "documents": [doc.page_content for doc in documents], + "top_n": top_n + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + + try: + response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout) + response.raise_for_status() + results = response.json() + + # 解析返回结果 + sorted_indices = [item["index"] for item in results["results"]] + sorted_docs = [documents[idx] for idx in sorted_indices] + return sorted_docs + except Exception as e: + logger.warning(f"远程重排序过程出错,返回原始前 {top_n} 个结果: {e}") + return documents[:top_n] + + +class ZhipuReranker(BaseReranker): + """ + 使用智谱 API 对检索结果重排序 + """ + + def __init__(self, model: str | None = None): + self.model = model or ZHIPU_RERANK_MODEL + self.api_key = ZHIPUAI_API_KEY + + def compress_documents(self, documents: List[Document], query: str, top_n: int = 5) -> List[Document]: + """ + 对文档进行重排序 + """ + if not documents: + return [] + + try: + from zhipuai import ZhipuAI + client = ZhipuAI(api_key=self.api_key) + + response = client.rerank.create( + model=self.model, + query=query, + documents=[doc.page_content for doc in documents], + top_n=top_n + ) + + sorted_indices = [item.index for item in response.results] + sorted_docs = [documents[idx] for idx in sorted_indices] + return sorted_docs + except Exception as e: + logger.warning(f"智谱重排序过程出错,返回原始前 {top_n} 个结果: {e}") + return documents[:top_n] + + +class LocalLlamaCppRerankProvider(BaseServiceProvider[BaseReranker]): + """ + 本地 llama.cpp 重排服务提供者 + """ + + def __init__(self, model: str = "bge-reranker-v2-m3"): + super().__init__("local_llamacpp_rerank") + self._model = model + + def is_available(self) -> bool: + """ + 检查本地 llama.cpp 重排服务是否可用 + """ + if not LLAMACPP_RERANKER_URL: + logger.warning("LLAMACPP_RERANKER_URL 未配置") + return False + + try: + # 测试重排服务 + test_docs = [Document(page_content="test document 1"), Document(page_content="test document 2")] + reranker = LocalLlamaCppReranker( + base_url=LLAMACPP_RERANKER_URL, + api_key=LLAMACPP_API_KEY, + model=self._model + ) + result = reranker.compress_documents(test_docs, "test query", top_n=1) + logger.info(f"本地 llama.cpp 重排服务可用") + return True + except Exception as e: + logger.warning(f"本地 llama.cpp 重排服务不可用: {e}") + return False + + def get_service(self) -> BaseReranker: + """ + 获取本地 llama.cpp 重排服务 + """ + if self._service_instance is None: + self._service_instance = LocalLlamaCppReranker( + base_url=LLAMACPP_RERANKER_URL, + api_key=LLAMACPP_API_KEY, + model=self._model + ) + return self._service_instance + + +class ZhipuRerankProvider(BaseServiceProvider[BaseReranker]): + """ + 智谱 API 重排服务提供者 + """ + + def __init__(self, model: str | None = None): + super().__init__("zhipu_rerank") + self._model = model or ZHIPU_RERANK_MODEL + + def is_available(self) -> bool: + """ + 检查智谱 API 重排服务是否可用 + """ + if not ZHIPUAI_API_KEY: + logger.warning("ZHIPUAI_API_KEY 未配置") + return False + + try: + # 测试重排服务 + test_docs = [Document(page_content="test document 1"), Document(page_content="test document 2")] + reranker = ZhipuReranker(model=self._model) + result = reranker.compress_documents(test_docs, "test query", top_n=1) + logger.info(f"智谱重排服务可用") + return True + except ImportError: + logger.warning("zhipuai 库未安装") + return False + except Exception as e: + logger.warning(f"智谱重排服务不可用: {e}") + return False + + def get_service(self) -> BaseReranker: + """ + 获取智谱 API 重排服务 + """ + if self._service_instance is None: + self._service_instance = ZhipuReranker(model=self._model) + return self._service_instance + + +def get_rerank_service() -> BaseReranker: + """ + 获取重排服务(带自动降级) + + Returns: + BaseReranker: 重排服务实例 + """ + def _create_chain(): + primary = LocalLlamaCppRerankProvider() + fallback = ZhipuRerankProvider() + return FallbackServiceChain(primary, [fallback]) + + chain = SingletonServiceManager.get_or_create("rerank_service_chain", _create_chain) + return chain.get_available_service() diff --git a/backend/app/rag/pipeline.py b/backend/app/rag/pipeline.py index 41f4186..4f41bcc 100644 --- a/backend/app/rag/pipeline.py +++ b/backend/app/rag/pipeline.py @@ -2,12 +2,11 @@ import asyncio import os -from ..config import LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY from typing import List from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from .reranker import LLaMaCPPReranker +from ..model_services import get_rerank_service from .query_transform import MultiQueryGenerator from .fusion import reciprocal_rank_fusion @@ -37,13 +36,9 @@ class RAGPipeline: self.num_queries = num_queries self.rerank_top_n = rerank_top_n - # 初始化组件 + # 初始化组件 - 使用统一的重排服务获取接口 self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) - self.reranker = LLaMaCPPReranker( - base_url=LLAMACPP_RERANKER_URL, - api_key=LLAMACPP_API_KEY, - top_n=rerank_top_n, - ) + self.reranker = get_rerank_service() async def aretrieve(self, query: str) -> List[Document]: """ @@ -68,9 +63,9 @@ class RAGPipeline: # Step 4: 重排序 try: - final_docs = self.reranker.compress_documents(fused_docs, query) + final_docs = self.reranker.compress_documents(fused_docs, query, top_n=self.rerank_top_n) except Exception: - # 若重排序器不可用,直接返回融合后的前 N 条 + # 若重排序器不可用,直接返回融合后的前 N 个结果 final_docs = fused_docs[:self.rerank_top_n] return final_docs diff --git a/backend/rag_core/retriever_factory.py b/backend/rag_core/retriever_factory.py index b0e5ab6..9559797 100644 --- a/backend/rag_core/retriever_factory.py +++ b/backend/rag_core/retriever_factory.py @@ -17,10 +17,11 @@ def create_parent_retriever( parent_chunk_overlap: int = 100, child_chunk_size: int = 200, child_chunk_overlap: int = 20, + embeddings: Embeddings | None = None, ) -> ParentDocumentRetriever: """ 创建 ParentDocumentRetriever 实例。 - + Args: collection_name: Qdrant 集合名称,默认 "rag_documents" parent_splitter: 父文档切分器,默认 None(使用默认参数创建) @@ -31,16 +32,18 @@ def create_parent_retriever( parent_chunk_overlap: 父文档块重叠大小,默认 100 child_chunk_size: 子文档块大小,默认 200 child_chunk_overlap: 子文档块重叠大小,默认 20 - + embeddings: 嵌入模型实例,默认 None(使用内部默认的 LocalLlamaCppEmbedder) + Returns: ParentDocumentRetriever 实例 """ # 嵌入模型 - embedder = LlamaCppEmbedder() - embeddings = embedder.as_langchain_embeddings() - + if embeddings is None: + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() + # 向量存储(只读) - vector_store = QdrantVectorStore(collection_name=collection_name) + vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings) # 切分器(若未提供则创建默认) if parent_splitter is None: diff --git a/backend/rag_core/vector_store.py b/backend/rag_core/vector_store.py index 7848157..88cc518 100644 --- a/backend/rag_core/vector_store.py +++ b/backend/rag_core/vector_store.py @@ -8,6 +8,7 @@ import time from typing import List, Optional, Dict, Any from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams @@ -23,18 +24,25 @@ logger = logging.getLogger(__name__) class QdrantVectorStore: """Qdrant 向量数据库操作包装器。""" - def __init__(self, collection_name: str): + def __init__(self, collection_name: str, embeddings: Optional[Embeddings] = None): """ Args: collection_name: Qdrant 集合名称。 + embeddings: 嵌入模型实例,默认 None(使用内部默认的 LlamaCppEmbedder)。 """ self.collection_name = collection_name self._client: Optional[QdrantClient] = None self._connection_attempts = 0 self._last_connection_time: Optional[float] = None - - embedder = LlamaCppEmbedder() - self.embeddings = embedder.as_langchain_embeddings() + + # 嵌入模型 + if embeddings is None: + embedder = LlamaCppEmbedder() + self.embeddings = embedder.as_langchain_embeddings() + self._embedder = embedder + else: + self.embeddings = embeddings + self._embedder = None self.create_collection() @@ -90,8 +98,13 @@ class QdrantVectorStore: def create_collection(self, force_recreate: bool = False): """创建集合,设置合适的向量维度。""" - embedder = LlamaCppEmbedder() - vector_size = embedder.get_embedding_dimension() + if self._embedder is not None: + # 使用内部的 embedder 获取维度 + vector_size = self._embedder.get_embedding_dimension() + else: + # 使用外部传入的 embeddings,通过测试获取维度 + test_embedding = self.embeddings.embed_query("test") + vector_size = len(test_embedding) max_retries = 3 base_delay = 2 diff --git a/rag_indexer/config.py b/rag_indexer/config.py index a4da69b..77ad2b6 100644 --- a/rag_indexer/config.py +++ b/rag_indexer/config.py @@ -34,17 +34,28 @@ def _get_list_str(key: str, default: list[str] | None = None) -> list[str]: return default or [] -# ========== 向量数据库配置(URL + API密钥 配对) ========== +# ========== 第三方 API 密钥 ========== +ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY") + + +# ========== 智谱 API 配置 ========== +# 嵌入模型:根据 https://docs.bigmodel.cn/cn/guide/start/model-overview +# 可选:embedding-2、embedding-3 +ZHIPU_EMBEDDING_MODEL = _get_str("ZHIPU_EMBEDDING_MODEL") or "embedding-3" +ZHIPU_API_BASE = _get_str("ZHIPU_API_BASE") or "https://open.bigmodel.cn/api/paas/v4" + + +# ========== 向量数据库配置(URL + API 密钥 配对) ========== QDRANT_URL = _get_str("QDRANT_URL") QDRANT_API_KEY = _get_str("QDRANT_API_KEY") -# ========== 嵌入服务配置(URL + API密钥 配对) ========== +# ========== 嵌入服务配置(URL + API 密钥 配对) ========== LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL") LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY") -# ========== 文档存储配置(分离配置 + 完整URI) ========== +# ========== 文档存储配置(分离配置 + 完整 URI) ========== # 分离配置(优先使用) DB_HOST = _get_str("DB_HOST") DB_PORT = _get_int("DB_PORT") diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index 666fef7..8970351 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -31,6 +31,13 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever +# 尝试导入新的 model_services(如果可用) +try: + from app.model_services import get_embedding_service + HAS_MODEL_SERVICES = True +except ImportError: + HAS_MODEL_SERVICES = False + logger = logging.getLogger(__name__) # ---------- 配置数据类 ---------- @@ -69,10 +76,11 @@ class IndexBuilderConfig: class IndexBuilder: """RAG 索引构建主流水线,支持单块切分与父子块切分。""" - def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs): + def __init__(self, config: Optional[IndexBuilderConfig] = None, embeddings: Optional[Embeddings] = None, **kwargs): """ Args: config: 索引构建器配置对象,优先级高于 kwargs + embeddings: 可选的外部嵌入模型实例,如果提供则使用它 **kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留) """ if config is None: @@ -88,12 +96,29 @@ class IndexBuilder: # 初始化基础组件 self.loader = DocumentLoader() - self.embedder = LlamaCppEmbedder() - self.embeddings: Embeddings = self.embedder.as_langchain_embeddings() + + # 设置嵌入模型 - 优先使用外部提供的,然后尝试使用新服务,最后回退到原来的方式 + if embeddings is not None: + self.embeddings = embeddings + self.embedder = None + logger.info("使用外部提供的嵌入模型") + elif HAS_MODEL_SERVICES: + try: + self.embeddings = get_embedding_service() + self.embedder = None + logger.info("使用 model_services 提供的嵌入服务") + except Exception as e: + logger.warning(f"获取嵌入服务失败,回退到 LlamaCppEmbedder: {e}") + self.embedder = LlamaCppEmbedder() + self.embeddings = self.embedder.as_langchain_embeddings() + else: + self.embedder = LlamaCppEmbedder() + self.embeddings = self.embedder.as_langchain_embeddings() # 初始化向量存储 self.vector_store = QdrantVectorStore( collection_name=config.collection_name, + embeddings=self.embeddings if self.embedder is None else None, ) # 根据切分类型初始化相关组件 @@ -149,6 +174,7 @@ class IndexBuilder: child_splitter=self.child_splitter, docstore=self.docstore, search_k=cfg.search_k, + embeddings=self.embeddings if self.embedder is None else None, ) logger.info("ParentDocumentRetriever 初始化完成")