重构代码,统一config配置
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 47m14s

This commit is contained in:
2026-04-21 11:02:16 +08:00
parent 726236eaff
commit 8b354b7ccc
50 changed files with 4025 additions and 6 deletions

8
backend/app/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
AI Agent 应用模块
"""
from ..agent import AIAgentService
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]

View File

@@ -0,0 +1,7 @@
"""
Agent 子模块
"""
from .service import AIAgentService
__all__ = ["AIAgentService"]

View File

@@ -0,0 +1,185 @@
"""
历史对话查询模块
利用 LangGraph 的 checkpointer 获取对话历史和摘要
"""
from typing import List, Dict, Any
from ..logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""
def __init__(self, checkpointer):
self.checkpointer = checkpointer
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""
获取指定用户的所有线程摘要信息
Args:
user_id: 用户 ID
limit: 返回数量限制
Returns:
线程列表,每个包含 thread_id, last_updated, summary, message_count
"""
try:
# 查询 checkpoints 表获取用户的线程列表
async with self.checkpointer.conn.cursor() as cur:
# 在较新的 LangGraph 版本中AsyncPostgresSaver 创建的 checkpoints 表
# 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。
# 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。
# 另外,用户的 metadata 存储在 metadata JSONB 列中。
query = """
SELECT
thread_id,
MAX(checkpoint_id) as last_updated
FROM checkpoints
WHERE metadata->>'user_id' = %s
GROUP BY thread_id
ORDER BY last_updated DESC
LIMIT %s
"""
await cur.execute(query, (user_id, limit))
rows = await cur.fetchall()
threads = []
for row in rows:
thread_id = row['thread_id']
# 获取该线程的状态
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict):
messages = state.checkpoint.get("channel_values", {}).get("messages", [])
if messages:
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
threads.append({
"thread_id": thread_id,
# checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识
"last_updated": row['last_updated'] if row['last_updated'] else "",
"summary": summary,
"message_count": message_count
})
return threads
except Exception as e:
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
return []
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
"""
获取指定线程的完整消息历史
Args:
thread_id: 线程 ID
Returns:
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
"""
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None:
return []
messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else []
if not messages:
return []
# 转换 LangChain 消息对象为字典
result = []
for msg in messages:
# 跳过 system 消息
if hasattr(msg, 'type') and msg.type == "system":
continue
if hasattr(msg, 'type'):
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
result.append({
"role": role,
"content": msg.content
})
elif isinstance(msg, dict):
role = msg.get("role", msg.get("type", "unknown"))
if role in ["human", "user"]:
role = "user"
elif role in ["ai", "assistant"]:
role = "assistant"
result.append({
"role": role,
"content": msg.get("content", "")
})
return result
except Exception as e:
error(f"获取线程消息历史失败: {e}")
return []
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
"""
获取线程摘要(用于历史列表展示)
Args:
thread_id: 线程 ID
Returns:
包含摘要信息的字典
"""
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None or not state.values:
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
messages = state.values.get("messages", [])
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
# 获取最后更新时间
last_updated = ""
if state.metadata and "created_at" in state.metadata:
last_updated = state.metadata["created_at"].isoformat()
return {
"thread_id": thread_id,
"summary": summary,
"message_count": message_count,
"last_updated": last_updated
}
except Exception as e:
error(f"获取线程摘要失败: {e}")
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
def _extract_summary(self, messages: List) -> str:
"""
从消息列表中提取摘要
策略:
1. 如果有 summarize 节点生成的 summary优先使用
2. 否则使用第一条用户消息的前 50 字
"""
# 查找是否有 summary 字段
for msg in messages:
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
return msg.additional_kwargs['summary']
elif isinstance(msg, dict) and msg.get('summary'):
return msg['summary']
# 使用第一条用户消息作为摘要
for msg in messages:
if hasattr(msg, 'type') and msg.type == "human":
content = msg.content
return content[:50] + "..." if len(content) > 50 else content
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
content = msg.get("content", "")
return content[:50] + "..." if len(content) > 50 else content
return "空对话"

View File

@@ -0,0 +1,57 @@
# app/llm_factory.py
import os
from ..config import ZHIPUAI_API_KEY, DEEPSEEK_API_KEY, VLLM_BASE_URL, LLAMACPP_API_KEY
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
class LLMFactory:
@staticmethod
def create_zhipu():
api_key = ZHIPUAI_API_KEY
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=api_key,
temperature=0.1,
max_tokens=4096,
timeout=120.0,
max_retries=3,
streaming=True,
)
@staticmethod
def create_deepseek():
api_key = DEEPSEEK_API_KEY
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set")
return ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(api_key),
model="deepseek-reasoner",
temperature=0.1,
max_tokens=4096,
timeout=60.0,
max_retries=2,
streaming=True,
)
@staticmethod
def create_local():
base_url = VLLM_BASE_URL
return ChatOpenAI(
base_url=base_url,
api_key=SecretStr(LLAMACPP_API_KEY),
model="gemma-4-E4B-it",
timeout=60.0,
max_retries=2,
streaming=True,
)
# 模型创建器映射
CREATORS = {
"local": create_local,
"deepseek": create_deepseek,
"zhipu": create_zhipu,
}

View File

@@ -0,0 +1,37 @@
# app/prompts.py
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
"""
创建系统提示模板,可选择动态注入工具描述。
"""
tools_section = ""
if tools:
tool_descs = []
for tool in tools:
# 提取工具名称和描述的第一行
name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
desc = (tool.description or "").split('\n')[0]
tool_descs.append(f"- {name}: {desc}")
tools_section = "\n".join(tool_descs)
system_template = (
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
f"{tools_section}\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接。\n"
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n"
"3. 优先利用已知用户信息进行个性化回复。\n"
"4. 若无信息可依,礼貌询问或提供通用帮助。"
)
return ChatPromptTemplate.from_messages([
("system", system_template),
MessagesPlaceholder(variable_name="messages")
])

View File

@@ -0,0 +1,23 @@
# app/rag_initializer.py
from ..rag.tools import create_rag_tool_sync
from rag_core import create_parent_retriever
from ..logger import info, warning
async def init_rag_tool(local_llm_creator):
"""初始化 RAG 工具,失败返回 None"""
try:
info("🔄 正在初始化 RAG 检索系统...")
retriever = create_parent_retriever(
collection_name="rag_documents",
search_k=5,
)
rewrite_llm = local_llm_creator()
rag_tool = create_rag_tool_sync(
retriever, rewrite_llm,
num_queries=3, rerank_top_n=5
)
info("✅ RAG 检索工具初始化成功")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None

View File

@@ -0,0 +1,154 @@
"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
# 本地模块
from ..graph.graph_builder import GraphBuilder, GraphContext
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from .llm_factory import LLMFactory
from .rag_initializer import init_rag_tool
from ..logger import info, warning
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graphs = {}
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
async def initialize(self):
# 1. 初始化 RAG 工具(如果需要)
rag_tool = await init_rag_tool(LLMFactory.create_local)
if rag_tool:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
# 2. 构建各模型的 Graph
for name, creator in LLMFactory.CREATORS.items():
try:
info(f"🔄 初始化模型 '{name}'...")
llm = creator()
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
available = list(self.graphs.keys())
if not available:
raise RuntimeError("没有可用的模型")
model = available[0]
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
graph = self.graphs[model]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""流式处理消息,返回异步生成器"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"],
version="v2",
subgraphs=True
):
chunk_type = chunk["type"]
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token,
"metadata": metadata
}
elif chunk_type == "updates":
updates_data = chunk["data"]
serialized_data = self._serialize_value(updates_data)
processed_event = {
"type": "state_update",
"data": serialized_data
}
if "messages" in serialized_data:
processed_event["messages"] = serialized_data["messages"]
elif chunk_type == "custom":
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
else:
continue
if processed_event:
yield processed_event

212
backend/app/backend.py Normal file
View File

