This commit is contained in:
@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
from app.agent import AIAgentService
|
from app.agent import AIAgentService
|
||||||
from app.history import ThreadHistoryService
|
from app.agent.history import ThreadHistoryService
|
||||||
from app.logger import info, error
|
from app.logger import info, error
|
||||||
|
|
||||||
# 加载 .env 文件
|
# 加载 .env 文件
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
|||||||
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
|
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
|
||||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
|
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
|
||||||
|
|
||||||
|
# ========== llm 配置 ==========
|
||||||
|
# LLM 模型配置
|
||||||
|
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
|
||||||
|
LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key")
|
||||||
|
|
||||||
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
|
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
|
||||||
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
|
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
|
||||||
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key")
|
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key")
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
from app.config import LLM_API_KEY
|
||||||
|
from app.config import VLLM_BASE_URL
|
||||||
|
import time
|
||||||
"""
|
"""
|
||||||
Mem0 记忆层客户端封装模块
|
Mem0 记忆层客户端封装模块
|
||||||
负责 Mem0 的初始化、检索和存储
|
负责 Mem0 的初始化、检索和存储
|
||||||
@@ -7,7 +10,11 @@ import asyncio
|
|||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from mem0 import AsyncMemory
|
from mem0 import AsyncMemory
|
||||||
|
|
||||||
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
from app.config import (
|
||||||
|
QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY,
|
||||||
|
VLLM_BASE_URL, LLM_API_KEY,
|
||||||
|
LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||||
|
)
|
||||||
from app.logger import info, warning, error
|
from app.logger import info, warning, error
|
||||||
|
|
||||||
class Mem0Client:
|
class Mem0Client:
|
||||||
@@ -42,9 +49,13 @@ class Mem0Client:
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "langchain",
|
"provider": "openai",
|
||||||
"config": {
|
"config": {
|
||||||
"model": self.llm
|
"model": "LLM_MODEL",
|
||||||
|
"api_key": LLM_API_KEY,
|
||||||
|
"openai_base_url": VLLM_BASE_URL,
|
||||||
|
"temperature": 0.1,
|
||||||
|
"max_tokens": 2000,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embedder": {
|
"embedder": {
|
||||||
@@ -118,36 +129,18 @@ class Mem0Client:
|
|||||||
warning(f"⚠️ Mem0 检索失败: {e}")
|
warning(f"⚠️ Mem0 检索失败: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def add_memories(self, messages: List[Dict[str, str]], user_id: str) -> bool:
|
async def add_memories(self, messages, user_id):
|
||||||
"""
|
if not self.mem0:
|
||||||
添加记忆(自动提取事实并存储)
|
return False
|
||||||
|
try:
|
||||||
Args:
|
start = time.time()
|
||||||
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}]
|
info(f"📝 开始 Mem0 add,消息数: {len(messages)}")
|
||||||
user_id: 用户 ID
|
await asyncio.wait_for(
|
||||||
|
self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}),
|
||||||
Returns:
|
timeout=60.0
|
||||||
bool: 是否成功
|
)
|
||||||
"""
|
info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s")
|
||||||
if not self.mem0:
|
return True
|
||||||
warning("⚠️ Mem0 未初始化,跳过记忆添加")
|
except asyncio.TimeoutError:
|
||||||
return False
|
error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s")
|
||||||
|
return False
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
self.mem0.add(
|
|
||||||
messages,
|
|
||||||
user_id=user_id,
|
|
||||||
metadata={"type": "conversation"}
|
|
||||||
),
|
|
||||||
timeout=60.0
|
|
||||||
)
|
|
||||||
info("📝 [记忆添加] 已提交给 Mem0 进行事实提取")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
error("❌ Mem0 记忆添加超时 (60s)")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
error(f"❌ Mem0 记忆添加失败: {e}")
|
|
||||||
return False
|
|
||||||
@@ -31,8 +31,8 @@ async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> D
|
|||||||
if not _mem0_client._initialized:
|
if not _mem0_client._initialized:
|
||||||
await _mem0_client.initialize()
|
await _mem0_client.initialize()
|
||||||
# 将用户消息作为事实来源提交给 Mem0
|
# 将用户消息作为事实来源提交给 Mem0
|
||||||
|
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
|
||||||
mem0_messages = [{"role": "user", "content": content}]
|
mem0_messages = [{"role": "user", "content": content}]
|
||||||
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
|
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
|
||||||
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
|
|
||||||
|
|
||||||
return {} # 不修改状态
|
return {} # 不修改状态
|
||||||
@@ -3,6 +3,7 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
|
|
||||||
|
|
||||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
from langchain_core.stores import BaseStore
|
from langchain_core.stores import BaseStore
|
||||||
from rag_core.store.postgres import PostgresDocStore
|
from rag_core.store.postgres import PostgresDocStore
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 默认连接字符串(从环境变量读取)
|
# 默认连接字符串(从环境变量读取)
|
||||||
DEFAULT_DB_URI = os.getenv(
|
DEFAULT_DB_URI = os.getenv(
|
||||||
|
|||||||
Reference in New Issue
Block a user