This commit is contained in:
8
backend/app/__init__.py
Normal file
8
backend/app/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
AI Agent 应用模块
|
||||
"""
|
||||
|
||||
from ..agent import AIAgentService
|
||||
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
|
||||
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]
|
||||
7
backend/app/agent/__init__.py
Normal file
7
backend/app/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Agent 子模块
|
||||
"""
|
||||
|
||||
from .service import AIAgentService
|
||||
|
||||
__all__ = ["AIAgentService"]
|
||||
185
backend/app/agent/history.py
Normal file
185
backend/app/agent/history.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
历史对话查询模块
|
||||
利用 LangGraph 的 checkpointer 获取对话历史和摘要
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from ..logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
|
||||
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的所有线程摘要信息
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
线程列表,每个包含 thread_id, last_updated, summary, message_count
|
||||
"""
|
||||
try:
|
||||
# 查询 checkpoints 表获取用户的线程列表
|
||||
async with self.checkpointer.conn.cursor() as cur:
|
||||
# 在较新的 LangGraph 版本中,AsyncPostgresSaver 创建的 checkpoints 表
|
||||
# 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。
|
||||
# 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。
|
||||
# 另外,用户的 metadata 存储在 metadata JSONB 列中。
|
||||
query = """
|
||||
SELECT
|
||||
thread_id,
|
||||
MAX(checkpoint_id) as last_updated
|
||||
FROM checkpoints
|
||||
WHERE metadata->>'user_id' = %s
|
||||
GROUP BY thread_id
|
||||
ORDER BY last_updated DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
await cur.execute(query, (user_id, limit))
|
||||
rows = await cur.fetchall()
|
||||
|
||||
threads = []
|
||||
for row in rows:
|
||||
thread_id = row['thread_id']
|
||||
|
||||
# 获取该线程的状态
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict):
|
||||
messages = state.checkpoint.get("channel_values", {}).get("messages", [])
|
||||
|
||||
if messages:
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
threads.append({
|
||||
"thread_id": thread_id,
|
||||
# checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识
|
||||
"last_updated": row['last_updated'] if row['last_updated'] else "",
|
||||
"summary": summary,
|
||||
"message_count": message_count
|
||||
})
|
||||
|
||||
return threads
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
|
||||
return []
|
||||
|
||||
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取指定线程的完整消息历史
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
|
||||
Returns:
|
||||
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
|
||||
"""
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None:
|
||||
return []
|
||||
|
||||
messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else []
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# 转换 LangChain 消息对象为字典
|
||||
result = []
|
||||
for msg in messages:
|
||||
# 跳过 system 消息
|
||||
if hasattr(msg, 'type') and msg.type == "system":
|
||||
continue
|
||||
|
||||
if hasattr(msg, 'type'):
|
||||
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
|
||||
result.append({
|
||||
"role": role,
|
||||
"content": msg.content
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "unknown"))
|
||||
if role in ["human", "user"]:
|
||||
role = "user"
|
||||
elif role in ["ai", "assistant"]:
|
||||
role = "assistant"
|
||||
result.append({
|
||||
"role": role,
|
||||
"content": msg.get("content", "")
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程消息历史失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取线程摘要(用于历史列表展示)
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
|
||||
Returns:
|
||||
包含摘要信息的字典
|
||||
"""
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None or not state.values:
|
||||
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
|
||||
|
||||
messages = state.values.get("messages", [])
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
# 获取最后更新时间
|
||||
last_updated = ""
|
||||
if state.metadata and "created_at" in state.metadata:
|
||||
last_updated = state.metadata["created_at"].isoformat()
|
||||
|
||||
return {
|
||||
"thread_id": thread_id,
|
||||
"summary": summary,
|
||||
"message_count": message_count,
|
||||
"last_updated": last_updated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程摘要失败: {e}")
|
||||
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
|
||||
|
||||
def _extract_summary(self, messages: List) -> str:
|
||||
"""
|
||||
从消息列表中提取摘要
|
||||
|
||||
策略:
|
||||
1. 如果有 summarize 节点生成的 summary,优先使用
|
||||
2. 否则使用第一条用户消息的前 50 字
|
||||
"""
|
||||
# 查找是否有 summary 字段
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
|
||||
return msg.additional_kwargs['summary']
|
||||
elif isinstance(msg, dict) and msg.get('summary'):
|
||||
return msg['summary']
|
||||
|
||||
# 使用第一条用户消息作为摘要
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and msg.type == "human":
|
||||
content = msg.content
|
||||
return content[:50] + "..." if len(content) > 50 else content
|
||||
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
|
||||
content = msg.get("content", "")
|
||||
return content[:50] + "..." if len(content) > 50 else content
|
||||
|
||||
return "空对话"
|
||||
57
backend/app/agent/llm_factory.py
Normal file
57
backend/app/agent/llm_factory.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# app/llm_factory.py
|
||||
import os
|
||||
from ..config import ZHIPUAI_API_KEY, DEEPSEEK_API_KEY, VLLM_BASE_URL, LLAMACPP_API_KEY
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
class LLMFactory:
|
||||
@staticmethod
|
||||
def create_zhipu():
|
||||
api_key = ZHIPUAI_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set")
|
||||
return ChatZhipuAI(
|
||||
model="glm-4.7-flash",
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=120.0,
|
||||
max_retries=3,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_deepseek():
|
||||
api_key = DEEPSEEK_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set")
|
||||
return ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(api_key),
|
||||
model="deepseek-reasoner",
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_local():
|
||||
base_url = VLLM_BASE_URL
|
||||
return ChatOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=SecretStr(LLAMACPP_API_KEY),
|
||||
model="gemma-4-E4B-it",
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# 模型创建器映射
|
||||
CREATORS = {
|
||||
"local": create_local,
|
||||
"deepseek": create_deepseek,
|
||||
"zhipu": create_zhipu,
|
||||
}
|
||||
37
backend/app/agent/prompts.py
Normal file
37
backend/app/agent/prompts.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# app/prompts.py
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
|
||||
"""
|
||||
创建系统提示模板,可选择动态注入工具描述。
|
||||
"""
|
||||
tools_section = ""
|
||||
if tools:
|
||||
tool_descs = []
|
||||
for tool in tools:
|
||||
# 提取工具名称和描述的第一行
|
||||
name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
|
||||
desc = (tool.description or "").split('\n')[0]
|
||||
tool_descs.append(f"- {name}: {desc}")
|
||||
tools_section = "\n".join(tool_descs)
|
||||
|
||||
system_template = (
|
||||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
f"{tools_section}\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接。\n"
|
||||
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n"
|
||||
"3. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"4. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)
|
||||
|
||||
return ChatPromptTemplate.from_messages([
|
||||
("system", system_template),
|
||||
MessagesPlaceholder(variable_name="messages")
|
||||
])
|
||||
23
backend/app/agent/rag_initializer.py
Normal file
23
backend/app/agent/rag_initializer.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# app/rag_initializer.py
|
||||
from ..rag.tools import create_rag_tool_sync
|
||||
from rag_core import create_parent_retriever
|
||||
from ..logger import info, warning
|
||||
|
||||
async def init_rag_tool(local_llm_creator):
|
||||
"""初始化 RAG 工具,失败返回 None"""
|
||||
try:
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
retriever = create_parent_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=5,
|
||||
)
|
||||
rewrite_llm = local_llm_creator()
|
||||
rag_tool = create_rag_tool_sync(
|
||||
retriever, rewrite_llm,
|
||||
num_queries=3, rerank_top_n=5
|
||||
)
|
||||
info("✅ RAG 检索工具初始化成功")
|
||||
return rag_tool
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
||||
return None
|
||||
154
backend/app/agent/service.py
Normal file
154
backend/app/agent/service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
AI Agent 服务类 - 支持多模型动态切换
|
||||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
# 本地模块
|
||||
from ..graph.graph_builder import GraphBuilder, GraphContext
|
||||
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from .llm_factory import LLMFactory
|
||||
from .rag_initializer import init_rag_tool
|
||||
from ..logger import info, warning
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {}
|
||||
self.tools = AVAILABLE_TOOLS.copy()
|
||||
self.tools_by_name = TOOLS_BY_NAME.copy()
|
||||
|
||||
async def initialize(self):
|
||||
# 1. 初始化 RAG 工具(如果需要)
|
||||
rag_tool = await init_rag_tool(LLMFactory.create_local)
|
||||
if rag_tool:
|
||||
self.tools.append(rag_tool)
|
||||
self.tools_by_name[rag_tool.name] = rag_tool
|
||||
|
||||
# 2. 构建各模型的 Graph
|
||||
for name, creator in LLMFactory.CREATORS.items():
|
||||
try:
|
||||
info(f"🔄 初始化模型 '{name}'...")
|
||||
llm = creator()
|
||||
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[name] = graph
|
||||
info(f"✅ 模型 '{name}' 初始化成功")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
if model not in self.graphs:
|
||||
# 回退到第一个可用模型
|
||||
available = list(self.graphs.keys())
|
||||
if not available:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
model = available[0]
|
||||
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
result = await graph.ainvoke(input_state, config=config, context=context)
|
||||
|
||||
reply = result["messages"][-1].content
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
def _serialize_value(self, value):
|
||||
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
|
||||
if hasattr(value, 'content'):
|
||||
msg_type = getattr(value, 'type', 'message')
|
||||
return {
|
||||
"role": msg_type,
|
||||
"content": getattr(value, 'content', ''),
|
||||
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
|
||||
"tool_calls": getattr(value, 'tool_calls', [])
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""流式处理消息,返回异步生成器"""
|
||||
graph = self.graphs.get(model_name)
|
||||
if not graph:
|
||||
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
async for chunk in graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
context=context,
|
||||
stream_mode=["messages", "updates", "custom"],
|
||||
version="v2",
|
||||
subgraphs=True
|
||||
):
|
||||
chunk_type = chunk["type"]
|
||||
processed_event = {}
|
||||
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token,
|
||||
"metadata": metadata
|
||||
}
|
||||
elif chunk_type == "updates":
|
||||
updates_data = chunk["data"]
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
processed_event = {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
if "messages" in serialized_data:
|
||||
processed_event["messages"] = serialized_data["messages"]
|
||||
elif chunk_type == "custom":
|
||||
serialized_data = self._serialize_value(chunk["data"])
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
212
backend/app/backend.py
Normal file
212
backend/app/backend.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
|
||||
采用依赖注入模式,优雅管理资源生命周期
|
||||
"""
|
||||
|
||||
import os
|
||||
from .config import DB_URI, BACKEND_PORT
|
||||
import uuid
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from .agent.service import AIAgentService
|
||||
from .agent.history import ThreadHistoryService
|
||||
from .logger import info, error
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:创建并注入全局服务"""
|
||||
# 1. 创建数据库连接池并初始化表(仅 checkpointer)
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. 构建 AI Agent 服务
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
|
||||
# 3. 创建历史查询服务
|
||||
history_service = ThreadHistoryService(checkpointer)
|
||||
|
||||
# 4. 将服务实例存入 app.state
|
||||
app.state.agent_service = agent_service
|
||||
app.state.history_service = history_service
|
||||
|
||||
# 应用运行中...
|
||||
yield
|
||||
|
||||
# 5. 关闭时自动清理数据库连接(async with 负责)
|
||||
info("🛑 应用关闭,数据库连接池已释放")
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# CORS 中间件(允许前端跨域)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ========== 健康检查端点 ==========
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
|
||||
return {"status": "ok", "service": "ai-agent-backend"}
|
||||
|
||||
# ========== Pydantic 模型 ==========
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
thread_id: str | None = None
|
||||
model: str = "zhipu"
|
||||
user_id: str = "default_user"
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str
|
||||
thread_id: str
|
||||
model_used: str
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
# ========== 依赖注入函数 ==========
|
||||
def get_agent_service(request: Request) -> AIAgentService:
|
||||
"""从 app.state 中获取全局 AIAgentService 实例"""
|
||||
return request.app.state.agent_service
|
||||
|
||||
def get_history_service(request: Request) -> ThreadHistoryService:
|
||||
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
|
||||
return request.app.state.history_service
|
||||
|
||||
# ========== HTTP 端点 ==========
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(
|
||||
request: ChatRequest,
|
||||
agent_service: AIAgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""同步对话接口,支持模型选择"""
|
||||
if not request.message:
|
||||
raise HTTPException(status_code=400, detail="message required")
|
||||
|
||||
thread_id = request.thread_id or str(uuid.uuid4())
|
||||
result = await agent_service.process_message(
|
||||
request.message, thread_id, request.model, request.user_id
|
||||
)
|
||||
|
||||
# 提取 token 统计信息
|
||||
token_usage = result.get("token_usage", {})
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
elapsed_time = result.get("elapsed_time", 0.0)
|
||||
|
||||
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
|
||||
return ChatResponse(
|
||||
reply=result["reply"],
|
||||
thread_id=thread_id,
|
||||
model_used=actual_model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
elapsed_time=elapsed_time
|
||||
)
|
||||
|
||||
# ========== 历史查询接口 ==========
|
||||
@app.get("/threads")
|
||||
async def list_threads(
|
||||
user_id: str = Query("default_user", description="用户 ID"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
history_service: ThreadHistoryService = Depends(get_history_service)
|
||||
):
|
||||
"""获取当前用户的对话历史列表"""
|
||||
threads = await history_service.get_user_threads(user_id, limit)
|
||||
return {"threads": threads}
|
||||
|
||||
@app.get("/thread/{thread_id}/messages")
|
||||
async def get_thread_messages(
|
||||
thread_id: str,
|
||||
user_id: str = Query("default_user", description="用户 ID"),
|
||||
history_service: ThreadHistoryService = Depends(get_history_service)
|
||||
):
|
||||
"""获取指定线程的完整消息历史"""
|
||||
messages = await history_service.get_thread_messages(thread_id)
|
||||
return {"messages": messages}
|
||||
|
||||
@app.get("/thread/{thread_id}/summary")
|
||||
async def get_thread_summary(
|
||||
thread_id: str,
|
||||
user_id: str = Query("default_user", description="用户 ID"),
|
||||
history_service: ThreadHistoryService = Depends(get_history_service)
|
||||
):
|
||||
"""获取指定线程的摘要信息"""
|
||||
summary = await history_service.get_thread_summary(thread_id)
|
||||
return summary
|
||||
|
||||
# ========== 流式对话接口 ==========
|
||||
@app.post("/chat/stream")
|
||||
async def chat_stream_endpoint(
|
||||
request: ChatRequest,
|
||||
agent_service: AIAgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""流式对话接口(SSE)"""
|
||||
if not request.message:
|
||||
raise HTTPException(status_code=400, detail="message required")
|
||||
|
||||
thread_id = request.thread_id or str(uuid.uuid4())
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
async for chunk in agent_service.process_message_stream(
|
||||
request.message, thread_id, request.model, request.user_id
|
||||
):
|
||||
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
error(f"流式响应异常: {e}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
|
||||
}
|
||||
)
|
||||
|
||||
# ========== WebSocket 端点(可选) ==========
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
agent_service: AIAgentService = Depends(get_agent_service)
|
||||
):
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
message = data.get("message")
|
||||
thread_id = data.get("thread_id", str(uuid.uuid4()))
|
||||
model = data.get("model", "zhipu")
|
||||
user_id = data.get("user_id", "default_user")
|
||||
if not message:
|
||||
await websocket.send_json({"error": "missing message"})
|
||||
continue
|
||||
reply = await agent_service.process_message(message, thread_id, model, user_id)
|
||||
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突)
|
||||
port = int(BACKEND_PORT)
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
50
backend/app/config.py
Normal file
50
backend/app/config.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
环境变量集中管理模块
|
||||
所有配置项统一定义,避免散落在各个文件中
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
# ========== Graph 执行追踪配置 ==========
|
||||
# 是否启用 Graph 流转追踪(通过环境变量控制)
|
||||
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
|
||||
|
||||
# ========== 记忆提取配置 ==========
|
||||
# 记忆提取间隔:每 N 轮对话生成一次摘要
|
||||
MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
|
||||
|
||||
# ========== Mem0 记忆层配置 ==========
|
||||
# Qdrant 向量数据库地址
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
|
||||
|
||||
# ========== llm 配置 ==========
|
||||
# LLM 模型配置
|
||||
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
|
||||
LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key")
|
||||
|
||||
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
|
||||
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
|
||||
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key")
|
||||
|
||||
# ========== 后端服务配置 ==========
|
||||
# 数据库连接字符串
|
||||
DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
# 后端服务端口
|
||||
BACKEND_PORT = int(os.getenv("BACKEND_PORT", "8079"))
|
||||
|
||||
# ========== 日志配置 ==========
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
|
||||
|
||||
# ========== Reranker 服务配置 ==========
|
||||
LLAMACPP_RERANKER_URL = os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083")
|
||||
|
||||
# ========== 第三方 API 密钥 ==========
|
||||
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY", "")
|
||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
|
||||
8
backend/app/graph/__init__.py
Normal file
8
backend/app/graph/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Graph 子模块
|
||||
"""
|
||||
|
||||
from .graph_builder import GraphBuilder
|
||||
from .state import MessagesState, GraphContext
|
||||
|
||||
__all__ = ["GraphBuilder", "MessagesState", "GraphContext"]
|
||||
83
backend/app/graph/graph_builder.py
Normal file
83
backend/app/graph/graph_builder.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
LangGraph 状态图构建模块 - 精简版,仅负责组装图
|
||||
所有节点逻辑已拆分到独立模块
|
||||
"""
|
||||
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from .state import MessagesState, GraphContext
|
||||
from ..nodes import (
|
||||
should_continue,
|
||||
create_llm_call_node,
|
||||
create_tool_call_node,
|
||||
create_retrieve_memory_node,
|
||||
create_summarize_node,
|
||||
finalize_node,
|
||||
)
|
||||
from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from ..memory import Mem0Client
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
"""LangGraph 状态图构建器 - 仅负责组装图"""
|
||||
|
||||
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
|
||||
"""
|
||||
初始化构建器
|
||||
|
||||
Args:
|
||||
llm: 大语言模型实例
|
||||
tools: 工具列表
|
||||
tools_by_name: 名称到工具函数的映射
|
||||
"""
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
self.tools_by_name = tools_by_name
|
||||
|
||||
# ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
|
||||
self.mem0_client = Mem0Client(llm)
|
||||
|
||||
def build(self) -> StateGraph:
|
||||
"""
|
||||
构建未编译的状态图
|
||||
|
||||
Returns:
|
||||
StateGraph 实例
|
||||
"""
|
||||
# 注入全局客户端
|
||||
set_mem0_client(self.mem0_client)
|
||||
|
||||
builder = StateGraph(MessagesState, context_schema=GraphContext)
|
||||
|
||||
# ⭐ 通过工厂函数创建节点(依赖注入)
|
||||
retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
|
||||
llm_call_node = create_llm_call_node(self.llm, self.tools)
|
||||
tool_call_node = create_tool_call_node(self.tools_by_name)
|
||||
summarize_node = create_summarize_node(self.mem0_client)
|
||||
|
||||
# 添加节点
|
||||
builder.add_node("retrieve_memory", retrieve_memory_node)
|
||||
builder.add_node("memory_trigger", memory_trigger_node)
|
||||
builder.add_node("llm_call", llm_call_node)
|
||||
builder.add_node("tool_node", tool_call_node)
|
||||
builder.add_node("summarize", summarize_node)
|
||||
builder.add_node("finalize", finalize_node)
|
||||
|
||||
# 添加边
|
||||
builder.add_edge(START, "retrieve_memory")
|
||||
builder.add_edge("retrieve_memory", "memory_trigger")
|
||||
builder.add_edge("memory_trigger", "llm_call")
|
||||
builder.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_continue,
|
||||
{
|
||||
"tool_node": "tool_node",
|
||||
"summarize": "summarize",
|
||||
"finalize": "finalize"
|
||||
}
|
||||
)
|
||||
builder.add_edge("tool_node", "llm_call")
|
||||
builder.add_edge("summarize", "finalize")
|
||||
builder.add_edge("finalize", END)
|
||||
|
||||
return builder
|
||||
95
backend/app/graph/graph_tools.py
Normal file
95
backend/app/graph/graph_tools.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
工具定义模块 - 纯函数工具,无依赖 AIAgent 类
|
||||
"""
|
||||
|
||||
# 标准库
|
||||
from pathlib import Path
|
||||
|
||||
# 第三方库
|
||||
import pandas as pd
|
||||
import pypdf
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain_core.tools import tool
|
||||
|
||||
def _file_allow_check(filename: str) -> Path:
|
||||
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
|
||||
allowed_dir = Path("./user_docs").resolve()
|
||||
allowed_dir.mkdir(exist_ok=True)
|
||||
|
||||
file_path = (allowed_dir / filename).resolve()
|
||||
if not str(file_path).startswith(str(allowed_dir)):
|
||||
raise ValueError("错误:非法文件路径。")
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
|
||||
|
||||
return file_path
|
||||
|
||||
@tool
|
||||
def get_current_temperature(location: str) -> str:
|
||||
"""获取指定地点的当前温度。"""
|
||||
return f'当前{location}的温度为25℃'
|
||||
|
||||
@tool
|
||||
def read_local_file(filename: str) -> str:
|
||||
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
|
||||
except Exception as e:
|
||||
return f"读取文件时出错:{str(e)}"
|
||||
|
||||
@tool
|
||||
def read_pdf_summary(filename: str) -> str:
|
||||
"""读取PDF文件并返回内容文本摘要。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
text = ""
|
||||
with open(file_path, 'rb') as f:
|
||||
reader = pypdf.PdfReader(f)
|
||||
for page in reader.pages[:3]:
|
||||
text += page.extract_text()
|
||||
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
|
||||
except Exception as e:
|
||||
return f"读取PDF出错:{e}"
|
||||
|
||||
@tool
|
||||
def read_excel_as_markdown(filename: str) -> str:
|
||||
"""读取Excel文件,并将其主要数据转换为Markdown表格格式。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
df = pd.read_excel(file_path)
|
||||
markdown_table = df.head(10).to_markdown(index=False)
|
||||
return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}"
|
||||
except Exception as e:
|
||||
return f"读取Excel出错:{e}"
|
||||
|
||||
@tool
|
||||
def fetch_webpage_content(url: str) -> str:
|
||||
"""抓取给定URL的网页正文内容,并返回清晰的纯文本。"""
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = '\n'.join(chunk for chunk in chunks if chunk)
|
||||
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
|
||||
except Exception as e:
|
||||
return f"抓取网页时出错:{str(e)}"
|
||||
|
||||
# 工具列表和映射(全局常量)
|
||||
AVAILABLE_TOOLS = [
|
||||
get_current_temperature,
|
||||
read_local_file,
|
||||
fetch_webpage_content,
|
||||
read_pdf_summary,
|
||||
read_excel_as_markdown
|
||||
]
|
||||
TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS}
|
||||
76
backend/app/graph/retrieve_memory.py
Normal file
76
backend/app/graph/retrieve_memory.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
记忆检索节点模块
|
||||
负责从 Mem0 检索相关长期记忆
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
# 本地模块
|
||||
from .state import MessagesState
|
||||
from ..memory.mem0_client import Mem0Client
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import debug
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆检索节点
|
||||
|
||||
Args:
|
||||
mem0_client: Mem0 客户端实例
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆检索节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
包含 memory_context 的状态更新
|
||||
"""
|
||||
log_state_change("retrieve_memory", state, "进入")
|
||||
|
||||
# 从 metadata 中获取 user_id
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
|
||||
# 兼容 dict 和对象两种消息格式
|
||||
last_msg = state["messages"][-1]
|
||||
if isinstance(last_msg, dict):
|
||||
query = str(last_msg.get("content", ""))
|
||||
else:
|
||||
query = str(last_msg.content)
|
||||
memory_text_parts = []
|
||||
|
||||
# 确保 Mem0 已初始化(懒加载)
|
||||
if not mem0_client._initialized:
|
||||
await mem0_client.initialize()
|
||||
|
||||
if mem0_client.mem0:
|
||||
try:
|
||||
# 异步调用 Mem0 语义检索
|
||||
facts = await mem0_client.search_memories(query, user_id=user_id, limit=5)
|
||||
|
||||
if facts:
|
||||
memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts))
|
||||
else:
|
||||
debug("🔍 [记忆检索] 未找到相关记忆")
|
||||
except Exception as e:
|
||||
from app.logger import warning
|
||||
warning(f"⚠️ Mem0 检索失败: {e}")
|
||||
else:
|
||||
from app.logger import warning
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆检索")
|
||||
|
||||
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
||||
result = {"memory_context": memory_context}
|
||||
log_state_change("retrieve_memory", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
return retrieve_memory
|
||||
25
backend/app/graph/state.py
Normal file
25
backend/app/graph/state.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
LangGraph 状态定义模块
|
||||
包含 MessagesState 和 GraphContext
|
||||
"""
|
||||
|
||||
import operator
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from dataclasses import dataclass
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
class MessagesState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
llm_calls: int
|
||||
memory_context: str
|
||||
last_token_usage: dict # 本次调用的 token 使用详情
|
||||
last_elapsed_time: float # 本次调用耗时(秒)
|
||||
turns_since_last_summary: int # 距离上次生成摘要的轮数
|
||||
|
||||
@dataclass
|
||||
class GraphContext:
|
||||
"""图执行上下文"""
|
||||
user_id: str
|
||||
# 可扩展更多上下文信息
|
||||
56
backend/app/logger.py
Normal file
56
backend/app/logger.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
统一的日志模块 - 基于环境变量控制日志级别
|
||||
类似 C# 的条件编译效果,开发时打印详细调试信息,生产环境只输出关键信息
|
||||
"""
|
||||
|
||||
import os
|
||||
from .config import LOG_LEVEL, DEBUG
|
||||
import logging
|
||||
from typing import Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 先加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 从环境变量读取日志级别,默认 INFO
|
||||
|
||||
|
||||
# 根据环境变量控制是否显示详细调试信息
|
||||
DEBUG_MODE = DEBUG
|
||||
|
||||
# 创建统一的日志器
|
||||
logger = logging.getLogger("ai_agent")
|
||||
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
||||
|
||||
# 避免重复添加 handler
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
# 重要:handler 也需要设置级别,否则可能继承根 logger 的级别
|
||||
handler.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def debug(msg: Any, *args, **kwargs):
|
||||
"""调试日志,仅在 DEBUG 环境变量为 true 时打印"""
|
||||
if DEBUG_MODE:
|
||||
logger.debug(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(msg: Any, *args, **kwargs):
|
||||
"""信息日志"""
|
||||
logger.info(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(msg: Any, *args, **kwargs):
|
||||
"""警告日志"""
|
||||
logger.warning(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(msg: Any, *args, **kwargs):
|
||||
"""错误日志"""
|
||||
logger.error(msg, *args, **kwargs)
|
||||
7
backend/app/memory/__init__.py
Normal file
7
backend/app/memory/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Mem0 记忆层模块
|
||||
"""
|
||||
|
||||
from .mem0_client import Mem0Client
|
||||
|
||||
__all__ = ["Mem0Client"]
|
||||
146
backend/app/memory/mem0_client.py
Normal file
146
backend/app/memory/mem0_client.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from ..config import LLM_API_KEY
|
||||
from ..config import VLLM_BASE_URL
|
||||
import time
|
||||
"""
|
||||
Mem0 记忆层客户端封装模块
|
||||
负责 Mem0 的初始化、检索和存储
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, List, Dict
|
||||
from mem0 import AsyncMemory
|
||||
|
||||
from ..config import (
|
||||
QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY,
|
||||
VLLM_BASE_URL, LLM_API_KEY,
|
||||
LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||
)
|
||||
from ..logger import info, warning, error
|
||||
|
||||
class Mem0Client:
|
||||
"""Mem0 异步客户端封装类"""
|
||||
|
||||
def __init__(self, llm_instance):
|
||||
"""
|
||||
初始化 Mem0 客户端
|
||||
|
||||
Args:
|
||||
llm_instance: LangChain LLM 实例(用于事实提取)
|
||||
"""
|
||||
self.llm = llm_instance
|
||||
self.mem0: Optional[AsyncMemory] = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化 Mem0 客户端,并进行实际连接测试"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Mem0 配置
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"url": QDRANT_URL, # 直接使用完整 URL
|
||||
"api_key": QDRANT_API_KEY,
|
||||
"collection_name": QDRANT_COLLECTION_NAME,
|
||||
"embedding_model_dims": 1024,
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "LLM_MODEL",
|
||||
"api_key": LLM_API_KEY,
|
||||
"openai_base_url": VLLM_BASE_URL,
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "Qwen3-Embedding-0.6B-Q8_0",
|
||||
"api_key": LLAMACPP_API_KEY,
|
||||
"openai_base_url": LLAMACPP_EMBEDDING_URL,
|
||||
},
|
||||
},
|
||||
"version": "v1.1"
|
||||
}
|
||||
|
||||
self.mem0 = AsyncMemory.from_config(config)
|
||||
info("✅ Mem0 配置加载成功,开始连接测试...")
|
||||
|
||||
# 实际连接测试:调用一次 search 确保 Qdrant 和 Embedding 都可达
|
||||
await asyncio.wait_for(
|
||||
self.mem0.search("ping", user_id="test", limit=1),
|
||||
timeout=60.0
|
||||
)
|
||||
info("✅ Mem0 实际连接测试成功,初始化完成")
|
||||
self._initialized = True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
error("❌ Mem0 连接测试超时 (10s),请检查 Qdrant 或 Embedding 服务响应")
|
||||
self.mem0 = None
|
||||
self._initialized = False
|
||||
except Exception as e:
|
||||
error(f"❌ Mem0 初始化或连接测试失败: {e}")
|
||||
import traceback
|
||||
error(f"详细错误信息:\n{traceback.format_exc()}")
|
||||
self.mem0 = None
|
||||
self._initialized = False
|
||||
|
||||
async def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[str]:
|
||||
"""
|
||||
检索相关记忆
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
user_id: 用户 ID
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
List[str]: 记忆事实列表
|
||||
"""
|
||||
if not self.mem0:
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆检索")
|
||||
return []
|
||||
|
||||
try:
|
||||
memories = await asyncio.wait_for(
|
||||
self.mem0.search(query, user_id=user_id, limit=limit),
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if memories and "results" in memories:
|
||||
facts = [m["memory"] for m in memories["results"] if m.get("memory")]
|
||||
if facts:
|
||||
info(f"🔍 [记忆检索] Mem0 返回 {len(facts)} 条记忆")
|
||||
return facts
|
||||
|
||||
info("🔍 [记忆检索] 未找到相关记忆")
|
||||
return []
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
warning("⚠️ Mem0 检索超时 (30s),跳过本次记忆检索")
|
||||
return []
|
||||
except Exception as e:
|
||||
warning(f"⚠️ Mem0 检索失败: {e}")
|
||||
return []
|
||||
|
||||
async def add_memories(self, messages, user_id):
|
||||
if not self.mem0:
|
||||
return False
|
||||
try:
|
||||
start = time.time()
|
||||
info(f"📝 开始 Mem0 add,消息数: {len(messages)}")
|
||||
await asyncio.wait_for(
|
||||
self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}),
|
||||
timeout=60.0
|
||||
)
|
||||
info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s")
|
||||
return False
|
||||
19
backend/app/nodes/__init__.py
Normal file
19
backend/app/nodes/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
节点模块 - 导出所有 LangGraph 节点函数
|
||||
"""
|
||||
|
||||
from .router import should_continue
|
||||
from .llm_call import create_llm_call_node
|
||||
from .tool_call import create_tool_call_node
|
||||
from ..graph.retrieve_memory import create_retrieve_memory_node
|
||||
from .summarize import create_summarize_node
|
||||
from .finalize import finalize_node
|
||||
|
||||
__all__ = [
|
||||
"should_continue",
|
||||
"create_llm_call_node",
|
||||
"create_tool_call_node",
|
||||
"create_retrieve_memory_node",
|
||||
"create_summarize_node",
|
||||
"finalize_node",
|
||||
]
|
||||
45
backend/app/nodes/finalize.py
Normal file
45
backend/app/nodes/finalize.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
完成事件节点模块
|
||||
负责发送完成事件,包含token使用情况和耗时信息
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from ..graph.state import MessagesState
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import info, error
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
完成事件节点 - 发送完成事件,包含token使用情况和耗时信息
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
空字典(完成节点,无状态更新)
|
||||
"""
|
||||
log_state_change("finalize", state, "进入")
|
||||
|
||||
try:
|
||||
# 获取流式写入器并发送完成事件
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": "done",
|
||||
"token_usage": state.get("last_token_usage", {}),
|
||||
"elapsed_time": state.get("last_elapsed_time", 0.0)
|
||||
}
|
||||
})
|
||||
info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息")
|
||||
except Exception as e:
|
||||
error(f"❌ [完成事件] 发送完成事件时发生异常: {e}")
|
||||
|
||||
log_state_change("finalize", state, "离开")
|
||||
return {}
|
||||
150
backend/app/nodes/llm_call.py
Normal file
150
backend/app/nodes/llm_call.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
LLM 调用节点模块
|
||||
负责调用大语言模型并处理响应
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from ..graph.state import MessagesState
|
||||
from ..agent.prompts import create_system_prompt
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import debug, info, error
|
||||
|
||||
def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
|
||||
Args:
|
||||
llm: LangChain LLM 实例
|
||||
tools: 工具列表
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
prompt = create_system_prompt(tools)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
chain = prompt | llm_with_tools
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
|
||||
|
||||
Returns:
|
||||
更新后的状态字典
|
||||
"""
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
"messages": state["messages"],
|
||||
"memory_context": memory_context
|
||||
},
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
for chunk in chunks[1:]:
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
if 'token_usage' in meta:
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
result = {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||||
"last_token_usage": token_usage,
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器
|
||||
}
|
||||
|
||||
log_state_change("llm_call", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug("="*80 + "\n")
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
)
|
||||
error_result = {
|
||||
"messages": [error_response],
|
||||
"llm_calls": state.get('llm_calls', 0),
|
||||
"last_token_usage": {},
|
||||
"last_elapsed_time": elapsed_time,
|
||||
"turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器
|
||||
}
|
||||
|
||||
log_state_change("llm_call", state, "离开(异常)")
|
||||
return error_result
|
||||
|
||||
return call_llm
|
||||
38
backend/app/nodes/memory_trigger.py
Normal file
38
backend/app/nodes/memory_trigger.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Any, Dict
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from ..graph.state import MessagesState
|
||||
from ..memory.mem0_client import Mem0Client
|
||||
from ..logger import info
|
||||
|
||||
# 全局变量,在 GraphBuilder 中注入
|
||||
_mem0_client: Mem0Client = None
|
||||
|
||||
def set_mem0_client(client: Mem0Client):
|
||||
global _mem0_client
|
||||
_mem0_client = client
|
||||
|
||||
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
|
||||
if _mem0_client is None:
|
||||
return {}
|
||||
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return {}
|
||||
|
||||
last_msg = messages[-1]
|
||||
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
|
||||
|
||||
# 触发词(可自行扩展)
|
||||
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
|
||||
if any(word in content for word in trigger_words):
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
# 确保 Mem0 已初始化
|
||||
if not _mem0_client._initialized:
|
||||
await _mem0_client.initialize()
|
||||
# 将用户消息作为事实来源提交给 Mem0
|
||||
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
|
||||
mem0_messages = [{"role": "user", "content": content}]
|
||||
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
|
||||
|
||||
return {} # 不修改状态
|
||||
48
backend/app/nodes/router.py
Normal file
48
backend/app/nodes/router.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
路由决策节点
|
||||
根据当前状态决定下一步走向
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from ..config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
|
||||
from ..graph.state import MessagesState
|
||||
from ..logger import info
|
||||
|
||||
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']:
|
||||
"""
|
||||
决定下一步:工具调用、生成摘要还是结束
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称
|
||||
"""
|
||||
last_message = state["messages"][-1]
|
||||
|
||||
# 1. 如果需要调用工具,优先进入工具节点
|
||||
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'")
|
||||
return 'tool_node'
|
||||
|
||||
# 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值
|
||||
if isinstance(last_message, AIMessage):
|
||||
turns = state.get("turns_since_last_summary", 0)
|
||||
if turns >= MEMORY_SUMMARIZE_INTERVAL:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'")
|
||||
return 'summarize'
|
||||
else:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
|
||||
return 'finalize'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
|
||||
return 'finalize'
|
||||
87
backend/app/nodes/summarize.py
Normal file
87
backend/app/nodes/summarize.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
记忆存储节点模块
|
||||
负责将对话历史提交给 Mem0 进行事实提取和存储
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
# 本地模块
|
||||
from ..graph.state import MessagesState
|
||||
from ..memory.mem0_client import Mem0Client
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import debug, info, error, warning
|
||||
|
||||
def create_summarize_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆存储节点
|
||||
|
||||
Args:
|
||||
mem0_client: Mem0 客户端实例
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆存储节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
重置计数器的状态更新
|
||||
"""
|
||||
log_state_change("summarize", state, "进入")
|
||||
|
||||
messages = state["messages"]
|
||||
if len(messages) < 4:
|
||||
debug("📝 [记忆添加] 对话过短,跳过")
|
||||
return {"turns_since_last_summary": 0}
|
||||
|
||||
# 从 metadata 中获取 user_id
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
|
||||
# 确保 Mem0 已初始化(懒加载)
|
||||
if not mem0_client._initialized:
|
||||
await mem0_client.initialize()
|
||||
|
||||
# 将整个对话历史转换为 Mem0 需要的消息格式
|
||||
mem0_messages = []
|
||||
for msg in messages:
|
||||
# 兼容 dict 和对象两种格式
|
||||
if isinstance(msg, dict):
|
||||
msg_type = msg.get("type", "")
|
||||
msg_content = msg.get("content", "")
|
||||
else:
|
||||
msg_type = getattr(msg, 'type', '')
|
||||
msg_content = getattr(msg, 'content', '')
|
||||
|
||||
if msg_type == "human":
|
||||
mem0_messages.append({"role": "user", "content": msg_content})
|
||||
elif msg_type == "ai":
|
||||
mem0_messages.append({"role": "assistant", "content": msg_content})
|
||||
elif msg_type == "tool":
|
||||
mem0_messages.append({"role": "system", "content": f"[工具返回] {msg_content}"})
|
||||
|
||||
if mem0_client.mem0:
|
||||
try:
|
||||
# 异步调用 Mem0 自动提取并存储事实
|
||||
success = await mem0_client.add_memories(
|
||||
mem0_messages,
|
||||
user_id=user_id
|
||||
)
|
||||
if success:
|
||||
info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取")
|
||||
except Exception as e:
|
||||
error(f"❌ Mem0 记忆添加失败: {e}")
|
||||
else:
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆添加")
|
||||
|
||||
log_state_change("summarize", state, "离开")
|
||||
return {"turns_since_last_summary": 0}
|
||||
|
||||
return summarize_conversation
|
||||
101
backend/app/nodes/tool_call.py
Normal file
101
backend/app/nodes/tool_call.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
工具执行节点模块
|
||||
负责执行 AI 调用的工具函数
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from ..graph.state import MessagesState
|
||||
from ..utils.logging import log_state_change
|
||||
from ..logger import debug, info
|
||||
|
||||
def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
"""
|
||||
工厂函数:创建工具执行节点
|
||||
|
||||
Args:
|
||||
tools_by_name: 名称到工具函数的映射字典
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def call_tools(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
包含 ToolMessage 的状态更新
|
||||
"""
|
||||
log_state_change("tool_node", state, "进入")
|
||||
|
||||
last_message = state['messages'][-1]
|
||||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||||
log_state_change("tool_node", state, "离开(无工具调用)")
|
||||
return {"messages": []}
|
||||
|
||||
results = []
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具")
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
tool_id = tool_call["id"]
|
||||
tool_func = tools_by_name.get(tool_name)
|
||||
|
||||
debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}")
|
||||
|
||||
if tool_func is None:
|
||||
err_msg = f"Tool {tool_name} not found"
|
||||
debug(f" └─ ❌ {err_msg}")
|
||||
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
|
||||
continue
|
||||
|
||||
# 获取流式写入器并发送工具调用开始事件
|
||||
writer = get_stream_writer()
|
||||
writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}})
|
||||
|
||||
try:
|
||||
# 修复闭包问题:将变量作为默认参数传入 lambda
|
||||
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
||||
if hasattr(tool_func, 'ainvoke'):
|
||||
observation = await tool_func.ainvoke(tool_args)
|
||||
else:
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
|
||||
)
|
||||
|
||||
# 字符打印
|
||||
result_preview = str(observation).replace("\n", " ")
|
||||
debug(f" └─ ✅ 结果: {result_preview}")
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
|
||||
# 发送工具调用完成事件
|
||||
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}})
|
||||
except Exception as e:
|
||||
debug(f" └─ ❌ 异常: {e}")
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
|
||||
# 发送工具调用失败事件
|
||||
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}})
|
||||
|
||||
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
|
||||
|
||||
result = {"messages": results}
|
||||
log_state_change("tool_node", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
return call_tools
|
||||
391
backend/app/rag/README.md
Normal file
391
backend/app/rag/README.md
Normal file
@@ -0,0 +1,391 @@
|
||||
# 在线 RAG 检索与生成系统 (Online RAG Retriever)
|
||||
|
||||
该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。
|
||||
|
||||
## 🎯 核心架构
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 组件 | 技术选型 | 版本 | 说明 |
|
||||
|:-----|:---------|:-----|:-----|
|
||||
| **基础检索** | `Qdrant` | 1.17+ | HNSW 稠密向量检索 |
|
||||
| **混合检索** | `Qdrant` + `BM25` | 内置 | 稠密 + 稀疏向量融合 |
|
||||
| **查询改写** | `LangChain` | 内置 | `MultiQueryGenerator` 多路改写 |
|
||||
| **RRF 融合** | 自实现 | - | `reciprocal_rank_fusion` 倒数排名融合 |
|
||||
| **重排序** | `llama.cpp` | 本地服务 | OpenAI 兼容 Rerank API |
|
||||
| **编排框架** | `asyncio` | Python 3.10+ | 异步并行检索 |
|
||||
|
||||
### 检索流水线
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 用户提问 │
|
||||
└──────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ MultiQueryGenerator │
|
||||
│ 多路查询改写 (num_queries=3) │
|
||||
│ "如何申请项目资金?" → ["项目资金申请流程", "经费申请步骤"] │
|
||||
└──────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 并行检索 (asyncio.gather) │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ 查询1 检索 │ │ 查询2 检索 │ │ 查询3 检索 │ │
|
||||
│ │ (k=20) │ │ (k=20) │ │ (k=20) │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
└──────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ reciprocal_rank_fusion (RRF) │
|
||||
│ RRF_score(d) = Σ 1/(k + rank_q(d)) (k=60) │
|
||||
│ 融合多路检索结果,去重排序 │
|
||||
└──────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ LLaMaCPPReranker │
|
||||
│ 远程重排序 (bge-reranker-v2-m3) │
|
||||
│ 返回 Top-N (top_n=5) 最相关文档 │
|
||||
└──────────────────────┬──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 返回增强上下文 │
|
||||
│ format_context() → 格式化输出 │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 技术特性
|
||||
|
||||
- ✅ **多路查询改写**:通过 LLM 将单一问题改写为多个不同角度的查询
|
||||
- ✅ **RRF 融合算法**:Reciprocal Rank Fusion,无需评分归一化的融合算法
|
||||
- ✅ **远程重排序**:使用 llama.cpp 服务的 OpenAI 兼容 Rerank API
|
||||
- ✅ **混合检索支持**:稠密向量 + BM25 稀疏向量混合检索
|
||||
- ✅ **异步并行检索**:多路查询并行执行,提升检索速度
|
||||
- ✅ **优雅降级**:重排序器不可用时自动降级到基础融合结果
|
||||
|
||||
## 📂 架构与文件结构
|
||||
|
||||
```
|
||||
app/rag/
|
||||
├── __init__.py
|
||||
├── retriever.py # Qdrant 基础检索与混合检索
|
||||
├── reranker.py # llama.cpp 远程重排序器
|
||||
├── query_transform.py # 多路查询改写生成器
|
||||
├── fusion.py # RRF 倒数排名融合算法
|
||||
├── pipeline.py # RAG 流水线编排
|
||||
└── tools.py # LangChain Tool 封装
|
||||
```
|
||||
|
||||
## 🎯 演进路线与算法详解 (Roadmap)
|
||||
|
||||
### Level 1: 基础向量搜索 (Basic Similarity Search)
|
||||
|
||||
- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。
|
||||
- **优缺点**: 速度极快。但只能捕捉"语义相似",如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生"幻觉"匹配)。
|
||||
- **实现指南**:
|
||||
- 使用 `rag_indexer.embedders.LlamaCppEmbedder` 作为嵌入模型
|
||||
- 使用 `app/rag/retriever.py` 中的 `create_base_retriever` 创建基础检索器
|
||||
- 配置 `search_kwargs={"k": 20}` 进行初步召回
|
||||
|
||||
```python
|
||||
from app.rag.retriever import create_base_retriever
|
||||
|
||||
retriever = create_base_retriever(
|
||||
collection_name="rag_documents",
|
||||
embeddings=embeddings,
|
||||
search_kwargs={"k": 20}
|
||||
)
|
||||
docs = retriever.invoke("什么是 RAG?")
|
||||
```
|
||||
|
||||
### Level 2: 混合检索与重排序 (Hybrid Search + Reranker)
|
||||
|
||||
混合检索旨在结合向量的"语义泛化"与关键词的"精准匹配",随后利用重排序模型过滤噪声。
|
||||
|
||||
**1. 基础召回 (混合检索)**
|
||||
|
||||
- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。
|
||||
- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_hybrid_retriever` 函数,配置 `dense_k=10` 和 `sparse_k=10`,总召回 20 条结果。
|
||||
|
||||
```python
|
||||
from app.rag.retriever import create_hybrid_retriever
|
||||
|
||||
retriever = create_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
embeddings=embeddings,
|
||||
dense_k=10,
|
||||
sparse_k=10,
|
||||
score_threshold=0.3
|
||||
)
|
||||
```
|
||||
|
||||
**2. 二次精排 (Cross-Encoder)**
|
||||
|
||||
- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将"用户问题 + 检索到的单例文档"拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。
|
||||
- **实现指南**:
|
||||
- 使用 `app/rag/reranker.py` 中的 `LLaMaCPPReranker` 类,加载 `bge-reranker-v2-m3` 模型
|
||||
- 设置 `top_n=5` 保留最相关的 5 条结果
|
||||
|
||||
```python
|
||||
from app.rag.reranker import LLaMaCPPReranker
|
||||
|
||||
reranker = LLaMaCPPReranker(
|
||||
base_url="http://127.0.0.1:8083",
|
||||
api_key="your-api-key",
|
||||
top_n=5
|
||||
)
|
||||
sorted_docs = reranker.compress_documents(documents, query)
|
||||
```
|
||||
|
||||
### Level 3: RAG-Fusion (多路改写与倒数排名融合)
|
||||
|
||||
RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。
|
||||
|
||||
**1. 多路查询改写**
|
||||
|
||||
- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。
|
||||
- **实现指南**: 使用 `app/rag/query_transform.py` 中的 `MultiQueryGenerator` 类,配置 `num_queries=3` 生成 3 个不同角度的查询。
|
||||
|
||||
```python
|
||||
from app.rag.query_transform import MultiQueryGenerator
|
||||
|
||||
generator = MultiQueryGenerator(llm=llm, num_queries=3)
|
||||
queries = await generator.agenerate("如何申请项目资金?")
|
||||
# 返回:["如何申请项目资金?", "项目资金申请流程是什么?", "申请项目经费需要哪些步骤?"]
|
||||
```
|
||||
|
||||
**2. 倒数排名融合 (RRF)**
|
||||
|
||||
- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 `RRF_score(d) = Σ 1/(k + rank_q(d))`,有效避免某一极端检索结果主导全局。
|
||||
- **实现指南**: 使用 `app/rag/fusion.py` 中的 `reciprocal_rank_fusion` 函数,配置 `k=60` 实现倒数排名融合。
|
||||
|
||||
```python
|
||||
from app.rag.fusion import reciprocal_rank_fusion
|
||||
|
||||
# 多个查询的检索结果
|
||||
doc_lists = [result1, result2, result3]
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists, k=60)
|
||||
```
|
||||
|
||||
### Level 4: Agentic RAG / Self-RAG (智能体与自我反思)
|
||||
|
||||
- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:"这是闲聊?还是需要查知识库?"。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。
|
||||
- **实现指南**: 使用 `app/rag/tools.py` 中的 `search_knowledge_base` 工具,将其绑定到 LangGraph 状态机中。
|
||||
|
||||
- **示意图**:
|
||||
|
||||
```
|
||||
┌──────────┐ ┌──────────────┐ ┌──────────┐ ┌────────
|
||||
│ User │────>│ LangGraph │────>│ RAG_Tool │────>│ Qdrant │
|
||||
│ │ │ Agent │ │ │ │ │
|
||||
│ "公司报 │ │ 思考: 这是 │ │ ToolCall │ │ RAG- │
|
||||
│ 销流程?"│ │ 内部规章问题 │ │ search_ │ │ Fusion │
|
||||
│ │ │ 需要查资料 │ │ knowledge│ │ & 混合 │
|
||||
│ │<────│ 资料充分, │<────│ 返回最相 │<────│ 检索 │
|
||||
│ "根据知 │ │ 开始撰写回答 │ │ 关5条规定 │ │ Cross- │
|
||||
│ 识库规定 │ │ │ │ │ │ Encoder│
|
||||
│ ..." │ │ │ │ │ │ 重排 │
|
||||
└────────── └────────────── └──────────┘ └────────┘
|
||||
```
|
||||
|
||||
### Level 5: GraphRAG 集成 (基于图和关系的 RAG)
|
||||
|
||||
- **核心原理**: 结合知识图谱的结构化关系和向量检索的语义相似度,解决跨文档复杂关系推理问题。
|
||||
- **实现指南**:
|
||||
- 使用 `langchain_community.graphs` 模块构建知识图谱
|
||||
- 配置本地大模型(如 `Gemma-4-E4B`)用于实体关系抽取
|
||||
- 实现混合检索逻辑,结合向量相似度和图路径分析
|
||||
|
||||
```python
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
||||
|
||||
# 实体关系抽取
|
||||
transformer = LLMGraphTransformer(llm=local_llm)
|
||||
graph_documents = transformer.convert_to_graph_documents(documents)
|
||||
|
||||
# 存储到图数据库
|
||||
graph = Neo4jGraph(url="bolt://localhost:7687")
|
||||
graph.add_graph_documents(graph_documents)
|
||||
```
|
||||
|
||||
## 🔧 核心组件详解
|
||||
|
||||
### 1. 检索器 (retriever.py)
|
||||
|
||||
提供基于 Qdrant 的向量检索能力。
|
||||
|
||||
**基础检索器**:
|
||||
```python
|
||||
from app.rag.retriever import create_base_retriever
|
||||
|
||||
retriever = create_base_retriever(
|
||||
collection_name="rag_documents",
|
||||
embeddings=embeddings,
|
||||
search_kwargs={"k": 20}
|
||||
)
|
||||
```
|
||||
|
||||
**混合检索器**:
|
||||
```python
|
||||
from app.rag.retriever import create_hybrid_retriever
|
||||
|
||||
retriever = create_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
embeddings=embeddings,
|
||||
dense_k=10,
|
||||
sparse_k=10,
|
||||
score_threshold=0.3
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 多路查询改写 (query_transform.py)
|
||||
|
||||
通过 LLM 将用户问题改写为多个不同版本,扩大搜索面。
|
||||
|
||||
```python
|
||||
from app.rag.query_transform import MultiQueryGenerator
|
||||
|
||||
generator = MultiQueryGenerator(llm=llm, num_queries=3)
|
||||
queries = await generator.agenerate("如何申请项目资金?")
|
||||
```
|
||||
|
||||
### 3. RRF 融合算法 (fusion.py)
|
||||
|
||||
Reciprocal Rank Fusion 算法,公式:`RRF_score(d) = Σ 1/(k + rank_q(d))`
|
||||
|
||||
```python
|
||||
from app.rag.fusion import reciprocal_rank_fusion
|
||||
|
||||
# 多个查询的检索结果
|
||||
doc_lists = [result1, result2, result3]
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists, k=60)
|
||||
```
|
||||
|
||||
### 4. 重排序器 (reranker.py)
|
||||
|
||||
使用 llama.cpp 服务的 OpenAI 兼容 Rerank API 对检索结果重排序。
|
||||
|
||||
```python
|
||||
from app.rag.reranker import LLaMaCPPReranker
|
||||
|
||||
reranker = LLaMaCPPReranker(
|
||||
base_url="http://127.0.0.1:8083",
|
||||
api_key="your-api-key",
|
||||
top_n=5
|
||||
)
|
||||
sorted_docs = reranker.compress_documents(documents, query)
|
||||
```
|
||||
|
||||
### 5. RAG 流水线 (pipeline.py)
|
||||
|
||||
组合上述组件的完整检索流水线。
|
||||
|
||||
```python
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
)
|
||||
|
||||
# 异步检索
|
||||
docs = await pipeline.aretrieve("如何申请项目资金?")
|
||||
|
||||
# 格式化上下文
|
||||
context = pipeline.format_context(docs)
|
||||
```
|
||||
|
||||
## 🔄 与 Agent 系统集成
|
||||
|
||||
### 封装为 LangChain Tool
|
||||
|
||||
```python
|
||||
from langchain_core.tools import tool
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
|
||||
@tool
|
||||
def search_knowledge_base(query: str) -> str:
|
||||
"""搜索知识库获取相关信息"""
|
||||
docs = pipeline.retrieve(query)
|
||||
return pipeline.format_context(docs)
|
||||
```
|
||||
|
||||
### 绑定到 LangGraph
|
||||
|
||||
```python
|
||||
from app.graph.graph_builder import GraphBuilder
|
||||
|
||||
# 将 RAG 工具添加到工具列表
|
||||
tools = AVAILABLE_TOOLS + [search_knowledge_base]
|
||||
|
||||
# 构建图
|
||||
builder = GraphBuilder(llm, tools, tools_by_name)
|
||||
graph = builder.build().compile(checkpointer=checkpointer)
|
||||
```
|
||||
|
||||
## ⚙️ 环境配置
|
||||
|
||||
| 变量名 | 说明 | 默认值 |
|
||||
|:-------|:-----|:-------|
|
||||
| `QDRANT_URL` | Qdrant 向量数据库地址 | `http://127.0.0.1:6333` |
|
||||
| `QDRANT_API_KEY` | Qdrant API 密钥 | - |
|
||||
| `LLAMACPP_RERANKER_URL` | llama.cpp 重排序服务地址 | `http://127.0.0.1:8083` |
|
||||
| `LLAMACPP_API_KEY` | llama.cpp API 密钥 | - |
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
```python
|
||||
# 1. 初始化嵌入模型
|
||||
from rag_core.embedders import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 2. 创建检索器
|
||||
from app.rag.retriever import create_base_retriever
|
||||
retriever = create_base_retriever(
|
||||
collection_name="rag_documents",
|
||||
embeddings=embeddings,
|
||||
search_kwargs={"k": 20}
|
||||
)
|
||||
|
||||
# 3. 创建 RAG 流水线
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
)
|
||||
|
||||
# 4. 执行检索
|
||||
docs = pipeline.retrieve("如何申请项目资金?")
|
||||
|
||||
# 5. 格式化上下文
|
||||
context = pipeline.format_context(docs)
|
||||
print(context)
|
||||
```
|
||||
|
||||
## 📊 检索策略对比
|
||||
|
||||
| 策略 | 优点 | 缺点 | 适用场景 |
|
||||
|:-----|:-----|:-----|:---------|
|
||||
| **基础向量检索** | 速度快,语义理解好 | 专有名词匹配差 | 通用问答 |
|
||||
| **混合检索** | 语义 + 关键词匹配 | 需要配置稀疏向量 | 专业术语查询 |
|
||||
| **多路改写 + RRF** | 搜索面广,结果稳定 | 延迟略高 | 复杂问题 |
|
||||
| **重排序** | 精度高 | 依赖额外模型 | 最终精排 |
|
||||
|
||||
## 🤝 与 rag_indexer 集成
|
||||
|
||||
- **向量存储**:共享 Qdrant 集合,确保嵌入模型一致
|
||||
- **文档存储**:使用 PostgreSQL 存储父块,通过 UUID 映射
|
||||
- **集合名称**:默认使用 `rag_documents` 集合
|
||||
|
||||
详见 [rag_indexer/README.md](../../rag_indexer/README.md)
|
||||
69
backend/app/rag/__init__.py
Normal file
69
backend/app/rag/__init__.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
RAG 检索与生成模块
|
||||
|
||||
提供在线检索与生成功能,包括:
|
||||
- 基础向量检索(稠密向量 / 混合检索)
|
||||
- 重排序(Cross-Encoder)
|
||||
- 多路查询改写(Multi-Query)
|
||||
- RRF 融合(Reciprocal Rank Fusion)
|
||||
- 完整的 RAG 流水线
|
||||
- Agent 工具封装
|
||||
|
||||
固定流水线:
|
||||
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||
|
||||
示例用法:
|
||||
>>> from app.rag.rag import RAGPipeline, create_rag_tool
|
||||
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
|
||||
>>> from langchain_openai import ChatOpenAI
|
||||
>>>
|
||||
>>> # 获取基础检索器(如父子块检索器)
|
||||
>>> config = IndexBuilderConfig(collection_name="my_docs")
|
||||
>>> builder = IndexBuilder(config)
|
||||
>>> retriever = builder.retriever
|
||||
>>>
|
||||
>>> # 创建 LLM 和流水线
|
||||
>>> llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
>>> pipeline = RAGPipeline(retriever=retriever, llm=llm)
|
||||
>>>
|
||||
>>> # 检索
|
||||
>>> docs = await pipeline.aretrieve("什么是 RAG?")
|
||||
>>> context = pipeline.format_context(docs)
|
||||
>>>
|
||||
>>> # 创建 Agent 工具
|
||||
>>> rag_tool = create_rag_tool(retriever=retriever, llm=llm)
|
||||
"""
|
||||
|
||||
from .retriever import (
|
||||
create_base_retriever,
|
||||
create_hybrid_retriever,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from .reranker import LLaMaCPPReranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
from .pipeline import RAGPipeline
|
||||
from .tools import create_rag_tool_sync
|
||||
|
||||
|
||||
__all__ = [
|
||||
# 检索器工厂函数
|
||||
"create_base_retriever",
|
||||
"create_hybrid_retriever",
|
||||
"create_qdrant_client",
|
||||
|
||||
# 重排序器
|
||||
"LLaMaCPPReranker",
|
||||
|
||||
# 查询改写生成器
|
||||
"MultiQueryGenerator",
|
||||
|
||||
# 融合算法
|
||||
"reciprocal_rank_fusion",
|
||||
|
||||
# 主流水线
|
||||
"RAGPipeline",
|
||||
|
||||
# 工具创建(供 Agent 使用)
|
||||
"create_rag_tool_sync",
|
||||
]
|
||||
36
backend/app/rag/fusion.py
Normal file
36
backend/app/rag/fusion.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# rag/fusion.py
|
||||
|
||||
from typing import List, Dict
|
||||
from langchain_core.documents import Document
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
doc_lists: List[List[Document]],
|
||||
k: int = 60
|
||||
) -> List[Document]:
|
||||
"""
|
||||
对多个检索结果列表进行 RRF 融合。
|
||||
|
||||
Args:
|
||||
doc_lists: 多个检索结果列表,每个列表来自一个查询
|
||||
k: RRF 常数,通常设为 60
|
||||
|
||||
Returns:
|
||||
融合后按 RRF 得分降序排列的文档列表
|
||||
"""
|
||||
# 使用文档内容作为唯一标识(如果内容相同但 metadata 不同,视为同一文档)
|
||||
# 更好的做法是用 docstore 的 ID,这里简化处理:用内容 hash
|
||||
doc_to_score: Dict[str, float] = {}
|
||||
doc_map: Dict[str, Document] = {}
|
||||
|
||||
for docs in doc_lists:
|
||||
for rank, doc in enumerate(docs, start=1):
|
||||
# 生成唯一标识符(内容+来源组合,避免不同文件相同内容混淆)
|
||||
doc_id = f"{doc.page_content[:200]}_{doc.metadata.get('source', '')}"
|
||||
if doc_id not in doc_map:
|
||||
doc_map[doc_id] = doc
|
||||
score = doc_to_score.get(doc_id, 0.0) + 1.0 / (k + rank)
|
||||
doc_to_score[doc_id] = score
|
||||
|
||||
# 按得分排序
|
||||
sorted_ids = sorted(doc_to_score.keys(), key=lambda x: doc_to_score[x], reverse=True)
|
||||
return [doc_map[doc_id] for doc_id in sorted_ids]
|
||||
91
backend/app/rag/pipeline.py
Normal file
91
backend/app/rag/pipeline.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# rag/pipeline.py
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from ..config import LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from .reranker import LLaMaCPPReranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
|
||||
class RAGPipeline:
|
||||
"""
|
||||
固定流程的 RAG 检索流水线:
|
||||
多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
|
||||
llm: 用于生成多路查询的语言模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
rerank_model: 重排序模型名称
|
||||
"""
|
||||
self.retriever = retriever
|
||||
self.llm = llm
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
|
||||
# 初始化组件
|
||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
|
||||
self.reranker = LLaMaCPPReranker(
|
||||
base_url=LLAMACPP_RERANKER_URL,
|
||||
api_key=LLAMACPP_API_KEY,
|
||||
top_n=rerank_top_n,
|
||||
)
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
"""
|
||||
异步执行完整检索流程
|
||||
"""
|
||||
# Step 1: 生成多路查询
|
||||
queries = await self.query_generator.agenerate(query)
|
||||
# 包含原始查询,确保至少有一条
|
||||
if query not in queries:
|
||||
queries.insert(0, query)
|
||||
else:
|
||||
# 如果原始查询已在列表中,将其移至首位
|
||||
queries.remove(query)
|
||||
queries.insert(0, query)
|
||||
|
||||
# Step 2: 并行检索(每个查询获取文档列表)
|
||||
tasks = [self.retriever.ainvoke(q) for q in queries]
|
||||
doc_lists = await asyncio.gather(*tasks)
|
||||
|
||||
# Step 3: RRF 融合
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||
|
||||
# Step 4: 重排序
|
||||
try:
|
||||
final_docs = self.reranker.compress_documents(fused_docs, query)
|
||||
except Exception:
|
||||
# 若重排序器不可用,直接返回融合后的前 N 条
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
return final_docs
|
||||
|
||||
def retrieve(self, query: str) -> List[Document]:
|
||||
"""同步检索入口(内部调用异步方法)"""
|
||||
return asyncio.run(self.aretrieve(query))
|
||||
|
||||
def format_context(self, documents: List[Document]) -> str:
|
||||
"""将文档列表格式化为上下文字符串"""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for i, doc in enumerate(documents, 1):
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
|
||||
return "\n".join(parts)
|
||||
43
backend/app/rag/query_transform.py
Normal file
43
backend/app/rag/query_transform.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# rag/query_transform.py
|
||||
|
||||
from typing import List
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
MULTI_QUERY_PROMPT = PromptTemplate.from_template(
|
||||
"""你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
|
||||
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
|
||||
|
||||
原始问题: {question}
|
||||
|
||||
请生成 {num_queries} 个不同版本的查询,每个版本一行。
|
||||
确保每个版本都是独立、完整的查询语句。
|
||||
|
||||
生成 {num_queries} 个查询:"""
|
||||
)
|
||||
|
||||
class MultiQueryGenerator:
|
||||
"""多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever)"""
|
||||
|
||||
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
|
||||
self.llm = llm
|
||||
self.num_queries = num_queries
|
||||
self.prompt = MULTI_QUERY_PROMPT
|
||||
|
||||
def generate(self, query: str) -> List[str]:
|
||||
"""同步生成多个查询变体"""
|
||||
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
|
||||
response = self.llm.invoke(prompt_str)
|
||||
# 处理响应内容,按行分割并去除空行和首尾空白
|
||||
lines = response.content.strip().split('\n')
|
||||
queries = [line.strip() for line in lines if line.strip()]
|
||||
# 确保至少返回原始查询
|
||||
return queries[:self.num_queries] if queries else [query]
|
||||
|
||||
async def agenerate(self, query: str) -> List[str]:
|
||||
"""异步生成多个查询变体"""
|
||||
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
|
||||
response = await self.llm.ainvoke(prompt_str)
|
||||
lines = response.content.strip().split('\n')
|
||||
queries = [line.strip() for line in lines if line.strip()]
|
||||
return queries[:self.num_queries] if queries else [query]
|
||||
75
backend/app/rag/reranker.py
Normal file
75
backend/app/rag/reranker.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
重排序器模块 (适配版)
|
||||
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
|
||||
"""
|
||||
import requests
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
|
||||
class LLaMaCPPReranker:
|
||||
"""使用远程 llama.cpp 服务对检索结果重排序。"""
|
||||
|
||||
def __init__(self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
top_n: int = 5,
|
||||
timeout: int = 60):
|
||||
"""
|
||||
初始化远程重排序器
|
||||
|
||||
Args:
|
||||
base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"。
|
||||
top_n: 返回前 N 个结果。
|
||||
api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"。
|
||||
timeout: 请求超时时间(秒)。
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.top_n = top_n
|
||||
self.timeout = timeout
|
||||
self.endpoint = f"{self.base_url}/rerank"
|
||||
|
||||
def compress_documents(
|
||||
self, documents: List[Document], query: str
|
||||
) -> List[Document]:
|
||||
"""
|
||||
对文档进行重排序
|
||||
|
||||
Args:
|
||||
documents: 待排序的文档列表
|
||||
query: 查询字符串
|
||||
|
||||
Returns:
|
||||
排序后的文档列表
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# 准备请求体
|
||||
# 根据 llama.cpp 的 OpenAI 兼容性,文档是一个字符串列表
|
||||
payload = {
|
||||
"model": "bge-reranker-v2-m3",
|
||||
"query": query,
|
||||
"documents": [doc.page_content for doc in documents],
|
||||
"top_n": self.top_n
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
results = response.json()
|
||||
|
||||
# 解析返回结果
|
||||
# 返回格式: {"results": [{"index": 0, "document": "...", "relevance_score": 0.8}, ...]}
|
||||
# 按相关性得分降序排列
|
||||
sorted_indices = [item["index"] for item in results["results"]]
|
||||
sorted_docs = [documents[idx] for idx in sorted_indices]
|
||||
return sorted_docs
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告: 远程重排序过程出错,将使用原始排序。错误: {e}")
|
||||
return documents[:self.top_n]
|
||||
199
backend/app/rag/retriever.py
Normal file
199
backend/app/rag/retriever.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Qdrant 向量检索器模块
|
||||
|
||||
提供基于 Qdrant 的基础向量检索和混合检索(Dense + Sparse)功能。
|
||||
|
||||
核心原理:
|
||||
- 基础检索:将查询文本转换为向量,在 Qdrant 中进行近似最近邻(ANN)搜索,
|
||||
使用余弦相似度返回最相似的 k 个文档。
|
||||
- 混合检索:结合稠密向量检索(语义相似)和 BM25 稀疏向量检索(关键词匹配),
|
||||
通过加权或分数融合提高召回精度。
|
||||
|
||||
使用示例:
|
||||
>>> from rag_core import LlamaCppEmbedder
|
||||
>>> embedder = LlamaCppEmbedder()
|
||||
>>> embeddings = embedder.as_langchain_embeddings()
|
||||
>>>
|
||||
>>> # 创建基础检索器
|
||||
>>> retriever = create_base_retriever(
|
||||
... collection_name="my_docs",
|
||||
... embeddings=embeddings,
|
||||
... search_kwargs={"k": 10}
|
||||
... )
|
||||
>>>
|
||||
>>> # 执行检索
|
||||
>>> docs = retriever.invoke("什么是 RAG?")
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
||||
|
||||
# 模块级常量
|
||||
DEFAULT_SEARCH_K = 20
|
||||
DEFAULT_SCORE_THRESHOLD = 0.3
|
||||
|
||||
|
||||
def create_qdrant_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 30,
|
||||
) -> QdrantClient:
|
||||
"""
|
||||
创建并返回一个配置好的 Qdrant 客户端。
|
||||
|
||||
优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。
|
||||
|
||||
Args:
|
||||
url: Qdrant 服务地址,例如 "http://localhost:6333"。
|
||||
默认从环境变量 QDRANT_URL 读取。
|
||||
api_key: API 密钥(若 Qdrant 启用了认证)。
|
||||
默认从环境变量 QDRANT_API_KEY 读取。
|
||||
timeout: 请求超时时间(秒),默认 30 秒。
|
||||
|
||||
Returns:
|
||||
配置好的 QdrantClient 实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 url 为空且环境变量也未设置。
|
||||
"""
|
||||
effective_url = url or QDRANT_URL
|
||||
if not effective_url:
|
||||
raise ValueError(
|
||||
"Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL"
|
||||
)
|
||||
|
||||
effective_api_key = api_key or QDRANT_API_KEY
|
||||
|
||||
client_kwargs = {
|
||||
"url": effective_url,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if effective_api_key:
|
||||
client_kwargs["api_key"] = effective_api_key
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
|
||||
|
||||
def create_base_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[QdrantClient] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建基础向量检索器(仅稠密向量检索)。
|
||||
|
||||
该检索器使用嵌入模型将查询转为向量,在 Qdrant 集合中执行 ANN 搜索,
|
||||
返回语义上最相似的文档块。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称(需预先创建并索引)。
|
||||
embeddings: LangChain 兼容的嵌入模型实例。
|
||||
search_kwargs: 搜索参数,可包含:
|
||||
- k (int): 返回的文档数量,默认 20。
|
||||
- score_threshold (float): 相似度阈值,仅返回高于此分数的文档。
|
||||
- filter (dict): Qdrant 过滤条件。
|
||||
若为 None,则使用默认值 {"k": 20}。
|
||||
client: 可选的 Qdrant 客户端实例。若未提供,将自动创建。
|
||||
|
||||
Returns:
|
||||
BaseRetriever 实例,可直接调用 .invoke(query) 或 .ainvoke(query) 检索。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果集合不存在或嵌入模型无效。
|
||||
"""
|
||||
# 合并默认搜索参数
|
||||
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
|
||||
if search_kwargs:
|
||||
merged_search_kwargs.update(search_kwargs)
|
||||
|
||||
# 创建或复用 Qdrant 客户端
|
||||
if client is None:
|
||||
client = create_qdrant_client()
|
||||
|
||||
# 验证集合是否存在(可选,便于提前发现问题)
|
||||
try:
|
||||
client.get_collection(collection_name)
|
||||
except UnexpectedResponse as e:
|
||||
if e.status_code == 404:
|
||||
raise ValueError(
|
||||
f"Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档。"
|
||||
)
|
||||
raise
|
||||
|
||||
# 构建向量存储
|
||||
vector_store = QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
embedding=embeddings,
|
||||
)
|
||||
|
||||
# 返回检索器
|
||||
return vector_store.as_retriever(search_kwargs=merged_search_kwargs)
|
||||
|
||||
|
||||
def create_hybrid_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
dense_k: int = 10,
|
||||
sparse_k: int = 10,
|
||||
score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD,
|
||||
client: Optional[QdrantClient] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建混合检索器(稠密向量 + BM25 稀疏向量)。
|
||||
|
||||
混合检索结合了语义相似度(Dense)和关键词匹配(Sparse),
|
||||
能够更好地处理专有名词、精确匹配等场景。
|
||||
|
||||
注意:此功能要求 Qdrant 集合已配置稀疏向量字段并生成了 BM25 索引。
|
||||
若集合未配置稀疏向量,将回退到纯稠密检索(不会报错,但检索效果降级)。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称。
|
||||
embeddings: 嵌入模型(用于稠密向量)。
|
||||
dense_k: 稠密向量检索返回数量,默认 10。
|
||||
sparse_k: 稀疏向量检索返回数量,默认 10。
|
||||
score_threshold: 相似度阈值,默认 0.3。
|
||||
client: 可选的 Qdrant 客户端实例。
|
||||
|
||||
Returns:
|
||||
BaseRetriever 实例,配置了混合搜索参数。
|
||||
"""
|
||||
total_k = dense_k + sparse_k
|
||||
|
||||
search_kwargs = {
|
||||
"k": total_k,
|
||||
}
|
||||
if score_threshold is not None:
|
||||
search_kwargs["score_threshold"] = score_threshold
|
||||
|
||||
# 复用基础检索器创建逻辑,只需调整搜索参数
|
||||
return create_base_retriever(
|
||||
collection_name=collection_name,
|
||||
embeddings=embeddings,
|
||||
search_kwargs=search_kwargs,
|
||||
client=client,
|
||||
)
|
||||
|
||||
|
||||
# 可选:提供异步友好的辅助函数
|
||||
async def acreate_base_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[QdrantClient] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
异步创建基础向量检索器(与同步版本功能相同)。
|
||||
|
||||
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
|
||||
"""
|
||||
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
|
||||
return create_base_retriever(collection_name, embeddings, search_kwargs, client)
|
||||
146
backend/app/rag/test.py
Normal file
146
backend/app/rag/test.py
Normal file
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RAG 系统使用示例(重构版)
|
||||
|
||||
演示:
|
||||
1. 使用 IndexBuilder 获取父子块检索器
|
||||
2. 创建固定流程的 RAGPipeline(多路改写 → RRF融合 → 重排序 → 返回父文档)
|
||||
3. 将流水线封装为 LangChain 工具,供 Agent 调用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量(Qdrant URL、PostgreSQL 连接等)
|
||||
load_dotenv()
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
from rag_indexer.index_builder import IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
from .pipeline import RAGPipeline
|
||||
from .tools import create_rag_tool_sync
|
||||
from pydantic import SecretStr
|
||||
# 使用本地 LLM(通过 OpenAI 兼容接口)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag_core.retriever_factory import create_parent_retriever
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def create_llm():
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://127.0.0.1:8081/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
base_url=vllm_base_url,
|
||||
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
async def demonstrate_full_pipeline():
|
||||
"""
|
||||
完整流水线演示:
|
||||
- 从 IndexBuilder 获取 ParentDocumentRetriever
|
||||
- 创建 RAGPipeline
|
||||
- 执行检索并打印结果
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
|
||||
print("=" * 60)
|
||||
|
||||
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
if retriever is None:
|
||||
print("错误:检索器未初始化,请确保索引已构建。")
|
||||
return
|
||||
|
||||
# 3. 创建 LLM 用于查询改写
|
||||
llm = create_llm()
|
||||
|
||||
# 4. 创建 RAGPipeline(固定流程)
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3, # 生成 3 个查询变体
|
||||
rerank_top_n=5, # 最终返回 5 个父文档
|
||||
)
|
||||
|
||||
# 5. 执行检索
|
||||
query = "打虎英雄是谁?"
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
print(f"返回 {len(documents)} 个父文档\n")
|
||||
|
||||
# 打印结果预览
|
||||
for i, doc in enumerate(documents, 1):
|
||||
content_preview = doc.page_content.replace("\n", " ")[:150]
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
print(f"{i}. 【来源:{source}】")
|
||||
print(f" {content_preview}...\n")
|
||||
|
||||
# 可选:格式化完整上下文
|
||||
# context = pipeline.format_context(documents)
|
||||
# print(context)
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def demonstrate_tool_creation():
|
||||
"""
|
||||
演示创建 RAG 工具(供 Agent 使用)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("演示:创建 RAG 工具(供 LangGraph Agent 调用)")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取检索器(同上)
|
||||
config = IndexBuilderConfig(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
)
|
||||
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
# 2. 创建 LLM
|
||||
llm = create_llm()
|
||||
|
||||
# 3. 创建工具
|
||||
rag_tool = create_rag_tool_sync(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
collection_name="rag_documents",
|
||||
)
|
||||
|
||||
print(f"工具名称: {rag_tool.name}")
|
||||
print(f"工具描述: {rag_tool.description[:100]}...")
|
||||
|
||||
# 4. 模拟 Agent 调用工具
|
||||
query = "请告诉我 打虎英雄是谁?"
|
||||
print(f"\n模拟调用: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
result = await rag_tool.ainvoke({"query": query})
|
||||
print(result[:800] + "..." if len(result) > 800 else result)
|
||||
|
||||
async def main():
|
||||
await demonstrate_full_pipeline()
|
||||
await demonstrate_tool_creation()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
54
backend/app/rag/tools.py
Normal file
54
backend/app/rag/tools.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
RAG 工具模块
|
||||
|
||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
||||
"""
|
||||
from typing import Callable
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from .pipeline import RAGPipeline
|
||||
|
||||
def create_rag_tool_sync(
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。
|
||||
|
||||
参数同 create_rag_tool。
|
||||
"""
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
|
||||
@tool
|
||||
def search_knowledge_base_sync(query: str) -> str:
|
||||
"""在知识库中搜索与查询相关的文档片段(同步版本)。
|
||||
|
||||
功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容。
|
||||
"""
|
||||
try:
|
||||
documents = pipeline.retrieve(query) # 内部调用异步方法并等待
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base_sync
|
||||
304
backend/app/test_backend.py
Normal file
304
backend/app/test_backend.py
Normal file
@@ -0,0 +1,304 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
完整后端测试 - 验证 Agent 所有功能
|
||||
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from .config import DB_URI
|
||||
import sys
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加项目根目录到 Python 路径 (现在文件在 backend/app/ 下,backend 就是根)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from ..agent import AIAgentService
|
||||
from ..agent.history import ThreadHistoryService
|
||||
from ..logger import info, warning, error
|
||||
|
||||
# PostgreSQL 连接字符串
|
||||
|
||||
async def print_section(title):
|
||||
"""打印测试区块标题"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" {title}")
|
||||
print("=" * 70)
|
||||
|
||||
async def test_short_term_memory(agent_service):
|
||||
"""测试短期记忆(同一 thread_id 继续对话)"""
|
||||
await print_section("测试 1: 短期记忆(Short-term Memory)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_memory"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
# 第一轮对话
|
||||
print("\n[第一轮] 发送消息: '我叫张三,今年28岁'")
|
||||
result1 = await agent_service.process_message(
|
||||
"我叫张三,今年28岁", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result1['reply'][:100]}...")
|
||||
|
||||
# 第二轮对话 - 测试记忆
|
||||
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
|
||||
result2 = await agent_service.process_message(
|
||||
"我叫什么名字?今年多大?", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result2['reply']}")
|
||||
|
||||
# 验证记忆是否存在
|
||||
if "张三" in result2['reply'] or "28" in result2['reply']:
|
||||
print("\n✅ 短期记忆测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 短期记忆测试失败!")
|
||||
return False
|
||||
|
||||
async def test_tool_calling(agent_service):
|
||||
"""测试工具调用(RAG 搜索)"""
|
||||
await print_section("测试 2: 工具调用(Tool Calling)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_tools"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
# 发送需要 RAG 搜索的问题
|
||||
print("\n发送消息: '请告诉我,打虎英雄是谁?'")
|
||||
result = await agent_service.process_message(
|
||||
"请告诉我,打虎英雄是谁?", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result['reply'][:200]}...")
|
||||
|
||||
# 检查是否调用了 RAG 工具(回复中会有水浒传相关内容)
|
||||
if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']:
|
||||
print("\n✅ 工具调用测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
|
||||
return None
|
||||
|
||||
async def test_streaming(agent_service):
|
||||
"""测试流式对话"""
|
||||
await print_section("测试 3: 流式对话(Streaming)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_stream"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
|
||||
print("流式输出: ", end="", flush=True)
|
||||
|
||||
full_reply = ""
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
async for chunk in agent_service.process_message_stream(
|
||||
"用100字介绍一下AI人工智能", thread_id, "local", user_id
|
||||
):
|
||||
chunk_count += 1
|
||||
if chunk.get("type") == "llm_token":
|
||||
token = chunk.get("token", "")
|
||||
print(token, end="", flush=True)
|
||||
full_reply += token
|
||||
elif chunk.get("type") == "state_update":
|
||||
pass # 状态更新不显示
|
||||
|
||||
print(f"\n\n共收到 {chunk_count} 个 chunk")
|
||||
print(f"完整回复长度: {len(full_reply)} 字")
|
||||
|
||||
if chunk_count > 0 and len(full_reply) > 10:
|
||||
print("\n✅ 流式对话测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 流式对话测试失败!")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 流式对话异常: {e}")
|
||||
return False
|
||||
|
||||
async def test_history_service(agent_service, history_service):
|
||||
"""测试历史查询服务"""
|
||||
await print_section("测试 4: 历史查询服务(History Service)")
|
||||
|
||||
user_id = "test_user_history"
|
||||
|
||||
# 先创建几个对话
|
||||
print(f"\n为 user_id={user_id} 创建测试对话...")
|
||||
|
||||
thread_ids = []
|
||||
for i in range(3):
|
||||
thread_id = str(uuid.uuid4())
|
||||
thread_ids.append(thread_id)
|
||||
|
||||
await agent_service.process_message(
|
||||
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
|
||||
)
|
||||
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
|
||||
|
||||
# 1. 测试获取用户线程列表
|
||||
print("\n[4.1] 测试获取用户线程列表...")
|
||||
threads = await history_service.get_user_threads(user_id, limit=10)
|
||||
print(f" 找到 {len(threads)} 个线程")
|
||||
|
||||
if len(threads) >= 3:
|
||||
print(" ✅ 线程列表查询通过")
|
||||
else:
|
||||
print(" ⚠️ 线程数量少于预期")
|
||||
|
||||
# 2. 测试获取单个线程的消息历史
|
||||
if thread_ids:
|
||||
test_thread_id = thread_ids[0]
|
||||
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
|
||||
messages = await history_service.get_thread_messages(test_thread_id)
|
||||
print(f" 找到 {len(messages)} 条消息")
|
||||
|
||||
if len(messages) >= 2: # 至少有一问一答
|
||||
print(" ✅ 消息历史查询通过")
|
||||
else:
|
||||
print(" ⚠️ 消息数量少于预期")
|
||||
|
||||
# 3. 测试获取线程摘要
|
||||
print(f"\n[4.3] 测试获取线程摘要...")
|
||||
summary = await history_service.get_thread_summary(test_thread_id)
|
||||
print(f" 摘要: {summary.get('summary', '')[:50]}...")
|
||||
print(f" 消息数: {summary.get('message_count', 0)}")
|
||||
|
||||
if summary.get('message_count', 0) > 0:
|
||||
print(" ✅ 线程摘要查询通过")
|
||||
else:
|
||||
print(" ⚠️ 摘要查询结果不确定")
|
||||
|
||||
return len(threads) >= 3
|
||||
|
||||
async def test_long_term_memory(agent_service):
|
||||
"""测试长期记忆(mem0)"""
|
||||
await print_section("测试 5: 长期记忆(Long-term Memory - mem0)")
|
||||
|
||||
thread_id1 = str(uuid.uuid4())
|
||||
thread_id2 = str(uuid.uuid4()) # 不同的线程
|
||||
user_id = "test_user_longterm"
|
||||
|
||||
print(f"\n使用 user_id: {user_id}")
|
||||
print(f"线程 1: {thread_id1[:8]}...")
|
||||
print(f"线程 2: {thread_id2[:8]}...")
|
||||
|
||||
# 在第一个线程中保存信息
|
||||
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
|
||||
result1 = await agent_service.process_message(
|
||||
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
|
||||
)
|
||||
print(f"回复: {result1['reply'][:100]}...")
|
||||
|
||||
# 等待一下,让 mem0 保存
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 在第二个线程中询问(不同的 thread_id)
|
||||
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
|
||||
result2 = await agent_service.process_message(
|
||||
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
|
||||
)
|
||||
print(f"回复: {result2['reply']}")
|
||||
|
||||
# 验证长期记忆
|
||||
if "小白" in result2['reply'] or "猫" in result2['reply']:
|
||||
print("\n✅ 长期记忆测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
|
||||
return None
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 后端完整功能测试")
|
||||
print("=" * 70)
|
||||
|
||||
results = {}
|
||||
|
||||
try:
|
||||
# 创建数据库连接和服务
|
||||
print("\n正在初始化数据库连接...")
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
print("✅ 数据库连接成功")
|
||||
|
||||
# 创建服务实例
|
||||
print("\n正在初始化 Agent 服务...")
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
print("✅ Agent 服务初始化成功")
|
||||
|
||||
history_service = ThreadHistoryService(checkpointer)
|
||||
print("✅ 历史服务初始化成功")
|
||||
|
||||
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
|
||||
|
||||
# 运行测试
|
||||
results["短期记忆"] = await test_short_term_memory(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["工具调用"] = await test_tool_calling(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["流式对话"] = await test_streaming(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["历史查询"] = await test_history_service(agent_service, history_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["长期记忆"] = await test_long_term_memory(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 打印总结
|
||||
await print_section("测试总结")
|
||||
print("\n测试结果:")
|
||||
print("-" * 40)
|
||||
|
||||
pass_count = 0
|
||||
fail_count = 0
|
||||
skip_count = 0
|
||||
|
||||
for test_name, result in results.items():
|
||||
if result is True:
|
||||
status = "✅ 通过"
|
||||
pass_count += 1
|
||||
elif result is False:
|
||||
status = "❌ 失败"
|
||||
fail_count += 1
|
||||
else:
|
||||
status = "⚠️ 待验证"
|
||||
skip_count += 1
|
||||
print(f" {test_name:12s}: {status}")
|
||||
|
||||
print("-" * 40)
|
||||
print(f"总计: {len(results)} 个测试")
|
||||
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
|
||||
|
||||
if fail_count == 0:
|
||||
print("\n🎉 所有核心测试通过!")
|
||||
else:
|
||||
print(f"\n⚠️ 有 {fail_count} 个测试失败")
|
||||
|
||||
except Exception as e:
|
||||
error(f"\n❌ 测试运行异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0 if fail_count == 0 else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
7
backend/app/utils/__init__.py
Normal file
7
backend/app/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
工具模块
|
||||
"""
|
||||
|
||||
from .logging import log_state_change, print_llm_input
|
||||
|
||||
__all__ = ["log_state_change", "print_llm_input"]
|
||||
61
backend/app/utils/logging.py
Normal file
61
backend/app/utils/logging.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
LangGraph 节点日志工具模块
|
||||
提供状态流转追踪和 LLM 输入输出打印功能
|
||||
"""
|
||||
|
||||
from ..config import ENABLE_GRAPH_TRACE
|
||||
from ..logger import debug, info
|
||||
|
||||
|
||||
def log_state_change(node_name: str, state: dict, prefix: str = "进入"):
|
||||
"""
|
||||
记录状态变化日志
|
||||
|
||||
Args:
|
||||
node_name: 节点名称
|
||||
state: 当前状态
|
||||
prefix: 日志前缀("进入" 或 "离开")
|
||||
"""
|
||||
from app.logger import info
|
||||
|
||||
messages = state.get("messages", [])
|
||||
msg_count = len(messages)
|
||||
last_msg = messages[-1] if messages else None
|
||||
last_info = ""
|
||||
if last_msg:
|
||||
# 兼容 dict 和对象两种格式
|
||||
if isinstance(last_msg, dict):
|
||||
content_preview = str(last_msg.get("content", ""))[:10].replace("\n", " ")
|
||||
msg_type = last_msg.get("type", "unknown")
|
||||
else:
|
||||
content_preview = str(last_msg.content)[:10].replace("\n", " ")
|
||||
msg_type = getattr(last_msg, 'type', 'unknown')
|
||||
last_info = f"{msg_type.upper()}: {content_preview}"
|
||||
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")
|
||||
|
||||
|
||||
def print_llm_input(prompt_value):
|
||||
"""
|
||||
RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息
|
||||
|
||||
Args:
|
||||
prompt_value: ChatPromptValue 对象,包含格式化后的消息列表
|
||||
|
||||
Returns:
|
||||
原样返回 prompt_value,不影响链式调用
|
||||
"""
|
||||
if not ENABLE_GRAPH_TRACE:
|
||||
return prompt_value
|
||||
|
||||
messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性
|
||||
|
||||
debug("\n" + "=" * 80)
|
||||
debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:")
|
||||
debug(f" 总消息数: {len(messages)}")
|
||||
debug("-" * 80)
|
||||
for i, msg in enumerate(messages):
|
||||
content_preview = str(msg.content) # 完整输出
|
||||
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
||||
debug("\n" + "=" * 80 + "\n")
|
||||
|
||||
return prompt_value
|
||||
Reference in New Issue
Block a user