@@ -0,0 +1,212 @@
"""
FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
采用依赖注入模式,优雅管理资源生命周期
"""
import os
from .config import DB_URI, BACKEND_PORT
import uuid
import json
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from .agent.service import AIAgentService
from .agent.history import ThreadHistoryService
from .logger import info, error
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
# 1. 创建数据库连接池并初始化表(仅 checkpointer
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
# 2. 构建 AI Agent 服务
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
# 3. 创建历史查询服务
history_service = ThreadHistoryService(checkpointer)
# 4. 将服务实例存入 app.state
app.state.agent_service = agent_service
app.state.history_service = history_service
# 应用运行中...
yield
# 5. 关闭时自动清理数据库连接async with 负责)
info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== 健康检查端点 ==========
@app.get("/health")
async def health_check():
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
return {"status": "ok", "service": "ai-agent-backend"}
# ========== Pydantic 模型 ==========
class ChatRequest(BaseModel):
message: str
thread_id: str | None = None
model: str = "zhipu"
user_id: str = "default_user"
class ChatResponse(BaseModel):
reply: str
thread_id: str
model_used: str
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
elapsed_time: float = 0.0
# ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service
def get_history_service(request: Request) -> ThreadHistoryService:
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
return request.app.state.history_service
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
agent_service: AIAgentService = Depends(get_agent_service)
):
"""同步对话接口,支持模型选择"""
if not request.message:
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
result = await agent_service.process_message(
request.message, thread_id, request.model, request.user_id
)
# 提取 token 统计信息
token_usage = result.get("token_usage", {})
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
elapsed_time = result.get("elapsed_time", 0.0)
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
return ChatResponse(
reply=result["reply"],
thread_id=thread_id,
model_used=actual_model,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
elapsed_time=elapsed_time
)
# ========== 历史查询接口 ==========
@app.get("/threads")
async def list_threads(
user_id: str = Query("default_user", description="用户 ID"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取当前用户的对话历史列表"""
threads = await history_service.get_user_threads(user_id, limit)
return {"threads": threads}
@app.get("/thread/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
user_id: str = Query("default_user", description="用户 ID"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取指定线程的完整消息历史"""
messages = await history_service.get_thread_messages(thread_id)
return {"messages": messages}
@app.get("/thread/{thread_id}/summary")
async def get_thread_summary(
thread_id: str,
user_id: str = Query("default_user", description="用户 ID"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取指定线程的摘要信息"""
summary = await history_service.get_thread_summary(thread_id)
return summary
# ========== 流式对话接口 ==========
@app.post("/chat/stream")
async def chat_stream_endpoint(
request: ChatRequest,
agent_service: AIAgentService = Depends(get_agent_service)
):
"""流式对话接口SSE"""
if not request.message:
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
async def event_generator():
try:
async for chunk in agent_service.process_message_stream(
request.message, thread_id, request.model, request.user_id
):
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
error(f"流式响应异常: {e}")
yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
}
)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
agent_service: AIAgentService = Depends(get_agent_service)
):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
message = data.get("message")
thread_id = data.get("thread_id", str(uuid.uuid4()))
model = data.get("model", "zhipu")
user_id = data.get("user_id", "default_user")
if not message:
await websocket.send_json({"error": "missing message"})
continue
reply = await agent_service.process_message(message, thread_id, model, user_id)
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
except WebSocketDisconnect:
pass
if __name__ == "__main__":
import uvicorn
# 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突)
port = int(BACKEND_PORT)
uvicorn.run(app, host="0.0.0.0", port=port)

50
backend/app/config.py Normal file
View File

@@ -0,0 +1,50 @@
"""
环境变量集中管理模块
所有配置项统一定义,避免散落在各个文件中
"""
import os
# ========== Graph 执行追踪配置 ==========
# 是否启用 Graph 流转追踪(通过环境变量控制)
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
# ========== 记忆提取配置 ==========
# 记忆提取间隔:每 N 轮对话生成一次摘要
MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
# ========== Mem0 记忆层配置 ==========
# Qdrant 向量数据库地址
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
# ========== llm 配置 ==========
# LLM 模型配置
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key")
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key")
# ========== 后端服务配置 ==========
# 数据库连接字符串
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
# 后端服务端口
BACKEND_PORT = int(os.getenv("BACKEND_PORT", "8079"))
# ========== 日志配置 ==========
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
# ========== Reranker 服务配置 ==========
LLAMACPP_RERANKER_URL = os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083")
# ========== 第三方 API 密钥 ==========
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY", "")
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")

View File

@@ -0,0 +1,8 @@
"""
Graph 子模块
"""
from .graph_builder import GraphBuilder
from .state import MessagesState, GraphContext
__all__ = ["GraphBuilder", "MessagesState", "GraphContext"]

View File

@@ -0,0 +1,83 @@
"""
LangGraph 状态图构建模块 - 精简版,仅负责组装图
所有节点逻辑已拆分到独立模块
"""
from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
from .state import MessagesState, GraphContext
from ..nodes import (
should_continue,
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
finalize_node,
)
from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client
from ..memory import Mem0Client
class GraphBuilder:
"""LangGraph 状态图构建器 - 仅负责组装图"""
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
"""
初始化构建器
Args:
llm: 大语言模型实例
tools: 工具列表
tools_by_name: 名称到工具函数的映射
"""
self.llm = llm
self.tools = tools
self.tools_by_name = tools_by_name
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
self.mem0_client = Mem0Client(llm)
def build(self) -> StateGraph:
"""
构建未编译的状态图
Returns:
StateGraph 实例
"""
# 注入全局客户端
set_mem0_client(self.mem0_client)
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
llm_call_node = create_llm_call_node(self.llm, self.tools)
tool_call_node = create_tool_call_node(self.tools_by_name)
summarize_node = create_summarize_node(self.mem0_client)
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("memory_trigger", memory_trigger_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
builder.add_node("finalize", finalize_node)
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "memory_trigger")
builder.add_edge("memory_trigger", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{
"tool_node": "tool_node",
"summarize": "summarize",
"finalize": "finalize"
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("summarize", "finalize")
builder.add_edge("finalize", END)
return builder

View File

@@ -0,0 +1,95 @@
"""
工具定义模块 - 纯函数工具,无依赖 AIAgent 类
"""
# 标准库
from pathlib import Path
# 第三方库
import pandas as pd
import pypdf
import requests
from bs4 import BeautifulSoup
from langchain_core.tools import tool
def _file_allow_check(filename: str) -> Path:
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
allowed_dir = Path("./user_docs").resolve()
allowed_dir.mkdir(exist_ok=True)
file_path = (allowed_dir / filename).resolve()
if not str(file_path).startswith(str(allowed_dir)):
raise ValueError("错误:非法文件路径。")
if not file_path.exists():
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
return file_path
@tool
def get_current_temperature(location: str) -> str:
"""获取指定地点的当前温度。"""
return f'当前{location}的温度为25℃'
@tool
def read_local_file(filename: str) -> str:
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
try:
file_path = _file_allow_check(filename)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
except Exception as e:
return f"读取文件时出错:{str(e)}"
@tool
def read_pdf_summary(filename: str) -> str:
"""读取PDF文件并返回内容文本摘要。"""
try:
file_path = _file_allow_check(filename)
text = ""
with open(file_path, 'rb') as f:
reader = pypdf.PdfReader(f)
for page in reader.pages[:3]:
text += page.extract_text()
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
except Exception as e:
return f"读取PDF出错{e}"
@tool
def read_excel_as_markdown(filename: str) -> str:
"""读取Excel文件并将其主要数据转换为Markdown表格格式。"""
try:
file_path = _file_allow_check(filename)
df = pd.read_excel(file_path)
markdown_table = df.head(10).to_markdown(index=False)
return f"Excel文件 '{filename}' 的数据预览前10行\n{markdown_table}"
except Exception as e:
return f"读取Excel出错{e}"
@tool
def fetch_webpage_content(url: str) -> str:
"""抓取给定URL的网页正文内容并返回清晰的纯文本。"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
except Exception as e:
return f"抓取网页时出错:{str(e)}"
# 工具列表和映射(全局常量)
AVAILABLE_TOOLS = [
get_current_temperature,
read_local_file,
fetch_webpage_content,
read_pdf_summary,
read_excel_as_markdown
]
TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS}

View File

@@ -0,0 +1,76 @@
"""
记忆检索节点模块
负责从 Mem0 检索相关长期记忆
"""
from typing import Any, Dict
# 本地模块
from .state import MessagesState
from ..memory.mem0_client import Mem0Client
from ..utils.logging import log_state_change
from ..logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆检索节点
Args:
mem0_client: Mem0 客户端实例
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆检索节点 - 使用 Mem0
Args:
state: 当前对话状态
config: 运行时配置
Returns:
包含 memory_context 的状态更新
"""
log_state_change("retrieve_memory", state, "进入")
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 兼容 dict 和对象两种消息格式
last_msg = state["messages"][-1]
if isinstance(last_msg, dict):
query = str(last_msg.get("content", ""))
else:
query = str(last_msg.content)
memory_text_parts = []
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
await mem0_client.initialize()
if mem0_client.mem0:
try:
# 异步调用 Mem0 语义检索
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
if facts:
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
else:
debug("🔍 [记忆检索] 未找到相关记忆")
except Exception as e:
from app.logger import warning
warning(f"⚠️ Mem0 检索失败: {e}")
else:
from app.logger import warning
warning("⚠️ Mem0 未初始化,跳过记忆检索")
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
result = {"memory_context": memory_context}
log_state_change("retrieve_memory", {**state, **result}, "离开")
return result
return retrieve_memory

View File

@@ -0,0 +1,25 @@
"""
LangGraph 状态定义模块
包含 MessagesState 和 GraphContext
"""
import operator
from typing import Annotated
from typing_extensions import TypedDict
from dataclasses import dataclass
from langchain_core.messages import AnyMessage
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
llm_calls: int
memory_context: str
last_token_usage: dict # 本次调用的 token 使用详情
last_elapsed_time: float # 本次调用耗时(秒)
turns_since_last_summary: int # 距离上次生成摘要的轮数
@dataclass
class GraphContext:
"""图执行上下文"""
user_id: str
# 可扩展更多上下文信息

56
backend/app/logger.py Normal file
View File

@@ -0,0 +1,56 @@
"""
统一的日志模块 - 基于环境变量控制日志级别
类似 C# 的条件编译效果,开发时打印详细调试信息,生产环境只输出关键信息
"""
import os
from .config import LOG_LEVEL, DEBUG
import logging
from typing import Any
from dotenv import load_dotenv
# 先加载环境变量
load_dotenv()
# 从环境变量读取日志级别,默认 INFO
# 根据环境变量控制是否显示详细调试信息
DEBUG_MODE = DEBUG
# 创建统一的日志器
logger = logging.getLogger("ai_agent")
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
# 避免重复添加 handler
if not logger.handlers:
handler = logging.StreamHandler()
# 重要handler 也需要设置级别,否则可能继承根 logger 的级别
handler.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
def debug(msg: Any, *args, **kwargs):
"""调试日志,仅在 DEBUG 环境变量为 true 时打印"""
if DEBUG_MODE:
logger.debug(msg, *args, **kwargs)
def info(msg: Any, *args, **kwargs):
"""信息日志"""
logger.info(msg, *args, **kwargs)
def warning(msg: Any, *args, **kwargs):
"""警告日志"""
logger.warning(msg, *args, **kwargs)
def error(msg: Any, *args, **kwargs):
"""错误日志"""
logger.error(msg, *args, **kwargs)

View File

@@ -0,0 +1,7 @@
"""
Mem0 记忆层模块
"""
from .mem0_client import Mem0Client
__all__ = ["Mem0Client"]

View File

@@ -0,0 +1,146 @@
from ..config import LLM_API_KEY
from ..config import VLLM_BASE_URL
import time
"""
Mem0 记忆层客户端封装模块
负责 Mem0 的初始化、检索和存储
"""
import asyncio
from typing import Optional, List, Dict
from mem0 import AsyncMemory
from ..config import (
QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY,
VLLM_BASE_URL, LLM_API_KEY,
LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
)
from ..logger import info, warning, error
class Mem0Client:
"""Mem0 异步客户端封装类"""
def __init__(self, llm_instance):
"""
初始化 Mem0 客户端
Args:
llm_instance: LangChain LLM 实例(用于事实提取)
"""
self.llm = llm_instance
self.mem0: Optional[AsyncMemory] = None
self._initialized = False
async def initialize(self):
"""异步初始化 Mem0 客户端,并进行实际连接测试"""
if self._initialized:
return
try:
# Mem0 配置
config = {
"vector_store": {
"provider": "qdrant",
"config": {
"url": QDRANT_URL, # 直接使用完整 URL
"api_key": QDRANT_API_KEY,
"collection_name": QDRANT_COLLECTION_NAME,
"embedding_model_dims": 1024,
}
},
"llm": {
"provider": "openai",
"config": {
"model": "LLM_MODEL",
"api_key": LLM_API_KEY,
"openai_base_url": VLLM_BASE_URL,
"temperature": 0.1,
"max_tokens": 2000,
}
},
"embedder": {
"provider": "openai",
"config": {
"model": "Qwen3-Embedding-0.6B-Q8_0",
"api_key": LLAMACPP_API_KEY,
"openai_base_url": LLAMACPP_EMBEDDING_URL,
},
},
"version": "v1.1"
}
self.mem0 = AsyncMemory.from_config(config)
info("✅ Mem0 配置加载成功,开始连接测试...")
# 实际连接测试:调用一次 search 确保 Qdrant 和 Embedding 都可达
await asyncio.wait_for(
self.mem0.search("ping", user_id="test", limit=1),
timeout=60.0
)
info("✅ Mem0 实际连接测试成功,初始化完成")
self._initialized = True
except asyncio.TimeoutError:
error("❌ Mem0 连接测试超时 (10s),请检查 Qdrant 或 Embedding 服务响应")
self.mem0 = None
self._initialized = False
except Exception as e:
error(f"❌ Mem0 初始化或连接测试失败: {e}")
import traceback
error(f"详细错误信息:\n{traceback.format_exc()}")
self.mem0 = None
self._initialized = False
async def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[str]:
"""
检索相关记忆
Args:
query: 查询文本
user_id: 用户 ID
limit: 返回结果数量限制
Returns:
List[str]: 记忆事实列表
"""
if not self.mem0:
warning("⚠️ Mem0 未初始化,跳过记忆检索")
return []
try:
memories = await asyncio.wait_for(
self.mem0.search(query, user_id=user_id, limit=limit),
timeout=30.0
)
if memories and "results" in memories:
facts = [m["memory"] for m in memories["results"] if m.get("memory")]
if facts:
info(f"🔍 [记忆检索] Mem0 返回 {len(facts)} 条记忆")
return facts
info("🔍 [记忆检索] 未找到相关记忆")
return []
except asyncio.TimeoutError:
warning("⚠️ Mem0 检索超时 (30s),跳过本次记忆检索")
return []
except Exception as e:
warning(f"⚠️ Mem0 检索失败: {e}")
return []
async def add_memories(self, messages, user_id):
if not self.mem0:
return False
try:
start = time.time()
info(f"📝 开始 Mem0 add消息数: {len(messages)}")
await asyncio.wait_for(
self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}),
timeout=60.0
)
info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s")
return True
except asyncio.TimeoutError:
error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s")
return False

View File

@@ -0,0 +1,19 @@
"""
节点模块 - 导出所有 LangGraph 节点函数
"""
from .router import should_continue
from .llm_call import create_llm_call_node
from .tool_call import create_tool_call_node
from ..graph.retrieve_memory import create_retrieve_memory_node
from .summarize import create_summarize_node
from .finalize import finalize_node
__all__ = [
"should_continue",
"create_llm_call_node",
"create_tool_call_node",
"create_retrieve_memory_node",
"create_summarize_node",
"finalize_node",
]

View File

@@ -0,0 +1,45 @@
"""
完成事件节点模块
负责发送完成事件包含token使用情况和耗时信息
"""
from typing import Any, Dict
from langgraph.config import get_stream_writer
# 本地模块
from ..graph.state import MessagesState
from ..utils.logging import log_state_change
from ..logger import info, error
from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
完成事件节点 - 发送完成事件包含token使用情况和耗时信息
Args:
state: 当前对话状态
config: 运行时配置
Returns:
空字典(完成节点,无状态更新)
"""
log_state_change("finalize", state, "进入")
try:
# 获取流式写入器并发送完成事件
writer = get_stream_writer()
writer({
"type": "custom",
"data": {
"type": "done",
"token_usage": state.get("last_token_usage", {}),
"elapsed_time": state.get("last_elapsed_time", 0.0)
}
})
info("🏁 [完成事件] 已发送完成事件包含token使用情况和耗时信息")
except Exception as e:
error(f"❌ [完成事件] 发送完成事件时发生异常: {e}")
log_state_change("finalize", state, "离开")
return {}

View File

@@ -0,0 +1,150 @@
"""
LLM 调用节点模块
负责调用大语言模型并处理响应
"""
import time
from typing import Any, Dict
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AIMessage
# 本地模块
from ..graph.state import MessagesState
from ..agent.prompts import create_system_prompt
from ..utils.logging import log_state_change
from ..logger import debug, info, error
def create_llm_call_node(llm: BaseLLM, tools: list):
"""
工厂函数:创建 LLM 调用节点
Args:
llm: LangChain LLM 实例
tools: 工具列表
Returns:
异步节点函数
"""
# 构建调用链
prompt = create_system_prompt(tools)
llm_with_tools = llm.bind_tools(tools)
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
chain = prompt | llm_with_tools
from langchain_core.runnables.config import RunnableConfig
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
LLM 调用节点(异步方法)
Args:
state: 当前对话状态
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
Returns:
更新后的状态字典
"""
log_state_change("llm_call", state, "进入")
memory_context = state.get("memory_context", "暂无用户信息")
start_time = time.time()
try:
# 恢复为:手动进行 astream并将所有的 chunk 拼接成最终的 response 返回。
# LangGraph 会自动监听这期间产生的所有 token。
chunks = []
async for chunk in chain.astream(
{
"messages": state["messages"],
"memory_context": memory_context
},
config=config
):
chunks.append(chunk)
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0]
for chunk in chunks[1:]:
response = response + chunk
else:
response = AIMessage(content="")
elapsed_time = time.time() - start_time
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
token_usage = {}
input_tokens = 0
output_tokens = 0
# 尝试从 response_metadata 中提取
if hasattr(response, 'response_metadata') and response.response_metadata:
meta = response.response_metadata
if 'token_usage' in meta:
token_usage = meta['token_usage']
elif 'usage' in meta:
token_usage = meta['usage']
# 尝试从 additional_kwargs 中提取
if not token_usage and hasattr(response, 'additional_kwargs'):
add_kwargs = response.additional_kwargs
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
token_usage = add_kwargs['llm_output']['token_usage']
# 提取具体的 token 数值
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
debug(f" 消息类型: {response.type.upper()}")
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
debug("="*80 + "\n")
result = {
"messages": [response],
"llm_calls": state.get('llm_calls', 0) + 1,
"last_token_usage": token_usage,
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器
}
log_state_change("llm_call", {**state, **result}, "离开")
return result
except Exception as e:
elapsed_time = time.time() - start_time
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
error(f" 错误类型: {type(e).__name__}")
error(f" 错误信息: {str(e)}")
import traceback
traceback.print_exc()
debug("="*80 + "\n")
# 返回一个友好的错误消息
error_response = AIMessage(
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
)
error_result = {
"messages": [error_response],
"llm_calls": state.get('llm_calls', 0),
"last_token_usage": {},
"last_elapsed_time": elapsed_time,
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器
}
log_state_change("llm_call", state, "离开(异常)")
return error_result
return call_llm

View File

@@ -0,0 +1,38 @@
from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from ..graph.state import MessagesState
from ..memory.mem0_client import Mem0Client
from ..logger import info
# 全局变量,在 GraphBuilder 中注入
_mem0_client: Mem0Client = None
def set_mem0_client(client: Mem0Client):
global _mem0_client
_mem0_client = client
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
if _mem0_client is None:
return {}
messages = state.get("messages", [])
if not messages:
return {}
last_msg = messages[-1]
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
# 触发词(可自行扩展)
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
if any(word in content for word in trigger_words):
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化
if not _mem0_client._initialized:
await _mem0_client.initialize()
# 将用户消息作为事实来源提交给 Mem0
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
mem0_messages = [{"role": "user", "content": content}]
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
return {} # 不修改状态

View File

@@ -0,0 +1,48 @@
"""
路由决策节点
根据当前状态决定下一步走向
"""
from typing import Literal
from langchain_core.messages import AIMessage
# 本地模块
from ..config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
from ..graph.state import MessagesState
from ..logger import info
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']:
"""
决定下一步:工具调用、生成摘要还是结束
Args:
state: 当前对话状态
Returns:
下一个节点名称
"""
last_message = state["messages"][-1]
# 1. 如果需要调用工具,优先进入工具节点
if isinstance(last_message, AIMessage) and last_message.tool_calls:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
return 'tool_node'
# 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值
if isinstance(last_message, AIMessage):
turns = state.get("turns_since_last_summary", 0)
if turns >= MEMORY_SUMMARIZE_INTERVAL:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'")
return 'summarize'
else:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
return 'finalize'
# 3. 其他情况(如只有用户消息)直接结束
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
return 'finalize'

View File

@@ -0,0 +1,87 @@
"""
记忆存储节点模块
负责将对话历史提交给 Mem0 进行事实提取和存储
"""
from typing import Any, Dict
# 本地模块
from ..graph.state import MessagesState
from ..memory.mem0_client import Mem0Client
from ..utils.logging import log_state_change
from ..logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆存储节点
Args:
mem0_client: Mem0 客户端实例
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆存储节点 - 使用 Mem0
Args:
state: 当前对话状态
config: 运行时配置
Returns:
重置计数器的状态更新
"""
log_state_change("summarize", state, "进入")
messages = state["messages"]
if len(messages) < 4:
debug("📝 [记忆添加] 对话过短,跳过")
return {"turns_since_last_summary": 0}
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
await mem0_client.initialize()
# 将整个对话历史转换为 Mem0 需要的消息格式
mem0_messages = []
for msg in messages:
# 兼容 dict 和对象两种格式
if isinstance(msg, dict):
msg_type = msg.get("type", "")
msg_content = msg.get("content", "")
else:
msg_type = getattr(msg, 'type', '')
msg_content = getattr(msg, 'content', '')
if msg_type == "human":
mem0_messages.append({"role": "user", "content": msg_content})
elif msg_type == "ai":
mem0_messages.append({"role": "assistant", "content": msg_content})
elif msg_type == "tool":
mem0_messages.append({"role": "system", "content": f"[工具返回] {msg_content}"})
if mem0_client.mem0:
try:
# 异步调用 Mem0 自动提取并存储事实
success = await mem0_client.add_memories(
mem0_messages,
user_id=user_id
)
if success:
info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取")
except Exception as e:
error(f"❌ Mem0 记忆添加失败: {e}")
else:
warning("⚠️ Mem0 未初始化,跳过记忆添加")
log_state_change("summarize", state, "离开")
return {"turns_since_last_summary": 0}
return summarize_conversation

View File

@@ -0,0 +1,101 @@
"""
工具执行节点模块
负责执行 AI 调用的工具函数
"""
import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.config import get_stream_writer
# 本地模块
from ..graph.state import MessagesState
from ..utils.logging import log_state_change
from ..logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""
工厂函数:创建工具执行节点
Args:
tools_by_name: 名称到工具函数的映射字典
Returns:
异步节点函数
"""
from langchain_core.runnables.config import RunnableConfig
async def call_tools(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
工具执行节点(异步方法)
Args:
state: 当前对话状态
config: 运行时配置
Returns:
包含 ToolMessage 的状态更新
"""
log_state_change("tool_node", state, "进入")
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
log_state_change("tool_node", state, "离开(无工具调用)")
return {"messages": []}
results = []
loop = asyncio.get_event_loop()
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = tools_by_name.get(tool_name)
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
if tool_func is None:
err_msg = f"Tool {tool_name} not found"
debug(f" └─ ❌ {err_msg}")
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
continue
# 获取流式写入器并发送工具调用开始事件
writer = get_stream_writer()
writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}})
try:
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
if hasattr(tool_func, 'ainvoke'):
observation = await tool_func.ainvoke(tool_args)
else:
observation = await loop.run_in_executor(
None,
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
)
# 字符打印
result_preview = str(observation).replace("\n", " ")
debug(f" └─ ✅ 结果: {result_preview}")
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
# 发送工具调用完成事件
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}})
except Exception as e:
debug(f" └─ ❌ 异常: {e}")
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
# 发送工具调用失败事件
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}})
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
result = {"messages": results}
log_state_change("tool_node", {**state, **result}, "离开")
return result
return call_tools

