修改引用逻辑,修改长期记忆bug

This commit is contained in:
2026-04-20 15:55:58 +08:00
parent 4e981e9dcf
commit 3143e0e4e6
39 changed files with 444 additions and 246 deletions

View File

@@ -2,7 +2,7 @@
AI Agent 应用模块
"""
from .agent import AIAgentService
from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.agent import AIAgentService
from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]

7
app/agent/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
Agent 子模块
"""
from app.agent.service import AIAgentService
__all__ = ["AIAgentService"]

View File

@@ -3,11 +3,9 @@
利用 LangGraph 的 checkpointer 获取对话历史和摘要
"""
from typing import List, Dict, Any, Optional
import logging
from typing import List, Dict, Any
from app.logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""

View File

@@ -1,6 +1,5 @@
# app/prompts.py
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import BaseTool
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
"""
@@ -11,7 +10,7 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
tool_descs = []
for tool in tools:
# 提取工具名称和描述的第一行
name = getattr(tool, 'name', tool.__name__)
name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
desc = (tool.description or "").split('\n')[0]
tool_descs.append(f"- {name}: {desc}")
tools_section = "\n".join(tool_descs)

View File

@@ -3,27 +3,17 @@ AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import os
import json
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import SecretStr
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.rag import RAGPipeline
from app.rag.tools import create_rag_tool_sync
from rag_core import create_parent_retriever
from app.llm_factory import LLMFactory
from app.rag_initializer import init_rag_tool
from app.logger import debug, info, warning, error
from app.graph.graph_builder import GraphBuilder, GraphContext
from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.agent.llm_factory import LLMFactory
from app.agent.rag_initializer import init_rag_tool
from app.logger import info, warning
load_dotenv()
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer

View File

