修改引用逻辑,修改长期记忆bug

This commit is contained in:
2026-04-20 15:55:58 +08:00
parent 4e981e9dcf
commit 3143e0e4e6
39 changed files with 444 additions and 246 deletions

View File

@@ -2,7 +2,7 @@
AI Agent 应用模块
"""
from .agent import AIAgentService
from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.agent import AIAgentService
from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]

7
app/agent/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
Agent 子模块
"""
from app.agent.service import AIAgentService
__all__ = ["AIAgentService"]

View File

@@ -3,11 +3,9 @@
利用 LangGraph 的 checkpointer 获取对话历史和摘要
"""
from typing import List, Dict, Any, Optional
import logging
from typing import List, Dict, Any
from app.logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""

View File

@@ -1,6 +1,5 @@
# app/prompts.py
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import BaseTool
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
"""
@@ -11,7 +10,7 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
tool_descs = []
for tool in tools:
# 提取工具名称和描述的第一行
name = getattr(tool, 'name', tool.__name__)
name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
desc = (tool.description or "").split('\n')[0]
tool_descs.append(f"- {name}: {desc}")
tools_section = "\n".join(tool_descs)

View File

@@ -3,27 +3,17 @@ AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import os
import json
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import SecretStr
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.rag import RAGPipeline
from app.rag.tools import create_rag_tool_sync
from rag_core import create_parent_retriever
from app.llm_factory import LLMFactory
from app.rag_initializer import init_rag_tool
from app.logger import debug, info, warning, error
from app.graph.graph_builder import GraphBuilder, GraphContext
from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.agent.llm_factory import LLMFactory
from app.agent.rag_initializer import init_rag_tool
from app.logger import info, warning
load_dotenv()
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer

View File