391
backend/app/rag/README.md Normal file
View File

@@ -0,0 +1,391 @@
# 在线 RAG 检索与生成系统 (Online RAG Retriever)
该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。
## 🎯 核心架构
### 技术栈
| 组件 | 技术选型 | 版本 | 说明 |
|:-----|:---------|:-----|:-----|
| **基础检索** | `Qdrant` | 1.17+ | HNSW 稠密向量检索 |
| **混合检索** | `Qdrant` + `BM25` | 内置 | 稠密 + 稀疏向量融合 |
| **查询改写** | `LangChain` | 内置 | `MultiQueryGenerator` 多路改写 |
| **RRF 融合** | 自实现 | - | `reciprocal_rank_fusion` 倒数排名融合 |
| **重排序** | `llama.cpp` | 本地服务 | OpenAI 兼容 Rerank API |
| **编排框架** | `asyncio` | Python 3.10+ | 异步并行检索 |
### 检索流水线
```
┌─────────────────────────────────────────────────────────────┐
│ 用户提问 │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ MultiQueryGenerator │
│ 多路查询改写 (num_queries=3) │
│ "如何申请项目资金?" → ["项目资金申请流程", "经费申请步骤"] │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ 并行检索 (asyncio.gather) │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ 查询1 检索 │ │ 查询2 检索 │ │ 查询3 检索 │ │
│ │ (k=20) │ │ (k=20) │ │ (k=20) │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ reciprocal_rank_fusion (RRF) │
│ RRF_score(d) = Σ 1/(k + rank_q(d)) (k=60) │
│ 融合多路检索结果,去重排序 │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ LLaMaCPPReranker │
│ 远程重排序 (bge-reranker-v2-m3) │
│ 返回 Top-N (top_n=5) 最相关文档 │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ 返回增强上下文 │
│ format_context() → 格式化输出 │
└─────────────────────────────────────────────────────────────┘
```
### 技术特性
-**多路查询改写**:通过 LLM 将单一问题改写为多个不同角度的查询
-**RRF 融合算法**Reciprocal Rank Fusion无需评分归一化的融合算法
-**远程重排序**:使用 llama.cpp 服务的 OpenAI 兼容 Rerank API
-**混合检索支持**:稠密向量 + BM25 稀疏向量混合检索
-**异步并行检索**:多路查询并行执行,提升检索速度
-**优雅降级**:重排序器不可用时自动降级到基础融合结果
## 📂 架构与文件结构
```
app/rag/
├── __init__.py
├── retriever.py # Qdrant 基础检索与混合检索
├── reranker.py # llama.cpp 远程重排序器
├── query_transform.py # 多路查询改写生成器
├── fusion.py # RRF 倒数排名融合算法
├── pipeline.py # RAG 流水线编排
└── tools.py # LangChain Tool 封装
```
## 🎯 演进路线与算法详解 (Roadmap)
### Level 1: 基础向量搜索 (Basic Similarity Search)
- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。
- **优缺点**: 速度极快。但只能捕捉"语义相似",如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生"幻觉"匹配)。
- **实现指南**:
- 使用 `rag_indexer.embedders.LlamaCppEmbedder` 作为嵌入模型
- 使用 `app/rag/retriever.py` 中的 `create_base_retriever` 创建基础检索器
- 配置 `search_kwargs={"k": 20}` 进行初步召回
```python
from app.rag.retriever import create_base_retriever
retriever = create_base_retriever(
collection_name="rag_documents",
embeddings=embeddings,
search_kwargs={"k": 20}
)
docs = retriever.invoke("什么是 RAG")
```
### Level 2: 混合检索与重排序 (Hybrid Search + Reranker)
混合检索旨在结合向量的"语义泛化"与关键词的"精准匹配",随后利用重排序模型过滤噪声。
**1. 基础召回 (混合检索)**
- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。
- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_hybrid_retriever` 函数,配置 `dense_k=10``sparse_k=10`,总召回 20 条结果。
```python
from app.rag.retriever import create_hybrid_retriever
retriever = create_hybrid_retriever(
collection_name="rag_documents",
embeddings=embeddings,
dense_k=10,
sparse_k=10,
score_threshold=0.3
)
```
**2. 二次精排 (Cross-Encoder)**
- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将"用户问题 + 检索到的单例文档"拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。
- **实现指南**:
- 使用 `app/rag/reranker.py` 中的 `LLaMaCPPReranker` 类,加载 `bge-reranker-v2-m3` 模型
- 设置 `top_n=5` 保留最相关的 5 条结果
```python
from app.rag.reranker import LLaMaCPPReranker
reranker = LLaMaCPPReranker(
base_url="http://127.0.0.1:8083",
api_key="your-api-key",
top_n=5
)
sorted_docs = reranker.compress_documents(documents, query)
```
### Level 3: RAG-Fusion (多路改写与倒数排名融合)
RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。
**1. 多路查询改写**
- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。
- **实现指南**: 使用 `app/rag/query_transform.py` 中的 `MultiQueryGenerator` 类,配置 `num_queries=3` 生成 3 个不同角度的查询。
```python
from app.rag.query_transform import MultiQueryGenerator
generator = MultiQueryGenerator(llm=llm, num_queries=3)
queries = await generator.agenerate("如何申请项目资金?")
# 返回:["如何申请项目资金?", "项目资金申请流程是什么?", "申请项目经费需要哪些步骤?"]
```
**2. 倒数排名融合 (RRF)**
- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 `RRF_score(d) = Σ 1/(k + rank_q(d))`,有效避免某一极端检索结果主导全局。
- **实现指南**: 使用 `app/rag/fusion.py` 中的 `reciprocal_rank_fusion` 函数,配置 `k=60` 实现倒数排名融合。
```python
from app.rag.fusion import reciprocal_rank_fusion
# 多个查询的检索结果
doc_lists = [result1, result2, result3]
fused_docs = reciprocal_rank_fusion(doc_lists, k=60)
```
### Level 4: Agentic RAG / Self-RAG (智能体与自我反思)
- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:"这是闲聊?还是需要查知识库?"。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。
- **实现指南**: 使用 `app/rag/tools.py` 中的 `search_knowledge_base` 工具,将其绑定到 LangGraph 状态机中。
- **示意图**:
```
┌──────────┐ ┌──────────────┐ ┌──────────┐ ┌────────
│ User │────>│ LangGraph │────>│ RAG_Tool │────>│ Qdrant │
│ │ │ Agent │ │ │ │ │
│ "公司报 │ │ 思考: 这是 │ │ ToolCall │ │ RAG- │
│ 销流程?"│ │ 内部规章问题 │ │ search_ │ │ Fusion │
│ │ │ 需要查资料 │ │ knowledge│ │ & 混合 │
│ │<────│ 资料充分, │<────│ 返回最相 │<────│ 检索 │
│ "根据知 │ │ 开始撰写回答 │ │ 关5条规定 │ │ Cross- │
│ 识库规定 │ │ │ │ │ │ Encoder│
│ ..." │ │ │ │ │ │ 重排 │
└────────── └────────────── └──────────┘ └────────┘
```
### Level 5: GraphRAG 集成 (基于图和关系的 RAG)
- **核心原理**: 结合知识图谱的结构化关系和向量检索的语义相似度,解决跨文档复杂关系推理问题。
- **实现指南**:
- 使用 `langchain_community.graphs` 模块构建知识图谱
- 配置本地大模型(如 `Gemma-4-E4B`)用于实体关系抽取
- 实现混合检索逻辑,结合向量相似度和图路径分析
```python
from langchain_community.graphs import Neo4jGraph
from langchain_experimental.graph_transformers import LLMGraphTransformer
# 实体关系抽取
transformer = LLMGraphTransformer(llm=local_llm)
graph_documents = transformer.convert_to_graph_documents(documents)
# 存储到图数据库
graph = Neo4jGraph(url="bolt://localhost:7687")
graph.add_graph_documents(graph_documents)
```
## 🔧 核心组件详解
### 1. 检索器 (retriever.py)
提供基于 Qdrant 的向量检索能力。
**基础检索器**
```python
from app.rag.retriever import create_base_retriever
retriever = create_base_retriever(
collection_name="rag_documents",
embeddings=embeddings,
search_kwargs={"k": 20}
)
```
**混合检索器**
```python
from app.rag.retriever import create_hybrid_retriever
retriever = create_hybrid_retriever(
collection_name="rag_documents",
embeddings=embeddings,
dense_k=10,
sparse_k=10,
score_threshold=0.3
)
```
### 2. 多路查询改写 (query_transform.py)
通过 LLM 将用户问题改写为多个不同版本,扩大搜索面。
```python
from app.rag.query_transform import MultiQueryGenerator
generator = MultiQueryGenerator(llm=llm, num_queries=3)
queries = await generator.agenerate("如何申请项目资金?")
```
### 3. RRF 融合算法 (fusion.py)
Reciprocal Rank Fusion 算法,公式:`RRF_score(d) = Σ 1/(k + rank_q(d))`
```python
from app.rag.fusion import reciprocal_rank_fusion
# 多个查询的检索结果
doc_lists = [result1, result2, result3]
fused_docs = reciprocal_rank_fusion(doc_lists, k=60)
```
### 4. 重排序器 (reranker.py)
使用 llama.cpp 服务的 OpenAI 兼容 Rerank API 对检索结果重排序。
```python
from app.rag.reranker import LLaMaCPPReranker
reranker = LLaMaCPPReranker(
base_url="http://127.0.0.1:8083",
api_key="your-api-key",
top_n=5
)
sorted_docs = reranker.compress_documents(documents, query)
```
### 5. RAG 流水线 (pipeline.py)
组合上述组件的完整检索流水线。
```python
from app.rag.pipeline import RAGPipeline
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=3,
rerank_top_n=5,
)
# 异步检索
docs = await pipeline.aretrieve("如何申请项目资金?")
# 格式化上下文
context = pipeline.format_context(docs)
```
## 🔄 与 Agent 系统集成
### 封装为 LangChain Tool
```python
from langchain_core.tools import tool
from app.rag.pipeline import RAGPipeline
@tool
def search_knowledge_base(query: str) -> str:
"""搜索知识库获取相关信息"""
docs = pipeline.retrieve(query)
return pipeline.format_context(docs)
```
### 绑定到 LangGraph
```python
from app.graph.graph_builder import GraphBuilder
# 将 RAG 工具添加到工具列表
tools = AVAILABLE_TOOLS + [search_knowledge_base]
# 构建图
builder = GraphBuilder(llm, tools, tools_by_name)
graph = builder.build().compile(checkpointer=checkpointer)
```
## ⚙️ 环境配置
| 变量名 | 说明 | 默认值 |
|:-------|:-----|:-------|
| `QDRANT_URL` | Qdrant 向量数据库地址 | `http://127.0.0.1:6333` |
| `QDRANT_API_KEY` | Qdrant API 密钥 | - |
| `LLAMACPP_RERANKER_URL` | llama.cpp 重排序服务地址 | `http://127.0.0.1:8083` |
| `LLAMACPP_API_KEY` | llama.cpp API 密钥 | - |
## 🚀 快速开始
```python
# 1. 初始化嵌入模型
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 2. 创建检索器
from app.rag.retriever import create_base_retriever
retriever = create_base_retriever(
collection_name="rag_documents",
embeddings=embeddings,
search_kwargs={"k": 20}
)
# 3. 创建 RAG 流水线
from app.rag.pipeline import RAGPipeline
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=3,
rerank_top_n=5,
)
# 4. 执行检索
docs = pipeline.retrieve("如何申请项目资金?")
# 5. 格式化上下文
context = pipeline.format_context(docs)
print(context)
```
## 📊 检索策略对比
| 策略 | 优点 | 缺点 | 适用场景 |
|:-----|:-----|:-----|:---------|
| **基础向量检索** | 速度快,语义理解好 | 专有名词匹配差 | 通用问答 |
| **混合检索** | 语义 + 关键词匹配 | 需要配置稀疏向量 | 专业术语查询 |
| **多路改写 + RRF** | 搜索面广,结果稳定 | 延迟略高 | 复杂问题 |
| **重排序** | 精度高 | 依赖额外模型 | 最终精排 |
## 🤝 与 rag_indexer 集成
- **向量存储**:共享 Qdrant 集合,确保嵌入模型一致
- **文档存储**:使用 PostgreSQL 存储父块,通过 UUID 映射
- **集合名称**:默认使用 `rag_documents` 集合
详见 [rag_indexer/README.md](../../rag_indexer/README.md)

