This commit is contained in:
7
backend/app/agent/__init__.py
Normal file
7
backend/app/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Agent 子模块
|
||||
"""
|
||||
|
||||
from .service import AIAgentService
|
||||
|
||||
__all__ = ["AIAgentService"]
|
||||
185
backend/app/agent/history.py
Normal file
185
backend/app/agent/history.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
历史对话查询模块
|
||||
利用 LangGraph 的 checkpointer 获取对话历史和摘要
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from ..logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
|
||||
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的所有线程摘要信息
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
线程列表,每个包含 thread_id, last_updated, summary, message_count
|
||||
"""
|
||||
try:
|
||||
# 查询 checkpoints 表获取用户的线程列表
|
||||
async with self.checkpointer.conn.cursor() as cur:
|
||||
# 在较新的 LangGraph 版本中,AsyncPostgresSaver 创建的 checkpoints 表
|
||||
# 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。
|
||||
# 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。
|
||||
# 另外,用户的 metadata 存储在 metadata JSONB 列中。
|
||||
query = """
|
||||
SELECT
|
||||
thread_id,
|
||||
MAX(checkpoint_id) as last_updated
|
||||
FROM checkpoints
|
||||
WHERE metadata->>'user_id' = %s
|
||||
GROUP BY thread_id
|
||||
ORDER BY last_updated DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
await cur.execute(query, (user_id, limit))
|
||||
rows = await cur.fetchall()
|
||||
|
||||
threads = []
|
||||
for row in rows:
|
||||
thread_id = row['thread_id']
|
||||
|
||||
# 获取该线程的状态
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict):
|
||||
messages = state.checkpoint.get("channel_values", {}).get("messages", [])
|
||||
|
||||
if messages:
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
threads.append({
|
||||
"thread_id": thread_id,
|
||||
# checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识
|
||||
"last_updated": row['last_updated'] if row['last_updated'] else "",
|
||||
"summary": summary,
|
||||
"message_count": message_count
|
||||
})
|
||||
|
||||
return threads
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
|
||||
return []
|
||||
|
||||
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取指定线程的完整消息历史
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
|
||||
Returns:
|
||||
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
|
||||
"""
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None:
|
||||
return []
|
||||
|
||||
messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else []
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# 转换 LangChain 消息对象为字典
|
||||
result = []
|
||||
for msg in messages:
|
||||
# 跳过 system 消息
|
||||
if hasattr(msg, 'type') and msg.type == "system":
|
||||
continue
|
||||
|
||||
if hasattr(msg, 'type'):
|
||||
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
|
||||
result.append({
|
||||
"role": role,
|
||||
"content": msg.content
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "unknown"))
|
||||
if role in ["human", "user"]:
|
||||
role = "user"
|
||||
elif role in ["ai", "assistant"]:
|
||||
role = "assistant"
|
||||
result.append({
|
||||
"role": role,
|
||||
"content": msg.get("content", "")
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程消息历史失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取线程摘要(用于历史列表展示)
|
||||
|
||||
Args:
|
||||
thread_id: 线程 ID
|
||||
|
||||
Returns:
|
||||
包含摘要信息的字典
|
||||
"""
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None or not state.values:
|
||||
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
|
||||
|
||||
messages = state.values.get("messages", [])
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
# 获取最后更新时间
|
||||
last_updated = ""
|
||||
if state.metadata and "created_at" in state.metadata:
|
||||
last_updated = state.metadata["created_at"].isoformat()
|
||||
|
||||
return {
|
||||
"thread_id": thread_id,
|
||||
"summary": summary,
|
||||
"message_count": message_count,
|
||||
"last_updated": last_updated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error(f"获取线程摘要失败: {e}")
|
||||
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
|
||||
|
||||
def _extract_summary(self, messages: List) -> str:
|
||||
"""
|
||||
从消息列表中提取摘要
|
||||
|
||||
策略:
|
||||
1. 如果有 summarize 节点生成的 summary,优先使用
|
||||
2. 否则使用第一条用户消息的前 50 字
|
||||
"""
|
||||
# 查找是否有 summary 字段
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
|
||||
return msg.additional_kwargs['summary']
|
||||
elif isinstance(msg, dict) and msg.get('summary'):
|
||||
return msg['summary']
|
||||
|
||||
# 使用第一条用户消息作为摘要
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and msg.type == "human":
|
||||
content = msg.content
|
||||
return content[:50] + "..." if len(content) > 50 else content
|
||||
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
|
||||
content = msg.get("content", "")
|
||||
return content[:50] + "..." if len(content) > 50 else content
|
||||
|
||||
return "空对话"
|
||||
57
backend/app/agent/llm_factory.py
Normal file
57
backend/app/agent/llm_factory.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# app/llm_factory.py
|
||||
import os
|
||||
from ..config import ZHIPUAI_API_KEY, DEEPSEEK_API_KEY, VLLM_BASE_URL, LLAMACPP_API_KEY
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
class LLMFactory:
|
||||
@staticmethod
|
||||
def create_zhipu():
|
||||
api_key = ZHIPUAI_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set")
|
||||
return ChatZhipuAI(
|
||||
model="glm-4.7-flash",
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=120.0,
|
||||
max_retries=3,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_deepseek():
|
||||
api_key = DEEPSEEK_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set")
|
||||
return ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(api_key),
|
||||
model="deepseek-reasoner",
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_local():
|
||||
base_url = VLLM_BASE_URL
|
||||
return ChatOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=SecretStr(LLAMACPP_API_KEY),
|
||||
model="gemma-4-E4B-it",
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# 模型创建器映射
|
||||
CREATORS = {
|
||||
"local": create_local,
|
||||
"deepseek": create_deepseek,
|
||||
"zhipu": create_zhipu,
|
||||
}
|
||||
37
backend/app/agent/prompts.py
Normal file
37
backend/app/agent/prompts.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# app/prompts.py
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
|
||||
"""
|
||||
创建系统提示模板,可选择动态注入工具描述。
|
||||
"""
|
||||
tools_section = ""
|
||||
if tools:
|
||||
tool_descs = []
|
||||
for tool in tools:
|
||||
# 提取工具名称和描述的第一行
|
||||
name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
|
||||
desc = (tool.description or "").split('\n')[0]
|
||||
tool_descs.append(f"- {name}: {desc}")
|
||||
tools_section = "\n".join(tool_descs)
|
||||
|
||||
system_template = (
|
||||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
f"{tools_section}\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接。\n"
|
||||
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n"
|
||||
"3. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"4. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)
|
||||
|
||||
return ChatPromptTemplate.from_messages([
|
||||
("system", system_template),
|
||||
MessagesPlaceholder(variable_name="messages")
|
||||
])
|
||||
23
backend/app/agent/rag_initializer.py
Normal file
23
backend/app/agent/rag_initializer.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# app/rag_initializer.py
|
||||
from ..rag.tools import create_rag_tool_sync
|
||||
from rag_core import create_parent_retriever
|
||||
from ..logger import info, warning
|
||||
|
||||
async def init_rag_tool(local_llm_creator):
|
||||
"""初始化 RAG 工具,失败返回 None"""
|
||||
try:
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
retriever = create_parent_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=5,
|
||||
)
|
||||
rewrite_llm = local_llm_creator()
|
||||
rag_tool = create_rag_tool_sync(
|
||||
retriever, rewrite_llm,
|
||||
num_queries=3, rerank_top_n=5
|
||||
)
|
||||
info("✅ RAG 检索工具初始化成功")
|
||||
return rag_tool
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
||||
return None
|
||||
154
backend/app/agent/service.py
Normal file
154
backend/app/agent/service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
AI Agent 服务类 - 支持多模型动态切换
|
||||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
# 本地模块
|
||||
from ..graph.graph_builder import GraphBuilder, GraphContext
|
||||
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from .llm_factory import LLMFactory
|
||||
from .rag_initializer import init_rag_tool
|
||||
from ..logger import info, warning
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {}
|
||||
self.tools = AVAILABLE_TOOLS.copy()
|
||||
self.tools_by_name = TOOLS_BY_NAME.copy()
|
||||
|
||||
async def initialize(self):
|
||||
# 1. 初始化 RAG 工具(如果需要)
|
||||
rag_tool = await init_rag_tool(LLMFactory.create_local)
|
||||
if rag_tool:
|
||||
self.tools.append(rag_tool)
|
||||
self.tools_by_name[rag_tool.name] = rag_tool
|
||||
|
||||
# 2. 构建各模型的 Graph
|
||||
for name, creator in LLMFactory.CREATORS.items():
|
||||
try:
|
||||
info(f"🔄 初始化模型 '{name}'...")
|
||||
llm = creator()
|
||||
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[name] = graph
|
||||
info(f"✅ 模型 '{name}' 初始化成功")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
if model not in self.graphs:
|
||||
# 回退到第一个可用模型
|
||||
available = list(self.graphs.keys())
|
||||
if not available:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
model = available[0]
|
||||
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
result = await graph.ainvoke(input_state, config=config, context=context)
|
||||
|
||||
reply = result["messages"][-1].content
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
def _serialize_value(self, value):
|
||||
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
|
||||
if hasattr(value, 'content'):
|
||||
msg_type = getattr(value, 'type', 'message')
|
||||
return {
|
||||
"role": msg_type,
|
||||
"content": getattr(value, 'content', ''),
|
||||
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
|
||||
"tool_calls": getattr(value, 'tool_calls', [])
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""流式处理消息,返回异步生成器"""
|
||||
graph = self.graphs.get(model_name)
|
||||
if not graph:
|
||||
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
async for chunk in graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
context=context,
|
||||
stream_mode=["messages", "updates", "custom"],
|
||||
version="v2",
|
||||
subgraphs=True
|
||||
):
|
||||
chunk_type = chunk["type"]
|
||||
processed_event = {}
|
||||
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token,
|
||||
"metadata": metadata
|
||||
}
|
||||
elif chunk_type == "updates":
|
||||
updates_data = chunk["data"]
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
processed_event = {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
if "messages" in serialized_data:
|
||||
processed_event["messages"] = serialized_data["messages"]
|
||||
elif chunk_type == "custom":
|
||||
serialized_data = self._serialize_value(chunk["data"])
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
Reference in New Issue
Block a user