Files
ailine/backend/app/memory/mem0_client.py
root 8db63e7a8d 重构:添加模型服务模块,支持嵌入和重排服务的自动降级
新增功能:
- 创建 app/model_services 模块,提供统一的模型服务获取接口
- 实现 BaseServiceProvider 基类和 FallbackServiceChain 降级链
- 实现 get_embedding_service():优先本地 llama.cpp,降级到智谱 API
- 实现 get_rerank_service():优先本地 llama.cpp,降级到智谱 API
- 支持单例管理,确保全局只有一个服务实例

修改内容:
- 更新 app/config.py,添加智谱 API 相关配置
- 修改 rag_core/vector_store.py:支持接受外部传入的 embeddings
- 修改 rag_core/retriever_factory.py:支持接受外部传入的 embeddings
- 修改 app/agent/rag_initializer.py:使用 get_embedding_service()
- 修改 app/rag/pipeline.py:使用 get_rerank_service()
- 修改 app/memory/mem0_client.py:智能判断可用服务配置 mem0
- 修改 rag_indexer/index_builder.py:支持使用新服务,保持向后兼容
- 修改 rag_indexer/config.py:添加智谱配置

环境变量:
- ZHIPUAI_API_KEY:智谱 API 密钥(必选)
- ZHIPU_EMBEDDING_MODEL:可选,默认 embedding-3
- ZHIPU_RERANK_MODEL:可选,默认 rerank-2
- ZHIPU_API_BASE:可选,默认 https://open.bigmodel.cn/api/paas/v4
2026-04-24 22:52:36 +08:00

191 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ..config import (
LLM_API_KEY, ZHIPUAI_API_KEY,
VLLM_BASE_URL, QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY,
LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY,
ZHIPU_EMBEDDING_MODEL, ZHIPU_API_BASE
)
from ..model_services import get_embedding_service
from ..logger import info, warning, error
import time
"""
Mem0 记忆层客户端封装模块
负责 Mem0 的初始化、检索和存储
"""
import asyncio
from typing import Optional, List, Dict
from mem0 import AsyncMemory
class Mem0Client:
"""Mem0 异步客户端封装类"""
def __init__(self, llm_instance):
"""
初始化 Mem0 客户端
Args:
llm_instance: LangChain LLM 实例(用于事实提取)
"""
self.llm = llm_instance
self.mem0: Optional[AsyncMemory] = None
self._initialized = False
async def initialize(self):
"""异步初始化 Mem0 客户端,并进行实际连接测试"""
if self._initialized:
return
try:
# 获取可用的 embedding 服务并确定维度
embeddings = get_embedding_service()
test_embedding = embeddings.embed_query("test")
embedding_dim = len(test_embedding)
# 构建正确的 embedder 配置 - 根据我们的降级机制
# 首先我们需要判断哪个服务实际可用
from ..model_services.embedding_services import LocalLlamaCppEmbeddingProvider, ZhipuEmbeddingProvider
embedder_config = None
# 检查本地服务
local_provider = LocalLlamaCppEmbeddingProvider()
if local_provider.is_available():
info("✅ 使用本地 llama.cpp 作为 mem0 embedder")
embedder_config = {
"provider": "openai",
"config": {
"model": "Qwen3-Embedding-0.6B-Q8_0",
"api_key": LLAMACPP_API_KEY or "dummy",
"openai_base_url": LLAMACPP_EMBEDDING_URL,
}
}
else:
# 尝试使用智谱
zhipu_provider = ZhipuEmbeddingProvider()
if zhipu_provider.is_available():
info("✅ 使用智谱 API 作为 mem0 embedder")
# 注意mem0 可能不直接支持智谱,这里我们暂时还是用 openai 兼容方式
# 或者需要自定义 embedder
embedder_config = {
"provider": "openai",
"config": {
"model": ZHIPU_EMBEDDING_MODEL,
"api_key": ZHIPUAI_API_KEY,
"openai_base_url": ZHIPU_API_BASE,
}
}
else:
# 都不可用,使用 dummy 配置
warning("⚠️ 没有可用的 embedder使用 dummy 配置")
embedder_config = {
"provider": "openai",
"config": {
"model": "dummy",
"api_key": "dummy",
"openai_base_url": "http://localhost:8080/v1",
}
}
# Mem0 配置
config = {
"vector_store": {
"provider": "qdrant",
"config": {
"url": QDRANT_URL,
"api_key": QDRANT_API_KEY,
"collection_name": QDRANT_COLLECTION_NAME,
"embedding_model_dims": embedding_dim,
}
},
"llm": {
"provider": "openai",
"config": {
"model": "LLM_MODEL",
"api_key": LLM_API_KEY,
"openai_base_url": VLLM_BASE_URL,
"temperature": 0.1,
"max_tokens": 2000,
}
},
"embedder": embedder_config,
"version": "v1.1"
}
self.mem0 = AsyncMemory.from_config(config)
info("✅ Mem0 配置加载成功,开始连接测试...")
# 实际连接测试
try:
await asyncio.wait_for(
self.mem0.search("ping", user_id="test", limit=1),
timeout=30.0
)
info("✅ Mem0 实际连接测试成功,初始化完成")
except Exception as e:
warning(f"⚠️ Mem0 连接测试遇到问题,但仍继续初始化: {e}")
self._initialized = True
except asyncio.TimeoutError:
error("❌ Mem0 连接测试超时 (10s),请检查 Qdrant 或 Embedding 服务响应")
self.mem0 = None
self._initialized = False
except Exception as e:
error(f"❌ Mem0 初始化或连接测试失败: {e}")
import traceback
error(f"详细错误信息:\n{traceback.format_exc()}")
self.mem0 = None
self._initialized = False
async def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[str]:
"""
检索相关记忆
Args:
query: 查询文本
user_id: 用户 ID
limit: 返回结果数量限制
Returns:
List[str]: 记忆事实列表
"""
if not self.mem0:
warning("⚠️ Mem0 未初始化,跳过记忆检索")
return []
try:
memories = await asyncio.wait_for(
self.mem0.search(query, user_id=user_id, limit=limit),
timeout=30.0
)
if memories and "results" in memories:
facts = [m["memory"] for m in memories["results"] if m.get("memory")]
if facts:
info(f"🔍 [记忆检索] Mem0 返回 {len(facts)} 条记忆")
return facts
info("🔍 [记忆检索] 未找到相关记忆")
return []
except asyncio.TimeoutError:
warning("⚠️ Mem0 检索超时 (30s),跳过本次记忆检索")
return []
except Exception as e:
warning(f"⚠️ Mem0 检索失败: {e}")
return []
async def add_memories(self, messages, user_id):
if not self.mem0:
return False
try:
start = time.time()
info(f"📝 开始 Mem0 add消息数: {len(messages)}")
await asyncio.wait_for(
self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}),
timeout=60.0
)
info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s")
return True
except asyncio.TimeoutError:
error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s")
return False