View File

@@ -0,0 +1,69 @@
"""
RAG 检索与生成模块
提供在线检索与生成功能,包括:
- 基础向量检索(稠密向量 / 混合检索)
- 重排序Cross-Encoder
- 多路查询改写Multi-Query
- RRF 融合Reciprocal Rank Fusion
- 完整的 RAG 流水线
- Agent 工具封装
固定流水线:
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
示例用法:
>>> from app.rag.rag import RAGPipeline, create_rag_tool
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
>>> from langchain_openai import ChatOpenAI
>>>
>>> # 获取基础检索器(如父子块检索器)
>>> config = IndexBuilderConfig(collection_name="my_docs")
>>> builder = IndexBuilder(config)
>>> retriever = builder.retriever
>>>
>>> # 创建 LLM 和流水线
>>> llm = ChatOpenAI(model="gpt-3.5-turbo")
>>> pipeline = RAGPipeline(retriever=retriever, llm=llm)
>>>
>>> # 检索
>>> docs = await pipeline.aretrieve("什么是 RAG")
>>> context = pipeline.format_context(docs)
>>>
>>> # 创建 Agent 工具
>>> rag_tool = create_rag_tool(retriever=retriever, llm=llm)
"""
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from .pipeline import RAGPipeline
from .tools import create_rag_tool_sync
__all__ = [
# 检索器工厂函数
"create_base_retriever",
"create_hybrid_retriever",
"create_qdrant_client",
# 重排序器
"LLaMaCPPReranker",
# 查询改写生成器
"MultiQueryGenerator",
# 融合算法
"reciprocal_rank_fusion",
# 主流水线
"RAGPipeline",
# 工具创建(供 Agent 使用)
"create_rag_tool_sync",
]