@@ -16,7 +16,7 @@ from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.history import ThreadHistoryService
from app.logger import debug, info, warning, error
from app.logger import info, error
# 加载 .env 文件
load_dotenv()
@@ -28,7 +28,6 @@ DB_URI = os.getenv(
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
@@ -53,7 +52,6 @@ async def lifespan(app: FastAPI):
# 5. 关闭时自动清理数据库连接async with 负责)
info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域)
@@ -65,14 +63,12 @@ app.add_middleware(
allow_headers=["*"],
)
# ========== 健康检查端点 ==========
@app.get("/health")
async def health_check():
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
return {"status": "ok", "service": "ai-agent-backend"}
# ========== Pydantic 模型 ==========
class ChatRequest(BaseModel):
message: str
@@ -80,7 +76,6 @@ class ChatRequest(BaseModel):
model: str = "zhipu"
user_id: str = "default_user"
class ChatResponse(BaseModel):
reply: str
thread_id: str
@@ -90,18 +85,15 @@ class ChatResponse(BaseModel):
total_tokens: int = 0
elapsed_time: float = 0.0
# ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service
def get_history_service(request: Request) -> ThreadHistoryService:
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
return request.app.state.history_service
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
@@ -135,7 +127,6 @@ async def chat_endpoint(
elapsed_time=elapsed_time
)
# ========== 历史查询接口 ==========
@app.get("/threads")
async def list_threads(
@@ -147,7 +138,6 @@ async def list_threads(
threads = await history_service.get_user_threads(user_id, limit)
return {"threads": threads}
@app.get("/thread/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
@@ -158,7 +148,6 @@ async def get_thread_messages(
messages = await history_service.get_thread_messages(thread_id)
return {"messages": messages}
@app.get("/thread/{thread_id}/summary")
async def get_thread_summary(
thread_id: str,
@@ -169,7 +158,6 @@ async def get_thread_summary(
summary = await history_service.get_thread_summary(thread_id)
return summary
# ========== 流式对话接口 ==========
@app.post("/chat/stream")
async def chat_stream_endpoint(
@@ -204,7 +192,6 @@ async def chat_stream_endpoint(
}
)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(
@@ -228,7 +215,6 @@ async def websocket_endpoint(
except WebSocketDisconnect:
pass
if __name__ == "__main__":
import uvicorn
# 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突)

View File

@@ -18,6 +18,7 @@ MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
# Qdrant 向量数据库地址
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")

8
app/graph/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
Graph 子模块
"""
from app.graph.graph_builder import GraphBuilder
from app.graph.state import MessagesState, GraphContext
__all__ = ["GraphBuilder", "MessagesState", "GraphContext"]

View File

@@ -5,18 +5,17 @@ LangGraph 状态图构建模块 - 精简版,仅负责组装图
from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.nodes import (
should_continue,
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
should_continue
finalize_node,
)
from app.nodes.memory_trigger import memory_trigger_node, set_mem0_client
from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder:
@@ -45,6 +44,9 @@ class GraphBuilder:
Returns:
StateGraph 实例
"""
# 注入全局客户端
set_mem0_client(self.mem0_client)
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
@@ -55,6 +57,7 @@ class GraphBuilder:
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("memory_trigger", memory_trigger_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
@@ -62,7 +65,8 @@ class GraphBuilder:
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_edge("retrieve_memory", "memory_trigger")
builder.add_edge("memory_trigger", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,

View File

@@ -3,7 +3,6 @@
"""
# 标准库
import os
from pathlib import Path
# 第三方库
@@ -13,7 +12,6 @@ import requests
from bs4 import BeautifulSoup
from langchain_core.tools import tool
def _file_allow_check(filename: str) -> Path:
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
allowed_dir = Path("./user_docs").resolve()
@@ -28,13 +26,11 @@ def _file_allow_check(filename: str) -> Path:
return file_path
@tool
def get_current_temperature(location: str) -> str:
"""获取指定地点的当前温度。"""
return f'当前{location}的温度为25℃'
@tool
def read_local_file(filename: str) -> str:
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
@@ -46,7 +42,6 @@ def read_local_file(filename: str) -> str:
except Exception as e:
return f"读取文件时出错:{str(e)}"
@tool
def read_pdf_summary(filename: str) -> str:
"""读取PDF文件并返回内容文本摘要。"""
@@ -61,7 +56,6 @@ def read_pdf_summary(filename: str) -> str:
except Exception as e:
return f"读取PDF出错{e}"
@tool
def read_excel_as_markdown(filename: str) -> str:
"""读取Excel文件并将其主要数据转换为Markdown表格格式。"""
@@ -73,7 +67,6 @@ def read_excel_as_markdown(filename: str) -> str:
except Exception as e:
return f"读取Excel出错{e}"
@tool
def fetch_webpage_content(url: str) -> str:
"""抓取给定URL的网页正文内容并返回清晰的纯文本。"""
@@ -91,7 +84,6 @@ def fetch_webpage_content(url: str) -> str:
except Exception as e:
return f"抓取网页时出错:{str(e)}"
# 工具列表和映射(全局常量)
AVAILABLE_TOOLS = [
get_current_temperature,

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆检索节点

View File

@@ -4,12 +4,11 @@ LangGraph 状态定义模块
"""
import operator
from typing import Annotated, Any
from typing import Annotated
from typing_extensions import TypedDict
from dataclasses import dataclass
from langchain_core.messages import AnyMessage
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
@@ -19,7 +18,6 @@ class MessagesState(TypedDict):
last_elapsed_time: float # 本次调用耗时(秒)
turns_since_last_summary: int # 距离上次生成摘要的轮数
@dataclass
class GraphContext:
"""图执行上下文"""

View File

@@ -1,79 +0,0 @@
"""
LangGraph 状态图构建模块 - 精简版,仅负责组装图
所有节点逻辑已拆分到独立模块
"""
from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.nodes import (
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
should_continue
)
from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder:
"""LangGraph 状态图构建器 - 仅负责组装图"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
self.mem0_client = Mem0Client(llm)
def build(self) -> StateGraph:
"""
构建未编译的状态图
Returns:
StateGraph 实例
"""
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
llm_call_node = create_llm_call_node(self.llm, self.tools)
tool_call_node = create_tool_call_node(self.tools_by_name)
summarize_node = create_summarize_node(self.mem0_client)
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
builder.add_node("finalize", finalize_node)
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{
"tool_node": "tool_node",
"summarize": "summarize",
"finalize": "finalize"
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("summarize", "finalize")
builder.add_edge("finalize", END)
return builder

View File

@@ -4,13 +4,12 @@ Mem0 记忆层客户端封装模块
"""
import asyncio
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict
from mem0 import AsyncMemory
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
from app.logger import info, warning, error
class Mem0Client:
"""Mem0 异步客户端封装类"""
@@ -37,8 +36,9 @@ class Mem0Client:
"provider": "qdrant",
"config": {
"url": QDRANT_URL, # 直接使用完整 URL
"api_key": QDRANT_API_KEY,
"collection_name": QDRANT_COLLECTION_NAME,
"embedding_model_dims": 768,
"embedding_model_dims": 1024,
}
},
"llm": {
@@ -50,7 +50,7 @@ class Mem0Client:
"embedder": {
"provider": "openai",
"config": {
"model": "embeddinggemma-300M-Q8_0",
"model": "Qwen3-Embedding-0.6B-Q8_0",
"api_key": LLAMACPP_API_KEY,
"openai_base_url": LLAMACPP_EMBEDDING_URL,
},

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.utils.logging import log_state_change
from app.logger import info, error
from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:

View File

@@ -3,21 +3,17 @@ LLM 调用节点模块
负责调用大语言模型并处理响应
"""
import asyncio
import time
from typing import Any, Dict
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda
from langgraph.runtime import Runtime
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.prompts import create_system_prompt
from app.utils.logging import log_state_change, print_llm_input
from app.graph.state import MessagesState
from app.agent.prompts import create_system_prompt
from app.utils.logging import log_state_change
from app.logger import debug, info, error
def create_llm_call_node(llm: BaseLLM, tools: list):
"""
工厂函数:创建 LLM 调用节点

View File

@@ -0,0 +1,38 @@
from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.logger import info
# 全局变量,在 GraphBuilder 中注入
_mem0_client: Mem0Client = None
def set_mem0_client(client: Mem0Client):
global _mem0_client
_mem0_client = client
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
if _mem0_client is None:
return {}
messages = state.get("messages", [])
if not messages:
return {}
last_msg = messages[-1]
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
# 触发词(可自行扩展)
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
if any(word in content for word in trigger_words):
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化
if not _mem0_client._initialized:
await _mem0_client.initialize()
# 将用户消息作为事实来源提交给 Mem0
mem0_messages = [{"role": "user", "content": content}]
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
return {} # 不修改状态

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆存储节点

View File

@@ -6,15 +6,13 @@
import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.utils.logging import log_state_change
from app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""
工厂函数:创建工具执行节点

View File

@@ -13,7 +13,7 @@ RAG 检索与生成模块
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
示例用法:
>>> from app.rag import RAGPipeline, create_rag_tool
>>> from app.rag.rag import RAGPipeline, create_rag_tool
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
>>> from langchain_openai import ChatOpenAI
>>>
@@ -34,16 +34,16 @@ RAG 检索与生成模块
>>> rag_tool = create_rag_tool(retriever=retriever, llm=llm)
"""
from .retriever import (
from app.rag.retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from .pipeline import RAGPipeline
from .tools import create_rag_tool, create_rag_tool_sync
from app.rag.reranker import LLaMaCPPReranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
from app.rag.pipeline import RAGPipeline
from app.rag.tools import create_rag_tool_sync
__all__ = [
@@ -65,6 +65,5 @@ __all__ = [
"RAGPipeline",
# 工具创建(供 Agent 使用)
"create_rag_tool",
"create_rag_tool_sync",
]

View File

@@ -1,6 +1,6 @@
# rag/fusion.py
from typing import List, Dict, Tuple
from typing import List, Dict
from langchain_core.documents import Document
def reciprocal_rank_fusion(

View File

@@ -2,15 +2,13 @@
import asyncio
import os
from typing import List, Optional
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from .retriever import create_qdrant_client # 可能不需要直接使用
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from app.rag.reranker import LLaMaCPPReranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
class RAGPipeline:
"""

View File

@@ -2,9 +2,8 @@
重排序器模块 (适配版)
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
"""
import os
import requests
from typing import List, Optional
from typing import List
from langchain_core.documents import Document
class LLaMaCPPReranker:

View File

@@ -11,7 +11,6 @@ RAG 系统使用示例(重构版)
import asyncio
import sys
import os
from pathlib import Path
from dotenv import load_dotenv
@@ -19,12 +18,12 @@ from dotenv import load_dotenv
load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType
from rag.pipeline import RAGPipeline
from rag.tools import create_rag_tool
from app.rag.pipeline import RAGPipeline
from app.rag.tools import create_rag_tool_sync
from pydantic import SecretStr
# 使用本地 LLM通过 OpenAI 兼容接口)
from langchain_openai import ChatOpenAI
@@ -32,7 +31,6 @@ from rag_core.retriever_factory import create_parent_retriever
load_dotenv()
def create_llm():
"""创建本地 vLLM 服务 LLM"""
vllm_base_url = os.getenv(
@@ -60,8 +58,7 @@ async def demonstrate_full_pipeline():
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
print("=" * 60)
retriever = retriever = create_parent_retriever(collection_name="my_docs", search_k=5)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
if retriever is None:
print("错误:检索器未初始化,请确保索引已构建。")
@@ -103,7 +100,6 @@ async def demonstrate_full_pipeline():
import traceback
traceback.print_exc()
async def demonstrate_tool_creation():
"""
演示创建 RAG 工具(供 Agent 使用)
@@ -119,12 +115,11 @@ async def demonstrate_tool_creation():
)
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM
llm = create_llm()
# 3. 创建工具
rag_tool = create_rag_tool(
rag_tool = create_rag_tool_sync(
retriever=retriever,
llm=llm,
num_queries=3,
@@ -136,18 +131,16 @@ async def demonstrate_tool_creation():
print(f"工具描述: {rag_tool.description[:100]}...")
# 4. 模拟 Agent 调用工具
query = "请告诉我 RAG 系统的核心组件有哪些"
query = "请告诉我 打虎英雄是谁"
print(f"\n模拟调用: {query}")
print("-" * 40)
result = await rag_tool.ainvoke({"query": query})
print(result[:800] + "..." if len(result) > 800 else result)
async def main():
await demonstrate_full_pipeline()
await demonstrate_tool_creation()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -4,11 +4,11 @@ RAG 工具模块
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
"""
from typing import Optional, Callable
from typing import Callable
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from .pipeline import RAGPipeline
from app.rag.pipeline import RAGPipeline
def create_rag_tool_sync(
retriever: BaseRetriever,

307
app/test_backend.py Normal file
View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
"""
完整后端测试 - 验证 Agent 所有功能
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
"""
import asyncio
import os
import sys
import uuid
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
load_dotenv()
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.agent.history import ThreadHistoryService
from app.logger import info, warning, error
# PostgreSQL 连接字符串
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable"
)
async def print_section(title):
"""打印测试区块标题"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70)
async def test_short_term_memory(agent_service):
"""测试短期记忆(同一 thread_id 继续对话)"""
await print_section("测试 1: 短期记忆Short-term Memory")
thread_id = str(uuid.uuid4())
user_id = "test_user_memory"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 第一轮对话
print("\n[第一轮] 发送消息: '我叫张三今年28岁'")
result1 = await agent_service.process_message(
"我叫张三今年28岁", thread_id, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 第二轮对话 - 测试记忆
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
result2 = await agent_service.process_message(
"我叫什么名字?今年多大?", thread_id, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证记忆是否存在
if "张三" in result2['reply'] or "28" in result2['reply']:
print("\n✅ 短期记忆测试通过!")
return True
else:
print("\n❌ 短期记忆测试失败!")
return False
async def test_tool_calling(agent_service):
"""测试工具调用RAG 搜索)"""
await print_section("测试 2: 工具调用Tool Calling")
thread_id = str(uuid.uuid4())
user_id = "test_user_tools"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 发送需要 RAG 搜索的问题
print("\n发送消息: '请告诉我,打虎英雄是谁?'")
result = await agent_service.process_message(
"请告诉我,打虎英雄是谁?", thread_id, "local", user_id
)
print(f"回复: {result['reply'][:200]}...")
# 检查是否调用了 RAG 工具(回复中会有水浒传相关内容)
if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']:
print("\n✅ 工具调用测试通过!")
return True
else:
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
return None
async def test_streaming(agent_service):
"""测试流式对话"""
await print_section("测试 3: 流式对话Streaming")
thread_id = str(uuid.uuid4())
user_id = "test_user_stream"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
print("流式输出: ", end="", flush=True)
full_reply = ""
chunk_count = 0
try:
async for chunk in agent_service.process_message_stream(
"用100字介绍一下AI人工智能", thread_id, "local", user_id
):
chunk_count += 1
if chunk.get("type") == "llm_token":
token = chunk.get("token", "")
print(token, end="", flush=True)
full_reply += token
elif chunk.get("type") == "state_update":
pass # 状态更新不显示
print(f"\n\n共收到 {chunk_count} 个 chunk")
print(f"完整回复长度: {len(full_reply)}")
if chunk_count > 0 and len(full_reply) > 10:
print("\n✅ 流式对话测试通过!")
return True
else:
print("\n❌ 流式对话测试失败!")
return False
except Exception as e:
print(f"\n❌ 流式对话异常: {e}")
return False
async def test_history_service(agent_service, history_service):
"""测试历史查询服务"""
await print_section("测试 4: 历史查询服务History Service")
user_id = "test_user_history"
# 先创建几个对话
print(f"\n为 user_id={user_id} 创建测试对话...")
thread_ids = []
for i in range(3):
thread_id = str(uuid.uuid4())
thread_ids.append(thread_id)
await agent_service.process_message(
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
)
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
# 1. 测试获取用户线程列表
print("\n[4.1] 测试获取用户线程列表...")
threads = await history_service.get_user_threads(user_id, limit=10)
print(f" 找到 {len(threads)} 个线程")
if len(threads) >= 3:
print(" ✅ 线程列表查询通过")
else:
print(" ⚠️ 线程数量少于预期")
# 2. 测试获取单个线程的消息历史
if thread_ids:
test_thread_id = thread_ids[0]
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
messages = await history_service.get_thread_messages(test_thread_id)
print(f" 找到 {len(messages)} 条消息")
if len(messages) >= 2: # 至少有一问一答
print(" ✅ 消息历史查询通过")
else:
print(" ⚠️ 消息数量少于预期")
# 3. 测试获取线程摘要
print(f"\n[4.3] 测试获取线程摘要...")
summary = await history_service.get_thread_summary(test_thread_id)
print(f" 摘要: {summary.get('summary', '')[:50]}...")
print(f" 消息数: {summary.get('message_count', 0)}")
if summary.get('message_count', 0) > 0:
print(" ✅ 线程摘要查询通过")
else:
print(" ⚠️ 摘要查询结果不确定")
return len(threads) >= 3
async def test_long_term_memory(agent_service):
"""测试长期记忆mem0"""
await print_section("测试 5: 长期记忆Long-term Memory - mem0")
thread_id1 = str(uuid.uuid4())
thread_id2 = str(uuid.uuid4()) # 不同的线程
user_id = "test_user_longterm"
print(f"\n使用 user_id: {user_id}")
print(f"线程 1: {thread_id1[:8]}...")
print(f"线程 2: {thread_id2[:8]}...")
# 在第一个线程中保存信息
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
result1 = await agent_service.process_message(
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 等待一下,让 mem0 保存
await asyncio.sleep(1)
# 在第二个线程中询问(不同的 thread_id
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
result2 = await agent_service.process_message(
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证长期记忆
if "小白" in result2['reply'] or "" in result2['reply']:
print("\n✅ 长期记忆测试通过!")
return True
else:
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
return None
async def main():
"""主测试函数"""
print("\n" + "=" * 70)
print(" 后端完整功能测试")
print("=" * 70)
results = {}
try:
# 创建数据库连接和服务
print("\n正在初始化数据库连接...")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
print("✅ 数据库连接成功")
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
print("✅ Agent 服务初始化成功")
history_service = ThreadHistoryService(checkpointer)
print("✅ 历史服务初始化成功")
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
# 运行测试
results["短期记忆"] = await test_short_term_memory(agent_service)
await asyncio.sleep(1)
results["工具调用"] = await test_tool_calling(agent_service)
await asyncio.sleep(1)
results["流式对话"] = await test_streaming(agent_service)
await asyncio.sleep(1)
results["历史查询"] = await test_history_service(agent_service, history_service)
await asyncio.sleep(1)
results["长期记忆"] = await test_long_term_memory(agent_service)
await asyncio.sleep(1)
# 打印总结
await print_section("测试总结")
print("\n测试结果:")
print("-" * 40)
pass_count = 0
fail_count = 0
skip_count = 0
for test_name, result in results.items():
if result is True:
status = "✅ 通过"
pass_count += 1
elif result is False:
status = "❌ 失败"
fail_count += 1
else:
status = "⚠️ 待验证"
skip_count += 1
print(f" {test_name:12s}: {status}")
print("-" * 40)
print(f"总计: {len(results)} 个测试")
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
if fail_count == 0:
print("\n🎉 所有核心测试通过!")
else:
print(f"\n⚠️ 有 {fail_count} 个测试失败")
except Exception as e:
error(f"\n❌ 测试运行异常: {e}")
import traceback
traceback.print_exc()
return 1
return 0 if fail_count == 0 else 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@@ -3,7 +3,7 @@ AI Agent 前端模块
采用分层架构设计包含配置、状态、API客户端和UI组件
"""
from .logger import debug, info, warning, error
from frontend.logger import debug, info, warning, error
__version__ = "2.0.0"
__all__ = ["debug", "info", "warning", "error"]

View File

@@ -9,8 +9,6 @@ from datetime import datetime
# 使用绝对导入
from frontend.state import AppState
from frontend.api_client import api_client
from frontend.config import config
def render_sidebar():
"""渲染左侧栏"""
@@ -25,7 +23,6 @@ def render_sidebar():
st.divider()
_render_user_section()
def _render_user_section():
"""渲染用户登录区域"""
# st.header("👤 用户") # 移除显眼的标题,改用更柔和的 caption
@@ -36,7 +33,6 @@ def _render_user_section():
else:
_render_user_info()
def _render_login_form():
"""渲染登录表单"""
username = st.text_input(
@@ -54,7 +50,6 @@ def _render_login_form():
# st.info("💡 建议登录以隔离对话历史") # 移除多余色块
def _render_user_info():
"""渲染用户信息"""
st.markdown(f"**当前用户**: `{AppState.get_user_id()}`")
@@ -64,7 +59,6 @@ def _render_user_info():
_refresh_threads()
st.rerun()
def _render_history_section():
"""渲染历史对话列表"""
col1, col2 = st.columns([3, 1])
@@ -76,7 +70,6 @@ def _render_history_section():
_render_thread_list()
def _render_history_actions():
"""渲染历史操作按钮"""
# 移除了 type="primary",让它变成普通的线框按钮,不再是大红块
@@ -84,7 +77,6 @@ def _render_history_actions():
AppState.start_new_thread()
st.rerun()
def _render_thread_list():
"""渲染线程列表"""
# 仅在初次加载时拉取,或由外部主动调用 _refresh_threads() 更新
@@ -101,7 +93,6 @@ def _render_thread_list():
for thread in threads:
_render_thread_item(thread)
def _render_thread_item(thread: dict):
"""
渲染单个线程项
@@ -130,7 +121,6 @@ def _render_thread_item(thread: dict):
):
_load_thread(thread_id)
def _format_time(time_str: str) -> str:
"""
格式化时间字符串
@@ -150,13 +140,11 @@ def _format_time(time_str: str) -> str:
except Exception:
return time_str[:10]
def _refresh_threads():
"""刷新历史线程列表"""
threads = api_client.get_user_threads(AppState.get_user_id())
AppState.set_threads(threads)
def _load_thread(thread_id: str):
"""
加载指定线程的消息历史

View File

@@ -7,7 +7,7 @@ import uuid
from typing import List, Dict, Any
import streamlit as st
from .config import config
from frontend.config import config
class AppState:

View File

@@ -4,10 +4,10 @@ RAG Core - 公共 RAG 组件包
提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。
"""
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .store import PostgresDocStore, create_docstore
from .retriever_factory import create_parent_retriever
from rag_core.embedders import LlamaCppEmbedder
from rag_core.vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from rag_core.store import PostgresDocStore, create_docstore
from rag_core.retriever_factory import create_parent_retriever
__all__ = [

View File

@@ -5,11 +5,9 @@
import os
import httpx
from typing import List, Optional
from urllib.parse import urljoin
from langchain_core.embeddings import Embeddings
class LlamaCppEmbedder:
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
@@ -17,7 +15,7 @@ class LlamaCppEmbedder:
self,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
model: str = "embeddinggemma-300M-Q8_0",
model: str = "Qwen3-Embedding-0.6B-Q8_0",
):
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
@@ -71,7 +69,6 @@ class LlamaCppEmbedder:
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
class _LlamaCppLangchainAdapter(Embeddings):
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""

View File

@@ -2,14 +2,7 @@
from langchain_core.embeddings import Embeddings
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter
from rag_indexer.splitters import SplitterType, get_splitter
import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
@@ -17,7 +10,6 @@ from langchain_classic.retrievers import ParentDocumentRetriever
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
embeddings: Optional[Embeddings] = None,

View File

@@ -15,8 +15,8 @@
"""
from .postgres import PostgresDocStore
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
from rag_core.store.postgres import PostgresDocStore
from rag_core.store.factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
__version__ = "2.0.0"

View File

@@ -9,7 +9,7 @@ import logging
from typing import Optional, Tuple
from langchain_core.stores import BaseStore
from .postgres import PostgresDocStore
from rag_core.store.postgres import PostgresDocStore
logger = logging.getLogger(__name__)

View File

@@ -4,12 +4,10 @@
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
@@ -18,7 +16,6 @@ import asyncpg
logger = logging.getLogger(__name__)
class PostgresDocStore(BaseStore[str, Any]):
"""
异步 PostgreSQL 文档存储实现。

View File

@@ -13,7 +13,7 @@ from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from .client import create_qdrant_client
from rag_core.client import create_qdrant_client
logger = logging.getLogger(__name__)
@@ -35,7 +35,7 @@ class QdrantVectorStore:
self._last_connection_time: Optional[float] = None
if embeddings is None:
from .embedders import LlamaCppEmbedder
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
@@ -96,7 +96,7 @@ class QdrantVectorStore:
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
from .embedders import LlamaCppEmbedder
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()

View File

@@ -23,9 +23,9 @@ Offline RAG Indexer module.
>>> await builder.build_from_file("document.pdf")
"""
from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
from rag_indexer.loaders import DocumentLoader
from rag_indexer.splitters import SplitterType, get_splitter
# 从 rag_core 重新导出常用组件
from rag_core import (

View File

@@ -8,24 +8,21 @@ import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple
from typing import List, Union, Optional, Any, Dict
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from qdrant_client.http.exceptions import ResponseHandlingException
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
from rag_indexer.loaders import DocumentLoader
from rag_indexer.splitters import SplitterType, get_splitter
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
logger = logging.getLogger(__name__)
# ---------- 配置数据类 ----------
@dataclass
class DocstoreConfig:
@@ -36,7 +33,6 @@ class DocstoreConfig:
# 若要从外部注入已创建好的 docstore可直接设置此字段
instance: Optional[BaseStore] = None
@dataclass
class IndexBuilderConfig:
"""索引构建器配置。"""
@@ -60,7 +56,6 @@ class IndexBuilderConfig:
# 其他切分器参数(当 splitter_type 非父子块时使用)
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
# ---------- 索引构建器 ----------
class IndexBuilder:
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""

View File

@@ -250,7 +250,7 @@ start_embedding() {
echo -e "${BLUE}🚀 启动 llama.cpp Embedding 容器...${NC}"
# 检查模型文件
if [ ! -f "/home/huang/Study/AIModel/GGUF/embeddinggemma-300M-Q8_0.gguf" ]; then
if [ ! -f "/home/huang/Study/AIModel/GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf" ]; then
echo -e "${RED}✗ 错误Embedding 模型文件不存在${NC}"
exit 1
fi
@@ -263,13 +263,16 @@ start_embedding() {
--device=/dev/dri \
-v /home/huang/Study/AIModel/GGUF:/models \
-p 8082:8080 \
-e LLAMA_ARG_CTX_SIZE=16384 \
-e LLAMA_ARG_N_PARALLEL=1 \
-e LLAMA_ARG_BATCH=512 \
-e LLAMA_ARG_N_GPU_LAYERS=99 \
-e LLAMA_ARG_API_KEY=huang1998 \
ghcr.io/ggml-org/llama.cpp:server-rocm \
-m /models/embeddinggemma-300M-Q8_0.gguf \
-m /models/Qwen3-Embedding-0.6B-Q8_0.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 99 \
--embeddings \
-c 512
--embeddings
echo -e "${GREEN}✓ llama.cpp Embedding 容器已启动 (端口 8082)${NC}"
sleep 5