修改引用逻辑,修改长期记忆bug

This commit is contained in:
2026-04-20 15:55:58 +08:00
parent 4e981e9dcf
commit 3143e0e4e6
39 changed files with 444 additions and 246 deletions

View File

@@ -2,7 +2,7 @@
AI Agent 应用模块 AI Agent 应用模块
""" """
from .agent import AIAgentService from app.agent import AIAgentService
from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"] __all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]

7
app/agent/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
Agent 子模块
"""
from app.agent.service import AIAgentService
__all__ = ["AIAgentService"]

View File

@@ -3,11 +3,9 @@
利用 LangGraph 的 checkpointer 获取对话历史和摘要 利用 LangGraph 的 checkpointer 获取对话历史和摘要
""" """
from typing import List, Dict, Any, Optional from typing import List, Dict, Any
import logging
from app.logger import error # 保持兼容,或者替换为 logger from app.logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService: class ThreadHistoryService:
"""线程历史查询服务""" """线程历史查询服务"""

View File

@@ -1,6 +1,5 @@
# app/prompts.py # app/prompts.py
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import BaseTool
def create_system_prompt(tools: list = None) -> ChatPromptTemplate: def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
""" """
@@ -11,7 +10,7 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
tool_descs = [] tool_descs = []
for tool in tools: for tool in tools:
# 提取工具名称和描述的第一行 # 提取工具名称和描述的第一行
name = getattr(tool, 'name', tool.__name__) name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
desc = (tool.description or "").split('\n')[0] desc = (tool.description or "").split('\n')[0]
tool_descs.append(f"- {name}: {desc}") tool_descs.append(f"- {name}: {desc}")
tools_section = "\n".join(tool_descs) tools_section = "\n".join(tool_descs)

View File

@@ -3,27 +3,17 @@ AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期 接收外部传入的 checkpointer不负责管理连接生命周期
""" """
import os
import json import json
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import SecretStr
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
# 本地模块 # 本地模块
from app.graph_builder import GraphBuilder, GraphContext from app.graph.graph_builder import GraphBuilder, GraphContext
from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.rag import RAGPipeline from app.agent.llm_factory import LLMFactory
from app.rag.tools import create_rag_tool_sync from app.agent.rag_initializer import init_rag_tool
from rag_core import create_parent_retriever from app.logger import info, warning
from app.llm_factory import LLMFactory
from app.rag_initializer import init_rag_tool
from app.logger import debug, info, warning, error
load_dotenv() load_dotenv()
class AIAgentService: class AIAgentService:
def __init__(self, checkpointer): def __init__(self, checkpointer):
self.checkpointer = checkpointer self.checkpointer = checkpointer

View File

@@ -16,7 +16,7 @@ from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService from app.agent import AIAgentService
from app.history import ThreadHistoryService from app.history import ThreadHistoryService
from app.logger import debug, info, warning, error from app.logger import info, error
# 加载 .env 文件 # 加载 .env 文件
load_dotenv() load_dotenv()
@@ -28,7 +28,6 @@ DB_URI = os.getenv(
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable" "postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
) )
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务""" """应用生命周期管理:创建并注入全局服务"""
@@ -53,7 +52,6 @@ async def lifespan(app: FastAPI):
# 5. 关闭时自动清理数据库连接async with 负责) # 5. 关闭时自动清理数据库连接async with 负责)
info("🛑 应用关闭,数据库连接池已释放") info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域) # CORS 中间件(允许前端跨域)
@@ -65,14 +63,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# ========== 健康检查端点 ========== # ========== 健康检查端点 ==========
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
"""健康检查端点,用于 Docker 和 CI/CD 监控""" """健康检查端点,用于 Docker 和 CI/CD 监控"""
return {"status": "ok", "service": "ai-agent-backend"} return {"status": "ok", "service": "ai-agent-backend"}
# ========== Pydantic 模型 ========== # ========== Pydantic 模型 ==========
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
message: str message: str
@@ -80,7 +76,6 @@ class ChatRequest(BaseModel):
model: str = "zhipu" model: str = "zhipu"
user_id: str = "default_user" user_id: str = "default_user"
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
reply: str reply: str
thread_id: str thread_id: str
@@ -90,18 +85,15 @@ class ChatResponse(BaseModel):
total_tokens: int = 0 total_tokens: int = 0
elapsed_time: float = 0.0 elapsed_time: float = 0.0
# ========== 依赖注入函数 ========== # ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService: def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例""" """从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service return request.app.state.agent_service
def get_history_service(request: Request) -> ThreadHistoryService: def get_history_service(request: Request) -> ThreadHistoryService:
"""从 app.state 中获取全局 ThreadHistoryService 实例""" """从 app.state 中获取全局 ThreadHistoryService 实例"""
return request.app.state.history_service return request.app.state.history_service
# ========== HTTP 端点 ========== # ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse) @app.post("/chat", response_model=ChatResponse)
async def chat_endpoint( async def chat_endpoint(
@@ -135,7 +127,6 @@ async def chat_endpoint(
elapsed_time=elapsed_time elapsed_time=elapsed_time
) )
# ========== 历史查询接口 ========== # ========== 历史查询接口 ==========
@app.get("/threads") @app.get("/threads")
async def list_threads( async def list_threads(
@@ -147,7 +138,6 @@ async def list_threads(
threads = await history_service.get_user_threads(user_id, limit) threads = await history_service.get_user_threads(user_id, limit)
return {"threads": threads} return {"threads": threads}
@app.get("/thread/{thread_id}/messages") @app.get("/thread/{thread_id}/messages")
async def get_thread_messages( async def get_thread_messages(
thread_id: str, thread_id: str,
@@ -158,7 +148,6 @@ async def get_thread_messages(
messages = await history_service.get_thread_messages(thread_id) messages = await history_service.get_thread_messages(thread_id)
return {"messages": messages} return {"messages": messages}
@app.get("/thread/{thread_id}/summary") @app.get("/thread/{thread_id}/summary")
async def get_thread_summary( async def get_thread_summary(
thread_id: str, thread_id: str,
@@ -169,7 +158,6 @@ async def get_thread_summary(
summary = await history_service.get_thread_summary(thread_id) summary = await history_service.get_thread_summary(thread_id)
return summary return summary
# ========== 流式对话接口 ========== # ========== 流式对话接口 ==========
@app.post("/chat/stream") @app.post("/chat/stream")
async def chat_stream_endpoint( async def chat_stream_endpoint(
@@ -204,7 +192,6 @@ async def chat_stream_endpoint(
} }
) )
# ========== WebSocket 端点(可选) ========== # ========== WebSocket 端点(可选) ==========
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint( async def websocket_endpoint(
@@ -228,7 +215,6 @@ async def websocket_endpoint(
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
# 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突) # 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突)

View File

@@ -18,6 +18,7 @@ MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
# Qdrant 向量数据库地址 # Qdrant 向量数据库地址
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories") QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化) # llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1") LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")

8
app/graph/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
Graph 子模块
"""
from app.graph.graph_builder import GraphBuilder
from app.graph.state import MessagesState, GraphContext
__all__ = ["GraphBuilder", "MessagesState", "GraphContext"]

View File

@@ -5,18 +5,17 @@ LangGraph 状态图构建模块 - 精简版,仅负责组装图
from langchain_core.language_models import BaseLLM from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END from langgraph.graph import StateGraph, START, END
# 本地模块
from app.graph.state import MessagesState, GraphContext from app.graph.state import MessagesState, GraphContext
from app.nodes import ( from app.nodes import (
should_continue,
create_llm_call_node, create_llm_call_node,
create_tool_call_node, create_tool_call_node,
create_retrieve_memory_node, create_retrieve_memory_node,
create_summarize_node, create_summarize_node,
should_continue finalize_node,
) )
from app.nodes.memory_trigger import memory_trigger_node, set_mem0_client
from app.memory import Mem0Client from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder: class GraphBuilder:
@@ -45,6 +44,9 @@ class GraphBuilder:
Returns: Returns:
StateGraph 实例 StateGraph 实例
""" """
# 注入全局客户端
set_mem0_client(self.mem0_client)
builder = StateGraph(MessagesState, context_schema=GraphContext) builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入) # ⭐ 通过工厂函数创建节点(依赖注入)
@@ -55,6 +57,7 @@ class GraphBuilder:
# 添加节点 # 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node) builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("memory_trigger", memory_trigger_node)
builder.add_node("llm_call", llm_call_node) builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node) builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node) builder.add_node("summarize", summarize_node)
@@ -62,7 +65,8 @@ class GraphBuilder:
# 添加边 # 添加边
builder.add_edge(START, "retrieve_memory") builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call") builder.add_edge("retrieve_memory", "memory_trigger")
builder.add_edge("memory_trigger", "llm_call")
builder.add_conditional_edges( builder.add_conditional_edges(
"llm_call", "llm_call",
should_continue, should_continue,

View File

@@ -3,7 +3,6 @@
""" """
# 标准库 # 标准库
import os
from pathlib import Path from pathlib import Path
# 第三方库 # 第三方库
@@ -13,7 +12,6 @@ import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from langchain_core.tools import tool from langchain_core.tools import tool
def _file_allow_check(filename: str) -> Path: def _file_allow_check(filename: str) -> Path:
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。""" """检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
allowed_dir = Path("./user_docs").resolve() allowed_dir = Path("./user_docs").resolve()
@@ -28,13 +26,11 @@ def _file_allow_check(filename: str) -> Path:
return file_path return file_path
@tool @tool
def get_current_temperature(location: str) -> str: def get_current_temperature(location: str) -> str:
"""获取指定地点的当前温度。""" """获取指定地点的当前温度。"""
return f'当前{location}的温度为25℃' return f'当前{location}的温度为25℃'
@tool @tool
def read_local_file(filename: str) -> str: def read_local_file(filename: str) -> str:
"""读取用户指定名称的本地文本文件内容并返回摘要。""" """读取用户指定名称的本地文本文件内容并返回摘要。"""
@@ -46,7 +42,6 @@ def read_local_file(filename: str) -> str:
except Exception as e: except Exception as e:
return f"读取文件时出错:{str(e)}" return f"读取文件时出错:{str(e)}"
@tool @tool
def read_pdf_summary(filename: str) -> str: def read_pdf_summary(filename: str) -> str:
"""读取PDF文件并返回内容文本摘要。""" """读取PDF文件并返回内容文本摘要。"""
@@ -61,7 +56,6 @@ def read_pdf_summary(filename: str) -> str:
except Exception as e: except Exception as e:
return f"读取PDF出错{e}" return f"读取PDF出错{e}"
@tool @tool
def read_excel_as_markdown(filename: str) -> str: def read_excel_as_markdown(filename: str) -> str:
"""读取Excel文件并将其主要数据转换为Markdown表格格式。""" """读取Excel文件并将其主要数据转换为Markdown表格格式。"""
@@ -73,7 +67,6 @@ def read_excel_as_markdown(filename: str) -> str:
except Exception as e: except Exception as e:
return f"读取Excel出错{e}" return f"读取Excel出错{e}"
@tool @tool
def fetch_webpage_content(url: str) -> str: def fetch_webpage_content(url: str) -> str:
"""抓取给定URL的网页正文内容并返回清晰的纯文本。""" """抓取给定URL的网页正文内容并返回清晰的纯文本。"""
@@ -91,7 +84,6 @@ def fetch_webpage_content(url: str) -> str:
except Exception as e: except Exception as e:
return f"抓取网页时出错:{str(e)}" return f"抓取网页时出错:{str(e)}"
# 工具列表和映射(全局常量) # 工具列表和映射(全局常量)
AVAILABLE_TOOLS = [ AVAILABLE_TOOLS = [
get_current_temperature, get_current_temperature,

View File

@@ -4,15 +4,13 @@
""" """
from typing import Any, Dict from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块 # 本地模块
from app.graph.state import MessagesState, GraphContext from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change from app.utils.logging import log_state_change
from app.logger import debug from app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client): def create_retrieve_memory_node(mem0_client: Mem0Client):
""" """
工厂函数:创建记忆检索节点 工厂函数:创建记忆检索节点

View File

@@ -4,12 +4,11 @@ LangGraph 状态定义模块
""" """
import operator import operator
from typing import Annotated, Any from typing import Annotated
from typing_extensions import TypedDict from typing_extensions import TypedDict
from dataclasses import dataclass from dataclasses import dataclass
from langchain_core.messages import AnyMessage from langchain_core.messages import AnyMessage
class MessagesState(TypedDict): class MessagesState(TypedDict):
"""对话状态类型定义""" """对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add] messages: Annotated[list[AnyMessage], operator.add]
@@ -19,7 +18,6 @@ class MessagesState(TypedDict):
last_elapsed_time: float # 本次调用耗时(秒) last_elapsed_time: float # 本次调用耗时(秒)
turns_since_last_summary: int # 距离上次生成摘要的轮数 turns_since_last_summary: int # 距离上次生成摘要的轮数
@dataclass @dataclass
class GraphContext: class GraphContext:
"""图执行上下文""" """图执行上下文"""

View File

@@ -1,79 +0,0 @@
"""
LangGraph 状态图构建模块 - 精简版,仅负责组装图
所有节点逻辑已拆分到独立模块
"""
from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.nodes import (
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
should_continue
)
from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder:
"""LangGraph 状态图构建器 - 仅负责组装图"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
self.mem0_client = Mem0Client(llm)
def build(self) -> StateGraph:
"""
构建未编译的状态图
Returns:
StateGraph 实例
"""
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
llm_call_node = create_llm_call_node(self.llm, self.tools)
tool_call_node = create_tool_call_node(self.tools_by_name)
summarize_node = create_summarize_node(self.mem0_client)
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
builder.add_node("finalize", finalize_node)
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{
"tool_node": "tool_node",
"summarize": "summarize",
"finalize": "finalize"
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("summarize", "finalize")
builder.add_edge("finalize", END)
return builder

View File

@@ -4,13 +4,12 @@ Mem0 记忆层客户端封装模块
""" """
import asyncio import asyncio
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict
from mem0 import AsyncMemory from mem0 import AsyncMemory
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
from app.logger import info, warning, error from app.logger import info, warning, error
class Mem0Client: class Mem0Client:
"""Mem0 异步客户端封装类""" """Mem0 异步客户端封装类"""
@@ -37,8 +36,9 @@ class Mem0Client:
"provider": "qdrant", "provider": "qdrant",
"config": { "config": {
"url": QDRANT_URL, # 直接使用完整 URL "url": QDRANT_URL, # 直接使用完整 URL
"api_key": QDRANT_API_KEY,
"collection_name": QDRANT_COLLECTION_NAME, "collection_name": QDRANT_COLLECTION_NAME,
"embedding_model_dims": 768, "embedding_model_dims": 1024,
} }
}, },
"llm": { "llm": {
@@ -50,7 +50,7 @@ class Mem0Client:
"embedder": { "embedder": {
"provider": "openai", "provider": "openai",
"config": { "config": {
"model": "embeddinggemma-300M-Q8_0", "model": "Qwen3-Embedding-0.6B-Q8_0",
"api_key": LLAMACPP_API_KEY, "api_key": LLAMACPP_API_KEY,
"openai_base_url": LLAMACPP_EMBEDDING_URL, "openai_base_url": LLAMACPP_EMBEDDING_URL,
}, },

View File

@@ -4,15 +4,13 @@
""" """
from typing import Any, Dict from typing import Any, Dict
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
# 本地模块 # 本地模块
from app.graph.state import MessagesState, GraphContext from app.graph.state import MessagesState
from app.utils.logging import log_state_change from app.utils.logging import log_state_change
from app.logger import info, error from app.logger import info, error
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:

View File

@@ -3,21 +3,17 @@ LLM 调用节点模块
负责调用大语言模型并处理响应 负责调用大语言模型并处理响应
""" """
import asyncio
import time import time
from typing import Any, Dict from typing import Any, Dict
from langchain_core.language_models import BaseLLM from langchain_core.language_models import BaseLLM
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda
from langgraph.runtime import Runtime
# 本地模块 # 本地模块
from app.graph.state import MessagesState, GraphContext from app.graph.state import MessagesState
from app.prompts import create_system_prompt from app.agent.prompts import create_system_prompt
from app.utils.logging import log_state_change, print_llm_input from app.utils.logging import log_state_change
from app.logger import debug, info, error from app.logger import debug, info, error
def create_llm_call_node(llm: BaseLLM, tools: list): def create_llm_call_node(llm: BaseLLM, tools: list):
""" """
工厂函数:创建 LLM 调用节点 工厂函数:创建 LLM 调用节点

View File

@@ -0,0 +1,38 @@
from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.logger import info
# 全局变量,在 GraphBuilder 中注入
_mem0_client: Mem0Client = None
def set_mem0_client(client: Mem0Client):
global _mem0_client
_mem0_client = client
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
if _mem0_client is None:
return {}
messages = state.get("messages", [])
if not messages:
return {}
last_msg = messages[-1]
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
# 触发词(可自行扩展)
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
if any(word in content for word in trigger_words):
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化
if not _mem0_client._initialized:
await _mem0_client.initialize()
# 将用户消息作为事实来源提交给 Mem0
mem0_messages = [{"role": "user", "content": content}]
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
return {} # 不修改状态

View File

@@ -4,15 +4,13 @@
""" """
from typing import Any, Dict from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块 # 本地模块
from app.graph.state import MessagesState, GraphContext from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning from app.logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client): def create_summarize_node(mem0_client: Mem0Client):
""" """
工厂函数:创建记忆存储节点 工厂函数:创建记忆存储节点

View File

@@ -6,15 +6,13 @@
import asyncio import asyncio
from typing import Any, Dict from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
# 本地模块 # 本地模块
from app.graph.state import MessagesState, GraphContext from app.graph.state import MessagesState
from app.utils.logging import log_state_change from app.utils.logging import log_state_change
from app.logger import debug, info from app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]): def create_tool_call_node(tools_by_name: Dict[str, Any]):
""" """
工厂函数:创建工具执行节点 工厂函数:创建工具执行节点

View File

@@ -13,7 +13,7 @@ RAG 检索与生成模块
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
示例用法: 示例用法:
>>> from app.rag import RAGPipeline, create_rag_tool >>> from app.rag.rag import RAGPipeline, create_rag_tool
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig >>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
>>> from langchain_openai import ChatOpenAI >>> from langchain_openai import ChatOpenAI
>>> >>>
@@ -34,16 +34,16 @@ RAG 检索与生成模块
>>> rag_tool = create_rag_tool(retriever=retriever, llm=llm) >>> rag_tool = create_rag_tool(retriever=retriever, llm=llm)
""" """
from .retriever import ( from app.rag.retriever import (
create_base_retriever, create_base_retriever,
create_hybrid_retriever, create_hybrid_retriever,
create_qdrant_client, create_qdrant_client,
) )
from .reranker import LLaMaCPPReranker from app.rag.reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator from app.rag.query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion from app.rag.fusion import reciprocal_rank_fusion
from .pipeline import RAGPipeline from app.rag.pipeline import RAGPipeline
from .tools import create_rag_tool, create_rag_tool_sync from app.rag.tools import create_rag_tool_sync
__all__ = [ __all__ = [
@@ -65,6 +65,5 @@ __all__ = [
"RAGPipeline", "RAGPipeline",
# 工具创建(供 Agent 使用) # 工具创建(供 Agent 使用)
"create_rag_tool",
"create_rag_tool_sync", "create_rag_tool_sync",
] ]

View File

@@ -1,6 +1,6 @@
# rag/fusion.py # rag/fusion.py
from typing import List, Dict, Tuple from typing import List, Dict
from langchain_core.documents import Document from langchain_core.documents import Document
def reciprocal_rank_fusion( def reciprocal_rank_fusion(

View File

@@ -2,15 +2,13 @@
import asyncio import asyncio
import os import os
from typing import List, Optional from typing import List
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from .retriever import create_qdrant_client # 可能不需要直接使用 from app.rag.reranker import LLaMaCPPReranker
from .reranker import LLaMaCPPReranker from app.rag.query_transform import MultiQueryGenerator
from .query_transform import MultiQueryGenerator from app.rag.fusion import reciprocal_rank_fusion
from .fusion import reciprocal_rank_fusion
class RAGPipeline: class RAGPipeline:
""" """

View File

@@ -2,9 +2,8 @@
重排序器模块 (适配版) 重排序器模块 (适配版)
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder 使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
""" """
import os
import requests import requests
from typing import List, Optional from typing import List
from langchain_core.documents import Document from langchain_core.documents import Document
class LLaMaCPPReranker: class LLaMaCPPReranker:

View File

@@ -11,7 +11,6 @@ RAG 系统使用示例(重构版)
import asyncio import asyncio
import sys import sys
import os import os
from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -19,12 +18,12 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
# 添加项目根目录到路径 # 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType from rag_indexer.splitters import SplitterType
from rag.pipeline import RAGPipeline from app.rag.pipeline import RAGPipeline
from rag.tools import create_rag_tool from app.rag.tools import create_rag_tool_sync
from pydantic import SecretStr from pydantic import SecretStr
# 使用本地 LLM通过 OpenAI 兼容接口) # 使用本地 LLM通过 OpenAI 兼容接口)
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
@@ -32,7 +31,6 @@ from rag_core.retriever_factory import create_parent_retriever
load_dotenv() load_dotenv()
def create_llm(): def create_llm():
"""创建本地 vLLM 服务 LLM""" """创建本地 vLLM 服务 LLM"""
vllm_base_url = os.getenv( vllm_base_url = os.getenv(
@@ -60,8 +58,7 @@ async def demonstrate_full_pipeline():
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)") print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
print("=" * 60) print("=" * 60)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
retriever = retriever = create_parent_retriever(collection_name="my_docs", search_k=5)
if retriever is None: if retriever is None:
print("错误:检索器未初始化,请确保索引已构建。") print("错误:检索器未初始化,请确保索引已构建。")
@@ -103,7 +100,6 @@ async def demonstrate_full_pipeline():
import traceback import traceback
traceback.print_exc() traceback.print_exc()
async def demonstrate_tool_creation(): async def demonstrate_tool_creation():
""" """
演示创建 RAG 工具(供 Agent 使用) 演示创建 RAG 工具(供 Agent 使用)
@@ -119,12 +115,11 @@ async def demonstrate_tool_creation():
) )
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM # 2. 创建 LLM
llm = create_llm() llm = create_llm()
# 3. 创建工具 # 3. 创建工具
rag_tool = create_rag_tool( rag_tool = create_rag_tool_sync(
retriever=retriever, retriever=retriever,
llm=llm, llm=llm,
num_queries=3, num_queries=3,
@@ -136,18 +131,16 @@ async def demonstrate_tool_creation():
print(f"工具描述: {rag_tool.description[:100]}...") print(f"工具描述: {rag_tool.description[:100]}...")
# 4. 模拟 Agent 调用工具 # 4. 模拟 Agent 调用工具
query = "请告诉我 RAG 系统的核心组件有哪些" query = "请告诉我 打虎英雄是谁"
print(f"\n模拟调用: {query}") print(f"\n模拟调用: {query}")
print("-" * 40) print("-" * 40)
result = await rag_tool.ainvoke({"query": query}) result = await rag_tool.ainvoke({"query": query})
print(result[:800] + "..." if len(result) > 800 else result) print(result[:800] + "..." if len(result) > 800 else result)
async def main(): async def main():
await demonstrate_full_pipeline() await demonstrate_full_pipeline()
await demonstrate_tool_creation() await demonstrate_tool_creation()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -4,11 +4,11 @@ RAG 工具模块
将检索功能封装为 LangChain Tool供 Agent 调用。 将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
""" """
from typing import Optional, Callable from typing import Callable
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from .pipeline import RAGPipeline from app.rag.pipeline import RAGPipeline
def create_rag_tool_sync( def create_rag_tool_sync(
retriever: BaseRetriever, retriever: BaseRetriever,

307
app/test_backend.py Normal file
View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
"""
完整后端测试 - 验证 Agent 所有功能
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
"""
import asyncio
import os
import sys
import uuid
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
load_dotenv()
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.agent.history import ThreadHistoryService
from app.logger import info, warning, error
# PostgreSQL 连接字符串
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable"
)
async def print_section(title):
"""打印测试区块标题"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70)
async def test_short_term_memory(agent_service):
"""测试短期记忆(同一 thread_id 继续对话)"""
await print_section("测试 1: 短期记忆Short-term Memory")
thread_id = str(uuid.uuid4())
user_id = "test_user_memory"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 第一轮对话
print("\n[第一轮] 发送消息: '我叫张三今年28岁'")
result1 = await agent_service.process_message(
"我叫张三今年28岁", thread_id, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 第二轮对话 - 测试记忆
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
result2 = await agent_service.process_message(
"我叫什么名字?今年多大?", thread_id, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证记忆是否存在
if "张三" in result2['reply'] or "28" in result2['reply']:
print("\n✅ 短期记忆测试通过!")
return True
else:
print("\n❌ 短期记忆测试失败!")
return False
async def test_tool_calling(agent_service):
"""测试工具调用RAG 搜索)"""
await print_section("测试 2: 工具调用Tool Calling")
thread_id = str(uuid.uuid4())
user_id = "test_user_tools"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 发送需要 RAG 搜索的问题
print("\n发送消息: '请告诉我,打虎英雄是谁?'")
result = await agent_service.process_message(
"请告诉我,打虎英雄是谁?", thread_id, "local", user_id
)
print(f"回复: {result['reply'][:200]}...")
# 检查是否调用了 RAG 工具(回复中会有水浒传相关内容)
if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']:
print("\n✅ 工具调用测试通过!")
return True
else:
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
return None
async def test_streaming(agent_service):
"""测试流式对话"""
await print_section("测试 3: 流式对话Streaming")
thread_id = str(uuid.uuid4())
user_id = "test_user_stream"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
print("流式输出: ", end="", flush=True)
full_reply = ""
chunk_count = 0
try:
async for chunk in agent_service.process_message_stream(
"用100字介绍一下AI人工智能", thread_id, "local", user_id
):
chunk_count += 1
if chunk.get("type") == "llm_token":
token = chunk.get("token", "")
print(token, end="", flush=True)
full_reply += token
elif chunk.get("type") == "state_update":
pass # 状态更新不显示
print(f"\n\n共收到 {chunk_count} 个 chunk")
print(f"完整回复长度: {len(full_reply)}")
if chunk_count > 0 and len(full_reply) > 10:
print("\n✅ 流式对话测试通过!")
return True
else:
print("\n❌ 流式对话测试失败!")
return False
except Exception as e:
print(f"\n❌ 流式对话异常: {e}")
return False
async def test_history_service(agent_service, history_service):
"""测试历史查询服务"""
await print_section("测试 4: 历史查询服务History Service")
user_id = "test_user_history"
# 先创建几个对话
print(f"\n为 user_id={user_id} 创建测试对话...")
thread_ids = []
for i in range(3):
thread_id = str(uuid.uuid4())
thread_ids.append(thread_id)
await agent_service.process_message(
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
)
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
# 1. 测试获取用户线程列表
print("\n[4.1] 测试获取用户线程列表...")
threads = await history_service.get_user_threads(user_id, limit=10)
print(f" 找到 {len(threads)} 个线程")
if len(threads) >= 3:
print(" ✅ 线程列表查询通过")
else:
print(" ⚠️ 线程数量少于预期")
# 2. 测试获取单个线程的消息历史
if thread_ids:
test_thread_id = thread_ids[0]
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
messages = await history_service.get_thread_messages(test_thread_id)
print(f" 找到 {len(messages)} 条消息")
if len(messages) >= 2: # 至少有一问一答
print(" ✅ 消息历史查询通过")
else:
print(" ⚠️ 消息数量少于预期")
# 3. 测试获取线程摘要
print(f"\n[4.3] 测试获取线程摘要...")
summary = await history_service.get_thread_summary(test_thread_id)
print(f" 摘要: {summary.get('summary', '')[:50]}...")
print(f" 消息数: {summary.get('message_count', 0)}")
if summary.get('message_count', 0) > 0:
print(" ✅ 线程摘要查询通过")
else:
print(" ⚠️ 摘要查询结果不确定")
return len(threads) >= 3
async def test_long_term_memory(agent_service):
"""测试长期记忆mem0"""
await print_section("测试 5: 长期记忆Long-term Memory - mem0")
thread_id1 = str(uuid.uuid4())
thread_id2 = str(uuid.uuid4()) # 不同的线程
user_id = "test_user_longterm"
print(f"\n使用 user_id: {user_id}")
print(f"线程 1: {thread_id1[:8]}...")
print(f"线程 2: {thread_id2[:8]}...")
# 在第一个线程中保存信息
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
result1 = await agent_service.process_message(
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 等待一下,让 mem0 保存
await asyncio.sleep(1)
# 在第二个线程中询问(不同的 thread_id
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
result2 = await agent_service.process_message(
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证长期记忆
if "小白" in result2['reply'] or "" in result2['reply']:
print("\n✅ 长期记忆测试通过!")
return True
else:
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
return None
async def main():
"""主测试函数"""
print("\n" + "=" * 70)
print(" 后端完整功能测试")
print("=" * 70)
results = {}
try:
# 创建数据库连接和服务
print("\n正在初始化数据库连接...")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
print("✅ 数据库连接成功")
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
print("✅ Agent 服务初始化成功")
history_service = ThreadHistoryService(checkpointer)
print("✅ 历史服务初始化成功")
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
# 运行测试
results["短期记忆"] = await test_short_term_memory(agent_service)
await asyncio.sleep(1)
results["工具调用"] = await test_tool_calling(agent_service)
await asyncio.sleep(1)
results["流式对话"] = await test_streaming(agent_service)
await asyncio.sleep(1)
results["历史查询"] = await test_history_service(agent_service, history_service)
await asyncio.sleep(1)
results["长期记忆"] = await test_long_term_memory(agent_service)
await asyncio.sleep(1)
# 打印总结
await print_section("测试总结")
print("\n测试结果:")
print("-" * 40)
pass_count = 0
fail_count = 0
skip_count = 0
for test_name, result in results.items():
if result is True:
status = "✅ 通过"
pass_count += 1
elif result is False:
status = "❌ 失败"
fail_count += 1
else:
status = "⚠️ 待验证"
skip_count += 1
print(f" {test_name:12s}: {status}")
print("-" * 40)
print(f"总计: {len(results)} 个测试")
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
if fail_count == 0:
print("\n🎉 所有核心测试通过!")
else:
print(f"\n⚠️ 有 {fail_count} 个测试失败")
except Exception as e:
error(f"\n❌ 测试运行异常: {e}")
import traceback
traceback.print_exc()
return 1
return 0 if fail_count == 0 else 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@@ -3,7 +3,7 @@ AI Agent 前端模块
采用分层架构设计包含配置、状态、API客户端和UI组件 采用分层架构设计包含配置、状态、API客户端和UI组件
""" """
from .logger import debug, info, warning, error from frontend.logger import debug, info, warning, error
__version__ = "2.0.0" __version__ = "2.0.0"
__all__ = ["debug", "info", "warning", "error"] __all__ = ["debug", "info", "warning", "error"]

View File

@@ -9,8 +9,6 @@ from datetime import datetime
# 使用绝对导入 # 使用绝对导入
from frontend.state import AppState from frontend.state import AppState
from frontend.api_client import api_client from frontend.api_client import api_client
from frontend.config import config
def render_sidebar(): def render_sidebar():
"""渲染左侧栏""" """渲染左侧栏"""
@@ -25,7 +23,6 @@ def render_sidebar():
st.divider() st.divider()
_render_user_section() _render_user_section()
def _render_user_section(): def _render_user_section():
"""渲染用户登录区域""" """渲染用户登录区域"""
# st.header("👤 用户") # 移除显眼的标题,改用更柔和的 caption # st.header("👤 用户") # 移除显眼的标题,改用更柔和的 caption
@@ -36,7 +33,6 @@ def _render_user_section():
else: else:
_render_user_info() _render_user_info()
def _render_login_form(): def _render_login_form():
"""渲染登录表单""" """渲染登录表单"""
username = st.text_input( username = st.text_input(
@@ -54,7 +50,6 @@ def _render_login_form():
# st.info("💡 建议登录以隔离对话历史") # 移除多余色块 # st.info("💡 建议登录以隔离对话历史") # 移除多余色块
def _render_user_info(): def _render_user_info():
"""渲染用户信息""" """渲染用户信息"""
st.markdown(f"**当前用户**: `{AppState.get_user_id()}`") st.markdown(f"**当前用户**: `{AppState.get_user_id()}`")
@@ -64,7 +59,6 @@ def _render_user_info():
_refresh_threads() _refresh_threads()
st.rerun() st.rerun()
def _render_history_section(): def _render_history_section():
"""渲染历史对话列表""" """渲染历史对话列表"""
col1, col2 = st.columns([3, 1]) col1, col2 = st.columns([3, 1])
@@ -76,7 +70,6 @@ def _render_history_section():
_render_thread_list() _render_thread_list()
def _render_history_actions(): def _render_history_actions():
"""渲染历史操作按钮""" """渲染历史操作按钮"""
# 移除了 type="primary",让它变成普通的线框按钮,不再是大红块 # 移除了 type="primary",让它变成普通的线框按钮,不再是大红块
@@ -84,7 +77,6 @@ def _render_history_actions():
AppState.start_new_thread() AppState.start_new_thread()
st.rerun() st.rerun()
def _render_thread_list(): def _render_thread_list():
"""渲染线程列表""" """渲染线程列表"""
# 仅在初次加载时拉取,或由外部主动调用 _refresh_threads() 更新 # 仅在初次加载时拉取,或由外部主动调用 _refresh_threads() 更新
@@ -101,7 +93,6 @@ def _render_thread_list():
for thread in threads: for thread in threads:
_render_thread_item(thread) _render_thread_item(thread)
def _render_thread_item(thread: dict): def _render_thread_item(thread: dict):
""" """
渲染单个线程项 渲染单个线程项
@@ -130,7 +121,6 @@ def _render_thread_item(thread: dict):
): ):
_load_thread(thread_id) _load_thread(thread_id)
def _format_time(time_str: str) -> str: def _format_time(time_str: str) -> str:
""" """
格式化时间字符串 格式化时间字符串
@@ -150,13 +140,11 @@ def _format_time(time_str: str) -> str:
except Exception: except Exception:
return time_str[:10] return time_str[:10]
def _refresh_threads(): def _refresh_threads():
"""刷新历史线程列表""" """刷新历史线程列表"""
threads = api_client.get_user_threads(AppState.get_user_id()) threads = api_client.get_user_threads(AppState.get_user_id())
AppState.set_threads(threads) AppState.set_threads(threads)
def _load_thread(thread_id: str): def _load_thread(thread_id: str):
""" """
加载指定线程的消息历史 加载指定线程的消息历史

View File

@@ -7,7 +7,7 @@ import uuid
from typing import List, Dict, Any from typing import List, Dict, Any
import streamlit as st import streamlit as st
from .config import config from frontend.config import config
class AppState: class AppState:

View File

@@ -4,10 +4,10 @@ RAG Core - 公共 RAG 组件包
提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。 提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。
""" """
from .embedders import LlamaCppEmbedder from rag_core.embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY from rag_core.vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .store import PostgresDocStore, create_docstore from rag_core.store import PostgresDocStore, create_docstore
from .retriever_factory import create_parent_retriever from rag_core.retriever_factory import create_parent_retriever
__all__ = [ __all__ = [

View File

@@ -5,11 +5,9 @@
import os import os
import httpx import httpx
from typing import List, Optional from typing import List, Optional
from urllib.parse import urljoin
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
class LlamaCppEmbedder: class LlamaCppEmbedder:
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。""" """通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
@@ -17,7 +15,7 @@ class LlamaCppEmbedder:
self, self,
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
model: str = "embeddinggemma-300M-Q8_0", model: str = "Qwen3-Embedding-0.6B-Q8_0",
): ):
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082") self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "") self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
@@ -71,7 +69,6 @@ class LlamaCppEmbedder:
else: else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}") raise ValueError(f"未知的嵌入 API 响应格式: {data}")
class _LlamaCppLangchainAdapter(Embeddings): class _LlamaCppLangchainAdapter(Embeddings):
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。""" """将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""

View File

@@ -2,14 +2,7 @@
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_classic.retrievers import ParentDocumentRetriever from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
from rag_indexer.splitters import SplitterType, get_splitter from typing import Optional
import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
@@ -17,7 +10,6 @@ from langchain_classic.retrievers import ParentDocumentRetriever
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever( def create_parent_retriever(
collection_name: str = "rag_documents", collection_name: str = "rag_documents",
embeddings: Optional[Embeddings] = None, embeddings: Optional[Embeddings] = None,

View File

@@ -15,8 +15,8 @@
""" """
from .postgres import PostgresDocStore from rag_core.store.postgres import PostgresDocStore
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI from rag_core.store.factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
__version__ = "2.0.0" __version__ = "2.0.0"

View File

@@ -9,7 +9,7 @@ import logging
from typing import Optional, Tuple from typing import Optional, Tuple
from langchain_core.stores import BaseStore from langchain_core.stores import BaseStore
from .postgres import PostgresDocStore from rag_core.store.postgres import PostgresDocStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -4,12 +4,10 @@
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。 使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
""" """
from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.stores import BaseStore from langchain_core.stores import BaseStore
@@ -18,7 +16,6 @@ import asyncpg
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PostgresDocStore(BaseStore[str, Any]): class PostgresDocStore(BaseStore[str, Any]):
""" """
异步 PostgreSQL 文档存储实现。 异步 PostgreSQL 文档存储实现。

View File

@@ -13,7 +13,7 @@ from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams from qdrant_client.http.models import Distance, VectorParams
from httpx import RemoteProtocolError from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException from qdrant_client.http.exceptions import ResponseHandlingException
from .client import create_qdrant_client from rag_core.client import create_qdrant_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,7 +35,7 @@ class QdrantVectorStore:
self._last_connection_time: Optional[float] = None self._last_connection_time: Optional[float] = None
if embeddings is None: if embeddings is None:
from .embedders import LlamaCppEmbedder from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder() embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings() self.embeddings = embedder.as_langchain_embeddings()
else: else:
@@ -96,7 +96,7 @@ class QdrantVectorStore:
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。""" """创建集合,设置合适的向量维度。"""
if vector_size is None: if vector_size is None:
from .embedders import LlamaCppEmbedder from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder() embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension() vector_size = embedder.get_embedding_dimension()

View File

@@ -23,9 +23,9 @@ Offline RAG Indexer module.
>>> await builder.build_from_file("document.pdf") >>> await builder.build_from_file("document.pdf")
""" """
from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
from .loaders import DocumentLoader from rag_indexer.loaders import DocumentLoader
from .splitters import SplitterType, get_splitter from rag_indexer.splitters import SplitterType, get_splitter
# 从 rag_core 重新导出常用组件 # 从 rag_core 重新导出常用组件
from rag_core import ( from rag_core import (

View File

@@ -8,24 +8,21 @@ import asyncio
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple from typing import List, Union, Optional, Any, Dict
from httpx import RemoteProtocolError from httpx import RemoteProtocolError
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from qdrant_client.http.exceptions import ResponseHandlingException from qdrant_client.http.exceptions import ResponseHandlingException
from .loaders import DocumentLoader from rag_indexer.loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter from rag_indexer.splitters import SplitterType, get_splitter
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------- 配置数据类 ---------- # ---------- 配置数据类 ----------
@dataclass @dataclass
class DocstoreConfig: class DocstoreConfig:
@@ -36,7 +33,6 @@ class DocstoreConfig:
# 若要从外部注入已创建好的 docstore可直接设置此字段 # 若要从外部注入已创建好的 docstore可直接设置此字段
instance: Optional[BaseStore] = None instance: Optional[BaseStore] = None
@dataclass @dataclass
class IndexBuilderConfig: class IndexBuilderConfig:
"""索引构建器配置。""" """索引构建器配置。"""
@@ -60,7 +56,6 @@ class IndexBuilderConfig:
# 其他切分器参数(当 splitter_type 非父子块时使用) # 其他切分器参数(当 splitter_type 非父子块时使用)
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict) extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
# ---------- 索引构建器 ---------- # ---------- 索引构建器 ----------
class IndexBuilder: class IndexBuilder:
"""RAG 索引构建主流水线,支持单块切分与父子块切分。""" """RAG 索引构建主流水线,支持单块切分与父子块切分。"""

View File

@@ -250,7 +250,7 @@ start_embedding() {
echo -e "${BLUE}🚀 启动 llama.cpp Embedding 容器...${NC}" echo -e "${BLUE}🚀 启动 llama.cpp Embedding 容器...${NC}"
# 检查模型文件 # 检查模型文件
if [ ! -f "/home/huang/Study/AIModel/GGUF/embeddinggemma-300M-Q8_0.gguf" ]; then if [ ! -f "/home/huang/Study/AIModel/GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf" ]; then
echo -e "${RED}✗ 错误Embedding 模型文件不存在${NC}" echo -e "${RED}✗ 错误Embedding 模型文件不存在${NC}"
exit 1 exit 1
fi fi
@@ -263,13 +263,16 @@ start_embedding() {
--device=/dev/dri \ --device=/dev/dri \
-v /home/huang/Study/AIModel/GGUF:/models \ -v /home/huang/Study/AIModel/GGUF:/models \
-p 8082:8080 \ -p 8082:8080 \
-e LLAMA_ARG_CTX_SIZE=16384 \
-e LLAMA_ARG_N_PARALLEL=1 \
-e LLAMA_ARG_BATCH=512 \
-e LLAMA_ARG_N_GPU_LAYERS=99 \
-e LLAMA_ARG_API_KEY=huang1998 \
ghcr.io/ggml-org/llama.cpp:server-rocm \ ghcr.io/ggml-org/llama.cpp:server-rocm \
-m /models/embeddinggemma-300M-Q8_0.gguf \ -m /models/Qwen3-Embedding-0.6B-Q8_0.gguf \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 8080 \ --port 8080 \
-ngl 99 \ --embeddings
--embeddings \
-c 512
echo -e "${GREEN}✓ llama.cpp Embedding 容器已启动 (端口 8082)${NC}" echo -e "${GREEN}✓ llama.cpp Embedding 容器已启动 (端口 8082)${NC}"
sleep 5 sleep 5