36
backend/app/rag/fusion.py Normal file
View File

@@ -0,0 +1,36 @@
# rag/fusion.py
from typing import List, Dict
from langchain_core.documents import Document
def reciprocal_rank_fusion(
doc_lists: List[List[Document]],
k: int = 60
) -> List[Document]:
"""
对多个检索结果列表进行 RRF 融合。
Args:
doc_lists: 多个检索结果列表,每个列表来自一个查询
k: RRF 常数,通常设为 60
Returns:
融合后按 RRF 得分降序排列的文档列表
"""
# 使用文档内容作为唯一标识(如果内容相同但 metadata 不同,视为同一文档)
# 更好的做法是用 docstore 的 ID这里简化处理用内容 hash
doc_to_score: Dict[str, float] = {}
doc_map: Dict[str, Document] = {}
for docs in doc_lists:
for rank, doc in enumerate(docs, start=1):
# 生成唯一标识符(内容+来源组合,避免不同文件相同内容混淆)
doc_id = f"{doc.page_content[:200]}_{doc.metadata.get('source', '')}"
if doc_id not in doc_map:
doc_map[doc_id] = doc
score = doc_to_score.get(doc_id, 0.0) + 1.0 / (k + rank)
doc_to_score[doc_id] = score
# 按得分排序
sorted_ids = sorted(doc_to_score.keys(), key=lambda x: doc_to_score[x], reverse=True)
return [doc_map[doc_id] for doc_id in sorted_ids]