@@ -16,7 +16,7 @@ from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.history import ThreadHistoryService
from app.logger import debug, info, warning, error
from app.logger import info, error
# 加载 .env 文件
load_dotenv()
@@ -28,7 +28,6 @@ DB_URI = os.getenv(
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
@@ -53,7 +52,6 @@ async def lifespan(app: FastAPI):
# 5. 关闭时自动清理数据库连接async with 负责)
info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域)
@@ -65,14 +63,12 @@ app.add_middleware(
allow_headers=["*"],
)
# ========== 健康检查端点 ==========
@app.get("/health")
async def health_check():
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
return {"status": "ok", "service": "ai-agent-backend"}
# ========== Pydantic 模型 ==========
class ChatRequest(BaseModel):
message: str
@@ -80,7 +76,6 @@ class ChatRequest(BaseModel):
model: str = "zhipu"
user_id: str = "default_user"
class ChatResponse(BaseModel):
reply: str
thread_id: str
@@ -90,18 +85,15 @@ class ChatResponse(BaseModel):
total_tokens: int = 0
elapsed_time: float = 0.0
# ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service
def get_history_service(request: Request) -> ThreadHistoryService:
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
return request.app.state.history_service
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
@@ -135,7 +127,6 @@ async def chat_endpoint(
elapsed_time=elapsed_time
)
# ========== 历史查询接口 ==========
@app.get("/threads")
async def list_threads(
@@ -147,7 +138,6 @@ async def list_threads(
threads = await history_service.get_user_threads(user_id, limit)
return {"threads": threads}
@app.get("/thread/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
@@ -158,7 +148,6 @@ async def get_thread_messages(
messages = await history_service.get_thread_messages(thread_id)
return {"messages": messages}
@app.get("/thread/{thread_id}/summary")
async def get_thread_summary(
thread_id: str,
@@ -169,7 +158,6 @@ async def get_thread_summary(
summary = await history_service.get_thread_summary(thread_id)
return summary
# ========== 流式对话接口 ==========
@app.post("/chat/stream")
async def chat_stream_endpoint(
@@ -204,7 +192,6 @@ async def chat_stream_endpoint(
}
)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(
@@ -228,7 +215,6 @@ async def websocket_endpoint(
except WebSocketDisconnect:
pass
if __name__ == "__main__":
import uvicorn
# 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突)

View File

@@ -18,6 +18,7 @@ MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
# Qdrant 向量数据库地址
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")

8
app/graph/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
Graph 子模块
"""
from app.graph.graph_builder import GraphBuilder
from app.graph.state import MessagesState, GraphContext
__all__ = ["GraphBuilder", "MessagesState", "GraphContext"]

View File

@@ -5,18 +5,17 @@ LangGraph 状态图构建模块 - 精简版,仅负责组装图
from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.nodes import (
should_continue,
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
should_continue
finalize_node,
)
from app.nodes.memory_trigger import memory_trigger_node, set_mem0_client
from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder:
@@ -45,6 +44,9 @@ class GraphBuilder:
Returns:
StateGraph 实例
"""
# 注入全局客户端
set_mem0_client(self.mem0_client)
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
@@ -55,6 +57,7 @@ class GraphBuilder:
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("memory_trigger", memory_trigger_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
@@ -62,7 +65,8 @@ class GraphBuilder:
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_edge("retrieve_memory", "memory_trigger")
builder.add_edge("memory_trigger", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,

View File

@@ -3,7 +3,6 @@
"""
# 标准库
import os
from pathlib import Path
# 第三方库
@@ -13,7 +12,6 @@ import requests
from bs4 import BeautifulSoup
from langchain_core.tools import tool
def _file_allow_check(filename: str) -> Path:
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
allowed_dir = Path("./user_docs").resolve()
@@ -28,13 +26,11 @@ def _file_allow_check(filename: str) -> Path:
return file_path
@tool
def get_current_temperature(location: str) -> str:
"""获取指定地点的当前温度。"""
return f'当前{location}的温度为25℃'
@tool
def read_local_file(filename: str) -> str:
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
@@ -46,7 +42,6 @@ def read_local_file(filename: str) -> str:
except Exception as e:
return f"读取文件时出错:{str(e)}"
@tool
def read_pdf_summary(filename: str) -> str:
"""读取PDF文件并返回内容文本摘要。"""
@@ -61,7 +56,6 @@ def read_pdf_summary(filename: str) -> str:
except Exception as e:
return f"读取PDF出错{e}"
@tool
def read_excel_as_markdown(filename: str) -> str:
"""读取Excel文件并将其主要数据转换为Markdown表格格式。"""
@@ -73,7 +67,6 @@ def read_excel_as_markdown(filename: str) -> str:
except Exception as e:
return f"读取Excel出错{e}"
@tool
def fetch_webpage_content(url: str) -> str:
"""抓取给定URL的网页正文内容并返回清晰的纯文本。"""
@@ -91,7 +84,6 @@ def fetch_webpage_content(url: str) -> str:
except Exception as e:
return f"抓取网页时出错:{str(e)}"
# 工具列表和映射(全局常量)
AVAILABLE_TOOLS = [
get_current_temperature,

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆检索节点

View File

@@ -4,12 +4,11 @@ LangGraph 状态定义模块
"""
import operator
from typing import Annotated, Any
from typing import Annotated
from typing_extensions import TypedDict
from dataclasses import dataclass
from langchain_core.messages import AnyMessage
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
@@ -19,7 +18,6 @@ class MessagesState(TypedDict):
last_elapsed_time: float # 本次调用耗时(秒)
turns_since_last_summary: int # 距离上次生成摘要的轮数
@dataclass
class GraphContext:
"""图执行上下文"""

View File

@@ -1,79 +0,0 @@
"""
LangGraph 状态图构建模块 - 精简版,仅负责组装图
所有节点逻辑已拆分到独立模块
"""
from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.nodes import (
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
should_continue
)
from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder:
"""LangGraph 状态图构建器 - 仅负责组装图"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
self.mem0_client = Mem0Client(llm)
def build(self) -> StateGraph:
"""
构建未编译的状态图
Returns:
StateGraph 实例
"""
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
llm_call_node = create_llm_call_node(self.llm, self.tools)
tool_call_node = create_tool_call_node(self.tools_by_name)
summarize_node = create_summarize_node(self.mem0_client)
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
builder.add_node("finalize", finalize_node)
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{
"tool_node": "tool_node",
"summarize": "summarize",
"finalize": "finalize"
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("summarize", "finalize")
builder.add_edge("finalize", END)
return builder

View File

@@ -4,13 +4,12 @@ Mem0 记忆层客户端封装模块
"""
import asyncio
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict
from mem0 import AsyncMemory
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
from app.logger import info, warning, error
class Mem0Client:
"""Mem0 异步客户端封装类"""
@@ -37,8 +36,9 @@ class Mem0Client:
"provider": "qdrant",
"config": {
"url": QDRANT_URL, # 直接使用完整 URL
"api_key": QDRANT_API_KEY,
"collection_name": QDRANT_COLLECTION_NAME,
"embedding_model_dims": 768,
"embedding_model_dims": 1024,
}
},
"llm": {
@@ -50,7 +50,7 @@ class Mem0Client:
"embedder": {
"provider": "openai",
"config": {
"model": "embeddinggemma-300M-Q8_0",
"model": "Qwen3-Embedding-0.6B-Q8_0",
"api_key": LLAMACPP_API_KEY,
"openai_base_url": LLAMACPP_EMBEDDING_URL,
},

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.utils.logging import log_state_change
from app.logger import info, error
from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:

View File

@@ -3,21 +3,17 @@ LLM 调用节点模块
负责调用大语言模型并处理响应
"""
import asyncio
import time
from typing import Any, Dict
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda
from langgraph.runtime import Runtime
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.prompts import create_system_prompt
from app.utils.logging import log_state_change, print_llm_input
from app.graph.state import MessagesState
from app.agent.prompts import create_system_prompt
from app.utils.logging import log_state_change
from app.logger import debug, info, error
def create_llm_call_node(llm: BaseLLM, tools: list):
"""
工厂函数:创建 LLM 调用节点

View File

@@ -0,0 +1,38 @@
from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.logger import info
# 全局变量,在 GraphBuilder 中注入
_mem0_client: Mem0Client = None
def set_mem0_client(client: Mem0Client):
global _mem0_client
_mem0_client = client
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
if _mem0_client is None:
return {}
messages = state.get("messages", [])
if not messages:
return {}
last_msg = messages[-1]
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
# 触发词(可自行扩展)
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
if any(word in content for word in trigger_words):
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化
if not _mem0_client._initialized:
await _mem0_client.initialize()
# 将用户消息作为事实来源提交给 Mem0
mem0_messages = [{"role": "user", "content": content}]
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
return {} # 不修改状态

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆存储节点

View File

@@ -6,15 +6,13 @@
import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
from app.graph.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.utils.logging import log_state_change
from app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""
工厂函数:创建工具执行节点

View File

@@ -13,7 +13,7 @@ RAG 检索与生成模块
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
示例用法:
>>> from app.rag import RAGPipeline, create_rag_tool
>>> from app.rag.rag import RAGPipeline, create_rag_tool
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
>>> from langchain_openai import ChatOpenAI
>>>
@@ -34,16 +34,16 @@ RAG 检索与生成模块
>>> rag_tool = create_rag_tool(retriever=retriever, llm=llm)
"""
from .retriever import (
from app.rag.retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from .pipeline import RAGPipeline
from .tools import create_rag_tool, create_rag_tool_sync
from app.rag.reranker import LLaMaCPPReranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
from app.rag.pipeline import RAGPipeline
from app.rag.tools import create_rag_tool_sync
__all__ = [
@@ -65,6 +65,5 @@ __all__ = [
"RAGPipeline",
# 工具创建(供 Agent 使用)
"create_rag_tool",
"create_rag_tool_sync",
]

View File

@@ -1,6 +1,6 @@
# rag/fusion.py
from typing import List, Dict, Tuple
from typing import List, Dict
from langchain_core.documents import Document
def reciprocal_rank_fusion(

View File

@@ -2,15 +2,13 @@
import asyncio
import os
from typing import List, Optional
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from .retriever import create_qdrant_client # 可能不需要直接使用
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from app.rag.reranker import LLaMaCPPReranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
class RAGPipeline:
"""

View File

@@ -2,9 +2,8 @@
重排序器模块 (适配版)
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
"""
import os
import requests
from typing import List, Optional
from typing import List
from langchain_core.documents import Document
class LLaMaCPPReranker:

View File

@@ -11,7 +11,6 @@ RAG 系统使用示例(重构版)
import asyncio
import sys
import os
from pathlib import Path
from dotenv import load_dotenv
@@ -19,12 +18,12 @@ from dotenv import load_dotenv
load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType
from rag.pipeline import RAGPipeline
from rag.tools import create_rag_tool
from app.rag.pipeline import RAGPipeline
from app.rag.tools import create_rag_tool_sync
from pydantic import SecretStr
# 使用本地 LLM通过 OpenAI 兼容接口)
from langchain_openai import ChatOpenAI
@@ -32,7 +31,6 @@ from rag_core.retriever_factory import create_parent_retriever
load_dotenv()
def create_llm():
"""创建本地 vLLM 服务 LLM"""
vllm_base_url = os.getenv(
@@ -60,8 +58,7 @@ async def demonstrate_full_pipeline():
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
print("=" * 60)
retriever = retriever = create_parent_retriever(collection_name="my_docs", search_k=5)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
if retriever is None:
print("错误:检索器未初始化,请确保索引已构建。")
@@ -103,7 +100,6 @@ async def demonstrate_full_pipeline():
import traceback
traceback.print_exc()
async def demonstrate_tool_creation():
"""
演示创建 RAG 工具(供 Agent 使用)
@@ -119,12 +115,11 @@ async def demonstrate_tool_creation():
)
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM
llm = create_llm()
# 3. 创建工具
rag_tool = create_rag_tool(
rag_tool = create_rag_tool_sync(
retriever=retriever,
llm=llm,
num_queries=3,
@@ -136,18 +131,16 @@ async def demonstrate_tool_creation():
print(f"工具描述: {rag_tool.description[:100]}...")
# 4. 模拟 Agent 调用工具
query = "请告诉我 RAG 系统的核心组件有哪些"
query = "请告诉我 打虎英雄是谁"
print(f"\n模拟调用: {query}")
print("-" * 40)
result = await rag_tool.ainvoke({"query": query})
print(result[:800] + "..." if len(result) > 800 else result)
async def main():
await demonstrate_full_pipeline()
await demonstrate_tool_creation()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -4,11 +4,11 @@ RAG 工具模块
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
"""
from typing import Optional, Callable
from typing import Callable
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from .pipeline import RAGPipeline
from app.rag.pipeline import RAGPipeline
def create_rag_tool_sync(
retriever: BaseRetriever,

307
app/test_backend.py Normal file
View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
"""
完整后端测试 - 验证 Agent 所有功能
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
"""
import asyncio
import os
import sys
import uuid
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
load_dotenv()
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.agent.history import ThreadHistoryService
from app.logger import info, warning, error
# PostgreSQL 连接字符串
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable"
)
async def print_section(title):
"""打印测试区块标题"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70)
async def test_short_term_memory(agent_service):
"""测试短期记忆(同一 thread_id 继续对话)"""
await print_section("测试 1: 短期记忆Short-term Memory")
thread_id = str(uuid.uuid4())
user_id = "test_user_memory"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 第一轮对话
print("\n[第一轮] 发送消息: '我叫张三今年28岁'")
result1 = await agent_service.process_message(
"我叫张三今年28岁", thread_id, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 第二轮对话 - 测试记忆
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
result2 = await agent_service.process_message(
"我叫什么名字?今年多大?", thread_id, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证记忆是否存在
if "张三" in result2['reply'] or "28" in result2['reply']:
print("\n✅ 短期记忆测试通过!")
return True
else:
print("\n❌ 短期记忆测试失败!")
return False
async def test_tool_calling(agent_service):
"""测试工具调用RAG 搜索)"""
await print_section("测试 2: 工具调用Tool Calling")
thread_id = str(uuid.uuid4())
user_id = "test_user_tools"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 发送需要 RAG 搜索的问题
print("\n发送消息: '请告诉我,打虎英雄是谁?'")
result = await agent_service.process_message(
"请告诉我,打虎英雄是谁?", thread_id, "local", user_id
)
print(f"回复: {result['reply'][:200]}...")
# 检查是否调用了 RAG 工具(回复中会有水浒传相关内容)
if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']:
print("\n✅ 工具调用测试通过!")
return True
else:
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
return None
async def test_streaming(agent_service):
"""测试流式对话"""
await print_section("测试 3: 流式对话Streaming")
thread_id = str(uuid.uuid4())
user_id = "test_user_stream"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
print("流式输出: ", end="", flush=True)
full_reply = ""
chunk_count = 0
try:
async for chunk in agent_service.process_message_stream(
"用100字介绍一下AI人工智能", thread_id, "local", user_id
):
chunk_count += 1
if chunk.get("type") == "llm_token":
token = chunk.get("token", "")
print(token, end="", flush=True)
full_reply += token
elif chunk.get("type") == "state_update":
pass # 状态更新不显示
print(f"\n\n共收到 {chunk_count} 个 chunk")
print(f"完整回复长度: {len(full_reply)}")
if chunk_count > 0 and len(full_reply) > 10:
print("\n✅ 流式对话测试通过!")
return True
else:
print("\n❌ 流式对话测试失败!")
return False
except Exception as e:
print(f"\n❌ 流式对话异常: {e}")
return False
async def test_history_service(agent_service, history_service):
"""测试历史查询服务"""
await print_section("测试 4: 历史查询服务History Service")
user_id = "test_user_history"
# 先创建几个对话
print(f"\n为 user_id={user_id} 创建测试对话...")
thread_ids = []
for i in range(3):
thread_id = str(uuid.uuid4())
thread_ids.append(thread_id)
await agent_service.process_message(
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
)
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
# 1. 测试获取用户线程列表
print("\n[4.1] 测试获取用户线程列表...")
threads = await history_service.get_user_threads(user_id, limit=10)
print(f" 找到 {len(threads)} 个线程")
if len(threads) >= 3:
print(" ✅ 线程列表查询通过")
else:
print(" ⚠️ 线程数量少于预期")
# 2. 测试获取单个线程的消息历史
if thread_ids:
test_thread_id = thread_ids[0]
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
messages = await history_service.get_thread_messages(test_thread_id)
print(f" 找到 {len(messages)} 条消息")
if len(messages) >= 2: # 至少有一问一答
print(" ✅ 消息历史查询通过")
else:
print(" ⚠️ 消息数量少于预期")
# 3. 测试获取线程摘要
print(f"\n[4.3] 测试获取线程摘要...")
summary = await history_service.get_thread_summary(test_thread_id)
print(f" 摘要: {summary.get('summary', '')[:50]}...")
print(f" 消息数: {summary.get('message_count', 0)}")
if summary.get('message_count', 0) > 0:
print(" ✅ 线程摘要查询通过")
else:
print(" ⚠️ 摘要查询结果不确定")
return len(threads) >= 3
async def test_long_term_memory(agent_service):
"""测试长期记忆mem0"""
await print_section("测试 5: 长期记忆Long-term Memory - mem0")
thread_id1 = str(uuid.uuid4())
thread_id2 = str(uuid.uuid4()) # 不同的线程
user_id = "test_user_longterm"
print(f"\n使用 user_id: {user_id}")
print(f"线程 1: {thread_id1[:8]}...")
print(f"线程 2: {thread_id2[:8]}...")
# 在第一个线程中保存信息
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
result1 = await agent_service.process_message(
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 等待一下,让 mem0 保存
await asyncio.sleep(1)
# 在第二个线程中询问(不同的 thread_id
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
result2 = await agent_service.process_message(
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证长期记忆
if "小白" in result2['reply'] or "" in result2['reply']:
print("\n✅ 长期记忆测试通过!")
return True
else:
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
return None
async def main():
"""主测试函数"""
print("\n" + "=" * 70)
print(" 后端完整功能测试")
print("=" * 70)
results = {}
try:
# 创建数据库连接和服务
print("\n正在初始化数据库连接...")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
print("✅ 数据库连接成功")
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
print("✅ Agent 服务初始化成功")
history_service = ThreadHistoryService(checkpointer)
print("✅ 历史服务初始化成功")
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
# 运行测试
results["短期记忆"] = await test_short_term_memory(agent_service)
await asyncio.sleep(1)
results["工具调用"] = await test_tool_calling(agent_service)
await asyncio.sleep(1)
results["流式对话"] = await test_streaming(agent_service)
await asyncio.sleep(1)
results["历史查询"] = await test_history_service(agent_service, history_service)
await asyncio.sleep(1)
results["长期记忆"] = await test_long_term_memory(agent_service)
await asyncio.sleep(1)
# 打印总结
await print_section("测试总结")
print("\n测试结果:")
print("-" * 40)
pass_count = 0
fail_count = 0
skip_count = 0
for test_name, result in results.items():
if result is True:
status = "✅ 通过"
pass_count += 1
elif result is False:
status = "❌ 失败"
fail_count += 1
else:
status = "⚠️ 待验证"
skip_count += 1
print(f" {test_name:12s}: {status}")
print("-" * 40)
print(f"总计: {len(results)} 个测试")
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
if fail_count == 0:
print("\n🎉 所有核心测试通过!")
else:
print(f"\n⚠️ 有 {fail_count} 个测试失败")
except Exception as e:
error(f"\n❌ 测试运行异常: {e}")
import traceback
traceback.print_exc()
return 1
return 0 if fail_count == 0 else 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)