diff --git a/app/backend.py b/app/backend.py index b60269e..c5b857a 100644 --- a/app/backend.py +++ b/app/backend.py @@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from app.agent import AIAgentService -from app.history import ThreadHistoryService +from app.agent.history import ThreadHistoryService from app.logger import info, error # 加载 .env 文件 diff --git a/app/config.py b/app/config.py index ad70ab2..77b2c7b 100644 --- a/app/config.py +++ b/app/config.py @@ -20,6 +20,11 @@ QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories") QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key") +# ========== llm 配置 ========== +# LLM 模型配置 +VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1") +LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key") + # llama.cpp Embedding 服务地址 (用于 Mem0 的向量化) LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1") LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key") \ No newline at end of file diff --git a/app/memory/mem0_client.py b/app/memory/mem0_client.py index f003610..b2cf89e 100644 --- a/app/memory/mem0_client.py +++ b/app/memory/mem0_client.py @@ -1,3 +1,6 @@ +from app.config import LLM_API_KEY +from app.config import VLLM_BASE_URL +import time """ Mem0 记忆层客户端封装模块 负责 Mem0 的初始化、检索和存储 @@ -7,7 +10,11 @@ import asyncio from typing import Optional, List, Dict from mem0 import AsyncMemory -from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY +from app.config import ( + QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY, + VLLM_BASE_URL, LLM_API_KEY, + LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY +) from app.logger import info, warning, error class Mem0Client: @@ -42,9 +49,13 @@ class Mem0Client: } }, "llm": { - "provider": "langchain", + "provider": "openai", "config": { - "model": self.llm + "model": "LLM_MODEL", + "api_key": LLM_API_KEY, + "openai_base_url": VLLM_BASE_URL, + "temperature": 0.1, + "max_tokens": 2000, } }, "embedder": { @@ -118,36 +129,18 @@ class Mem0Client: warning(f"⚠️ Mem0 检索失败: {e}") return [] - async def add_memories(self, messages: List[Dict[str, str]], user_id: str) -> bool: - """ - 添加记忆(自动提取事实并存储) - - Args: - messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}] - user_id: 用户 ID - - Returns: - bool: 是否成功 - """ - if not self.mem0: - warning("⚠️ Mem0 未初始化,跳过记忆添加") - return False - - try: - await asyncio.wait_for( - self.mem0.add( - messages, - user_id=user_id, - metadata={"type": "conversation"} - ), - timeout=60.0 - ) - info("📝 [记忆添加] 已提交给 Mem0 进行事实提取") - return True - - except asyncio.TimeoutError: - error("❌ Mem0 记忆添加超时 (60s)") - return False - except Exception as e: - error(f"❌ Mem0 记忆添加失败: {e}") - return False \ No newline at end of file + async def add_memories(self, messages, user_id): + if not self.mem0: + return False + try: + start = time.time() + info(f"📝 开始 Mem0 add,消息数: {len(messages)}") + await asyncio.wait_for( + self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}), + timeout=60.0 + ) + info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s") + return True + except asyncio.TimeoutError: + error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s") + return False \ No newline at end of file diff --git a/app/nodes/memory_trigger.py b/app/nodes/memory_trigger.py index 77078ed..6f02879 100644 --- a/app/nodes/memory_trigger.py +++ b/app/nodes/memory_trigger.py @@ -31,8 +31,8 @@ async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> D if not _mem0_client._initialized: await _mem0_client.initialize() # 将用户消息作为事实来源提交给 Mem0 + info(f"📌 检测到记忆指令,已主动触发 Mem0 存储") mem0_messages = [{"role": "user", "content": content}] await _mem0_client.add_memories(mem0_messages, user_id=user_id) - info(f"📌 检测到记忆指令,已主动触发 Mem0 存储") return {} # 不修改状态 \ No newline at end of file diff --git a/rag_core/client.py b/rag_core/client.py index 109958a..7615ea7 100644 --- a/rag_core/client.py +++ b/rag_core/client.py @@ -3,6 +3,7 @@ import os from typing import Optional from qdrant_client import QdrantClient + QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") diff --git a/rag_core/store/factory.py b/rag_core/store/factory.py index 391b077..43b465e 100644 --- a/rag_core/store/factory.py +++ b/rag_core/store/factory.py @@ -10,8 +10,10 @@ from typing import Optional, Tuple from langchain_core.stores import BaseStore from rag_core.store.postgres import PostgresDocStore +from dotenv import load_dotenv +load_dotenv() -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 默认连接字符串(从环境变量读取) DEFAULT_DB_URI = os.getenv(