View File

@@ -0,0 +1,91 @@
# rag/pipeline.py
import asyncio
import os
from ..config import LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
class RAGPipeline:
"""
固定流程的 RAG 检索流水线:
多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档
"""
def __init__(
self,
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
):
"""
Args:
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
llm: 用于生成多路查询的语言模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
rerank_model: 重排序模型名称
"""
self.retriever = retriever
self.llm = llm
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
# 初始化组件
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
self.reranker = LLaMaCPPReranker(
base_url=LLAMACPP_RERANKER_URL,
api_key=LLAMACPP_API_KEY,
top_n=rerank_top_n,
)
async def aretrieve(self, query: str) -> List[Document]:
"""
异步执行完整检索流程
"""
# Step 1: 生成多路查询
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
# Step 4: 重排序
try:
final_docs = self.reranker.compress_documents(fused_docs, query)
except Exception:
# 若重排序器不可用,直接返回融合后的前 N 条
final_docs = fused_docs[:self.rerank_top_n]
return final_docs
def retrieve(self, query: str) -> List[Document]:
"""同步检索入口(内部调用异步方法)"""
return asyncio.run(self.aretrieve(query))
def format_context(self, documents: List[Document]) -> str:
"""将文档列表格式化为上下文字符串"""
if not documents:
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
return "\n".join(parts)

View File

@@ -0,0 +1,43 @@
# rag/query_transform.py
from typing import List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
MULTI_QUERY_PROMPT = PromptTemplate.from_template(
"""你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
原始问题: {question}
请生成 {num_queries} 个不同版本的查询,每个版本一行。
确保每个版本都是独立、完整的查询语句。
生成 {num_queries} 个查询:"""
)
class MultiQueryGenerator:
"""多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever"""
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
self.llm = llm
self.num_queries = num_queries
self.prompt = MULTI_QUERY_PROMPT
def generate(self, query: str) -> List[str]:
"""同步生成多个查询变体"""
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
response = self.llm.invoke(prompt_str)
# 处理响应内容,按行分割并去除空行和首尾空白
lines = response.content.strip().split('\n')
queries = [line.strip() for line in lines if line.strip()]
# 确保至少返回原始查询
return queries[:self.num_queries] if queries else [query]
async def agenerate(self, query: str) -> List[str]:
"""异步生成多个查询变体"""
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
response = await self.llm.ainvoke(prompt_str)
lines = response.content.strip().split('\n')
queries = [line.strip() for line in lines if line.strip()]
return queries[:self.num_queries] if queries else [query]

View File

@@ -0,0 +1,75 @@
"""
重排序器模块 (适配版)
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
"""
import requests
from typing import List
from langchain_core.documents import Document
class LLaMaCPPReranker:
"""使用远程 llama.cpp 服务对检索结果重排序。"""
def __init__(self,
base_url: str,
api_key: str,
top_n: int = 5,
timeout: int = 60):
"""
初始化远程重排序器
Args:
base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"
top_n: 返回前 N 个结果。
api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"
timeout: 请求超时时间(秒)。
"""
self.base_url = base_url
self.api_key = api_key
self.top_n = top_n
self.timeout = timeout
self.endpoint = f"{self.base_url}/rerank"
def compress_documents(
self, documents: List[Document], query: str
) -> List[Document]:
"""
对文档进行重排序
Args:
documents: 待排序的文档列表
query: 查询字符串
Returns:
排序后的文档列表
"""
if not documents:
return []
# 准备请求体
# 根据 llama.cpp 的 OpenAI 兼容性,文档是一个字符串列表
payload = {
"model": "bge-reranker-v2-m3",
"query": query,
"documents": [doc.page_content for doc in documents],
"top_n": self.top_n
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
try:
response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout)
response.raise_for_status() # 检查请求是否成功
results = response.json()
# 解析返回结果
# 返回格式: {"results": [{"index": 0, "document": "...", "relevance_score": 0.8}, ...]}
# 按相关性得分降序排列
sorted_indices = [item["index"] for item in results["results"]]
sorted_docs = [documents[idx] for idx in sorted_indices]
return sorted_docs
except Exception as e:
print(f"警告: 远程重排序过程出错,将使用原始排序。错误: {e}")
return documents[:self.top_n]

View File

@@ -0,0 +1,199 @@
"""
Qdrant 向量检索器模块
提供基于 Qdrant 的基础向量检索和混合检索Dense + Sparse功能。
核心原理:
- 基础检索:将查询文本转换为向量,在 Qdrant 中进行近似最近邻ANN搜索
使用余弦相似度返回最相似的 k 个文档。
- 混合检索:结合稠密向量检索(语义相似)和 BM25 稀疏向量检索(关键词匹配),
通过加权或分数融合提高召回精度。
使用示例:
>>> from rag_core import LlamaCppEmbedder
>>> embedder = LlamaCppEmbedder()
>>> embeddings = embedder.as_langchain_embeddings()
>>>
>>> # 创建基础检索器
>>> retriever = create_base_retriever(
... collection_name="my_docs",
... embeddings=embeddings,
... search_kwargs={"k": 10}
... )
>>>
>>> # 执行检索
>>> docs = retriever.invoke("什么是 RAG")
"""
from typing import Optional, Dict, Any
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from langchain_qdrant import QdrantVectorStore
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from rag_core import QDRANT_URL, QDRANT_API_KEY
# 模块级常量
DEFAULT_SEARCH_K = 20
DEFAULT_SCORE_THRESHOLD = 0.3
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 30,
) -> QdrantClient:
"""
创建并返回一个配置好的 Qdrant 客户端。
优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。
Args:
url: Qdrant 服务地址,例如 "http://localhost:6333"
默认从环境变量 QDRANT_URL 读取。
api_key: API 密钥(若 Qdrant 启用了认证)。
默认从环境变量 QDRANT_API_KEY 读取。
timeout: 请求超时时间(秒),默认 30 秒。
Returns:
配置好的 QdrantClient 实例。
Raises:
ValueError: 如果 url 为空且环境变量也未设置。
"""
effective_url = url or QDRANT_URL
if not effective_url:
raise ValueError(
"Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL"
)
effective_api_key = api_key or QDRANT_API_KEY
client_kwargs = {
"url": effective_url,
"timeout": timeout,
}
if effective_api_key:
client_kwargs["api_key"] = effective_api_key
return QdrantClient(**client_kwargs)
def create_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
) -> BaseRetriever:
"""
创建基础向量检索器(仅稠密向量检索)。
该检索器使用嵌入模型将查询转为向量,在 Qdrant 集合中执行 ANN 搜索,
返回语义上最相似的文档块。
Args:
collection_name: Qdrant 集合名称(需预先创建并索引)。
embeddings: LangChain 兼容的嵌入模型实例。
search_kwargs: 搜索参数,可包含:
- k (int): 返回的文档数量,默认 20。
- score_threshold (float): 相似度阈值,仅返回高于此分数的文档。
- filter (dict): Qdrant 过滤条件。
若为 None则使用默认值 {"k": 20}。
client: 可选的 Qdrant 客户端实例。若未提供,将自动创建。
Returns:
BaseRetriever 实例,可直接调用 .invoke(query) 或 .ainvoke(query) 检索。
Raises:
ValueError: 如果集合不存在或嵌入模型无效。
"""
# 合并默认搜索参数
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
if search_kwargs:
merged_search_kwargs.update(search_kwargs)
# 创建或复用 Qdrant 客户端
if client is None:
client = create_qdrant_client()
# 验证集合是否存在(可选,便于提前发现问题)
try:
client.get_collection(collection_name)
except UnexpectedResponse as e:
if e.status_code == 404:
raise ValueError(
f"Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档。"
)
raise
# 构建向量存储
vector_store = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=embeddings,
)
# 返回检索器
return vector_store.as_retriever(search_kwargs=merged_search_kwargs)
def create_hybrid_retriever(
collection_name: str,
embeddings: Embeddings,
dense_k: int = 10,
sparse_k: int = 10,
score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD,
client: Optional[QdrantClient] = None,
) -> BaseRetriever:
"""
创建混合检索器(稠密向量 + BM25 稀疏向量)。
混合检索结合了语义相似度Dense和关键词匹配Sparse
能够更好地处理专有名词、精确匹配等场景。
注意:此功能要求 Qdrant 集合已配置稀疏向量字段并生成了 BM25 索引。
若集合未配置稀疏向量,将回退到纯稠密检索(不会报错,但检索效果降级)。
Args:
collection_name: Qdrant 集合名称。
embeddings: 嵌入模型(用于稠密向量)。
dense_k: 稠密向量检索返回数量,默认 10。
sparse_k: 稀疏向量检索返回数量,默认 10。
score_threshold: 相似度阈值,默认 0.3。
client: 可选的 Qdrant 客户端实例。
Returns:
BaseRetriever 实例,配置了混合搜索参数。
"""
total_k = dense_k + sparse_k
search_kwargs = {
"k": total_k,
}
if score_threshold is not None:
search_kwargs["score_threshold"] = score_threshold
# 复用基础检索器创建逻辑,只需调整搜索参数
return create_base_retriever(
collection_name=collection_name,
embeddings=embeddings,
search_kwargs=search_kwargs,
client=client,
)
# 可选:提供异步友好的辅助函数
async def acreate_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
) -> BaseRetriever:
"""
异步创建基础向量检索器(与同步版本功能相同)。
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
"""
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
return create_base_retriever(collection_name, embeddings, search_kwargs, client)

146
backend/app/rag/test.py Normal file
View File

@@ -0,0 +1,146 @@
#!/usr/bin/env python3
"""
RAG 系统使用示例(重构版)
演示:
1. 使用 IndexBuilder 获取父子块检索器
2. 创建固定流程的 RAGPipeline多路改写 → RRF融合 → 重排序 → 返回父文档)
3. 将流水线封装为 LangChain 工具,供 Agent 调用
"""
import asyncio
import sys
import os
from dotenv import load_dotenv
# 加载环境变量Qdrant URL、PostgreSQL 连接等)
load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType
from .pipeline import RAGPipeline
from .tools import create_rag_tool_sync
from pydantic import SecretStr
# 使用本地 LLM通过 OpenAI 兼容接口)
from langchain_openai import ChatOpenAI
from rag_core.retriever_factory import create_parent_retriever
load_dotenv()
def create_llm():
"""创建本地 vLLM 服务 LLM"""
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://127.0.0.1:8081/v1"
)
return ChatOpenAI(
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
async def demonstrate_full_pipeline():
"""
完整流水线演示:
- 从 IndexBuilder 获取 ParentDocumentRetriever
- 创建 RAGPipeline
- 执行检索并打印结果
"""
print("=" * 60)
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
print("=" * 60)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
if retriever is None:
print("错误:检索器未初始化,请确保索引已构建。")
return
# 3. 创建 LLM 用于查询改写
llm = create_llm()
# 4. 创建 RAGPipeline固定流程
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=3, # 生成 3 个查询变体
rerank_top_n=5, # 最终返回 5 个父文档
)
# 5. 执行检索
query = "打虎英雄是谁?"
print(f"\n查询: {query}")
print("-" * 40)
try:
documents = await pipeline.aretrieve(query)
print(f"返回 {len(documents)} 个父文档\n")
# 打印结果预览
for i, doc in enumerate(documents, 1):
content_preview = doc.page_content.replace("\n", " ")[:150]
source = doc.metadata.get("source", "未知来源")
print(f"{i}. 【来源:{source}")
print(f" {content_preview}...\n")
# 可选:格式化完整上下文
# context = pipeline.format_context(documents)
# print(context)
except Exception as e:
print(f"检索失败: {e}")
import traceback
traceback.print_exc()
async def demonstrate_tool_creation():
"""
演示创建 RAG 工具(供 Agent 使用)
"""
print("\n" + "=" * 60)
print("演示:创建 RAG 工具(供 LangGraph Agent 调用)")
print("=" * 60)
# 1. 获取检索器(同上)
config = IndexBuilderConfig(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
)
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM
llm = create_llm()
# 3. 创建工具
rag_tool = create_rag_tool_sync(
retriever=retriever,
llm=llm,
num_queries=3,
rerank_top_n=5,
collection_name="rag_documents",
)
print(f"工具名称: {rag_tool.name}")
print(f"工具描述: {rag_tool.description[:100]}...")
# 4. 模拟 Agent 调用工具
query = "请告诉我 打虎英雄是谁?"
print(f"\n模拟调用: {query}")
print("-" * 40)
result = await rag_tool.ainvoke({"query": query})
print(result[:800] + "..." if len(result) > 800 else result)
async def main():
await demonstrate_full_pipeline()
await demonstrate_tool_creation()
if __name__ == "__main__":
asyncio.run(main())

54
backend/app/rag/tools.py Normal file
View File

@@ -0,0 +1,54 @@
"""
RAG 工具模块
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
"""
from typing import Callable
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from .pipeline import RAGPipeline
def create_rag_tool_sync(
retriever: BaseRetriever,
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent
参数同 create_rag_tool。
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
)
@tool
def search_knowledge_base_sync(query: str) -> str:
"""在知识库中搜索与查询相关的文档片段(同步版本)。
功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容。
"""
try:
documents = pipeline.retrieve(query) # 内部调用异步方法并等待
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_sync

304
backend/app/test_backend.py Normal file
View File

@@ -0,0 +1,304 @@
#!/usr/bin/env python3
"""
完整后端测试 - 验证 Agent 所有功能
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
"""
import asyncio
import os
from .config import DB_URI
import sys
import uuid
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径 (现在文件在 backend/app/ 下backend 就是根)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
load_dotenv()
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from ..agent import AIAgentService
from ..agent.history import ThreadHistoryService
from ..logger import info, warning, error
# PostgreSQL 连接字符串
async def print_section(title):
"""打印测试区块标题"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70)
async def test_short_term_memory(agent_service):
"""测试短期记忆(同一 thread_id 继续对话)"""
await print_section("测试 1: 短期记忆Short-term Memory")
thread_id = str(uuid.uuid4())
user_id = "test_user_memory"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 第一轮对话
print("\n[第一轮] 发送消息: '我叫张三今年28岁'")
result1 = await agent_service.process_message(
"我叫张三今年28岁", thread_id, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 第二轮对话 - 测试记忆
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
result2 = await agent_service.process_message(
"我叫什么名字?今年多大?", thread_id, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证记忆是否存在
if "张三" in result2['reply'] or "28" in result2['reply']:
print("\n✅ 短期记忆测试通过!")
return True
else:
print("\n❌ 短期记忆测试失败!")
return False
async def test_tool_calling(agent_service):
"""测试工具调用RAG 搜索)"""
await print_section("测试 2: 工具调用Tool Calling")
thread_id = str(uuid.uuid4())
user_id = "test_user_tools"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 发送需要 RAG 搜索的问题
print("\n发送消息: '请告诉我,打虎英雄是谁?'")
result = await agent_service.process_message(
"请告诉我,打虎英雄是谁?", thread_id, "local", user_id
)
print(f"回复: {result['reply'][:200]}...")
# 检查是否调用了 RAG 工具(回复中会有水浒传相关内容)
if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']:
print("\n✅ 工具调用测试通过!")
return True
else:
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
return None
async def test_streaming(agent_service):
"""测试流式对话"""
await print_section("测试 3: 流式对话Streaming")
thread_id = str(uuid.uuid4())
user_id = "test_user_stream"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
print("流式输出: ", end="", flush=True)
full_reply = ""
chunk_count = 0
try:
async for chunk in agent_service.process_message_stream(
"用100字介绍一下AI人工智能", thread_id, "local", user_id
):
chunk_count += 1
if chunk.get("type") == "llm_token":
token = chunk.get("token", "")
print(token, end="", flush=True)
full_reply += token
elif chunk.get("type") == "state_update":
pass # 状态更新不显示
print(f"\n\n共收到 {chunk_count} 个 chunk")
print(f"完整回复长度: {len(full_reply)}")
if chunk_count > 0 and len(full_reply) > 10:
print("\n✅ 流式对话测试通过!")
return True
else:
print("\n❌ 流式对话测试失败!")
return False
except Exception as e:
print(f"\n❌ 流式对话异常: {e}")
return False
async def test_history_service(agent_service, history_service):
"""测试历史查询服务"""
await print_section("测试 4: 历史查询服务History Service")
user_id = "test_user_history"
# 先创建几个对话
print(f"\n为 user_id={user_id} 创建测试对话...")
thread_ids = []
for i in range(3):
thread_id = str(uuid.uuid4())
thread_ids.append(thread_id)
await agent_service.process_message(
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
)
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
# 1. 测试获取用户线程列表
print("\n[4.1] 测试获取用户线程列表...")
threads = await history_service.get_user_threads(user_id, limit=10)
print(f" 找到 {len(threads)} 个线程")
if len(threads) >= 3:
print(" ✅ 线程列表查询通过")
else:
print(" ⚠️ 线程数量少于预期")
# 2. 测试获取单个线程的消息历史
if thread_ids:
test_thread_id = thread_ids[0]
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
messages = await history_service.get_thread_messages(test_thread_id)
print(f" 找到 {len(messages)} 条消息")
if len(messages) >= 2: # 至少有一问一答
print(" ✅ 消息历史查询通过")
else:
print(" ⚠️ 消息数量少于预期")
# 3. 测试获取线程摘要
print(f"\n[4.3] 测试获取线程摘要...")
summary = await history_service.get_thread_summary(test_thread_id)
print(f" 摘要: {summary.get('summary', '')[:50]}...")
print(f" 消息数: {summary.get('message_count', 0)}")
if summary.get('message_count', 0) > 0:
print(" ✅ 线程摘要查询通过")
else:
print(" ⚠️ 摘要查询结果不确定")
return len(threads) >= 3
async def test_long_term_memory(agent_service):
"""测试长期记忆mem0"""
await print_section("测试 5: 长期记忆Long-term Memory - mem0")
thread_id1 = str(uuid.uuid4())
thread_id2 = str(uuid.uuid4()) # 不同的线程
user_id = "test_user_longterm"
print(f"\n使用 user_id: {user_id}")
print(f"线程 1: {thread_id1[:8]}...")
print(f"线程 2: {thread_id2[:8]}...")
# 在第一个线程中保存信息
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
result1 = await agent_service.process_message(
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 等待一下,让 mem0 保存
await asyncio.sleep(1)
# 在第二个线程中询问(不同的 thread_id
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
result2 = await agent_service.process_message(
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证长期记忆
if "小白" in result2['reply'] or "" in result2['reply']:
print("\n✅ 长期记忆测试通过!")
return True
else:
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
return None
async def main():
"""主测试函数"""
print("\n" + "=" * 70)
print(" 后端完整功能测试")
print("=" * 70)
results = {}
try:
# 创建数据库连接和服务
print("\n正在初始化数据库连接...")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
print("✅ 数据库连接成功")
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
print("✅ Agent 服务初始化成功")
history_service = ThreadHistoryService(checkpointer)
print("✅ 历史服务初始化成功")
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
# 运行测试
results["短期记忆"] = await test_short_term_memory(agent_service)
await asyncio.sleep(1)
results["工具调用"] = await test_tool_calling(agent_service)
await asyncio.sleep(1)
results["流式对话"] = await test_streaming(agent_service)
await asyncio.sleep(1)
results["历史查询"] = await test_history_service(agent_service, history_service)
await asyncio.sleep(1)
results["长期记忆"] = await test_long_term_memory(agent_service)
await asyncio.sleep(1)
# 打印总结
await print_section("测试总结")
print("\n测试结果:")
print("-" * 40)
pass_count = 0
fail_count = 0
skip_count = 0
for test_name, result in results.items():
if result is True:
status = "✅ 通过"
pass_count += 1
elif result is False:
status = "❌ 失败"
fail_count += 1
else:
status = "⚠️ 待验证"
skip_count += 1
print(f" {test_name:12s}: {status}")
print("-" * 40)
print(f"总计: {len(results)} 个测试")
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
if fail_count == 0:
print("\n🎉 所有核心测试通过!")
else:
print(f"\n⚠️ 有 {fail_count} 个测试失败")
except Exception as e:
error(f"\n❌ 测试运行异常: {e}")
import traceback
traceback.print_exc()
return 1
return 0 if fail_count == 0 else 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@@ -0,0 +1,7 @@
"""
工具模块
"""
from .logging import log_state_change, print_llm_input
__all__ = ["log_state_change", "print_llm_input"]

View File

@@ -0,0 +1,61 @@
"""
LangGraph 节点日志工具模块
提供状态流转追踪和 LLM 输入输出打印功能
"""
from ..config import ENABLE_GRAPH_TRACE
from ..logger import debug, info
def log_state_change(node_name: str, state: dict, prefix: str = "进入"):
"""
记录状态变化日志
Args:
node_name: 节点名称
state: 当前状态
prefix: 日志前缀("进入""离开"
"""
from app.logger import info
messages = state.get("messages", [])
msg_count = len(messages)
last_msg = messages[-1] if messages else None
last_info = ""
if last_msg:
# 兼容 dict 和对象两种格式
if isinstance(last_msg, dict):
content_preview = str(last_msg.get("content", ""))[:10].replace("\n", " ")
msg_type = last_msg.get("type", "unknown")
else:
content_preview = str(last_msg.content)[:10].replace("\n", " ")
msg_type = getattr(last_msg, 'type', 'unknown')
last_info = f"{msg_type.upper()}: {content_preview}"
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")
def print_llm_input(prompt_value):
"""
RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息
Args:
prompt_value: ChatPromptValue 对象,包含格式化后的消息列表
Returns:
原样返回 prompt_value不影响链式调用
"""
if not ENABLE_GRAPH_TRACE:
return prompt_value
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
debug("\n" + "=" * 80)
debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:")
debug(f" 总消息数: {len(messages)}")
debug("-" * 80)
for i, msg in enumerate(messages):
content_preview = str(msg.content) # 完整输出
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
debug("\n" + "=" * 80 + "\n")
return prompt_value