重构代码,统一config配置
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 47m14s

This commit is contained in:
2026-04-21 11:02:16 +08:00
parent 726236eaff
commit 8b354b7ccc
50 changed files with 4025 additions and 6 deletions

View File

@@ -0,0 +1,7 @@
"""
Agent 子模块
"""
from .service import AIAgentService
__all__ = ["AIAgentService"]

View File

@@ -0,0 +1,185 @@
"""
历史对话查询模块
利用 LangGraph 的 checkpointer 获取对话历史和摘要
"""
from typing import List, Dict, Any
from ..logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""
def __init__(self, checkpointer):
self.checkpointer = checkpointer
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""
获取指定用户的所有线程摘要信息
Args:
user_id: 用户 ID
limit: 返回数量限制
Returns:
线程列表,每个包含 thread_id, last_updated, summary, message_count
"""
try:
# 查询 checkpoints 表获取用户的线程列表
async with self.checkpointer.conn.cursor() as cur:
# 在较新的 LangGraph 版本中AsyncPostgresSaver 创建的 checkpoints 表
# 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。
# 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。
# 另外,用户的 metadata 存储在 metadata JSONB 列中。
query = """
SELECT
thread_id,
MAX(checkpoint_id) as last_updated
FROM checkpoints
WHERE metadata->>'user_id' = %s
GROUP BY thread_id
ORDER BY last_updated DESC
LIMIT %s
"""
await cur.execute(query, (user_id, limit))
rows = await cur.fetchall()
threads = []
for row in rows:
thread_id = row['thread_id']
# 获取该线程的状态
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict):
messages = state.checkpoint.get("channel_values", {}).get("messages", [])
if messages:
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
threads.append({
"thread_id": thread_id,
# checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识
"last_updated": row['last_updated'] if row['last_updated'] else "",
"summary": summary,
"message_count": message_count
})
return threads
except Exception as e:
error(f"获取用户线程列表失败 (user_id={user_id}): {e}")
return []
async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]:
"""
获取指定线程的完整消息历史
Args:
thread_id: 线程 ID
Returns:
消息列表,格式 [{"role": "user/assistant", "content": "..."}]
"""
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None:
return []
messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else []
if not messages:
return []
# 转换 LangChain 消息对象为字典
result = []
for msg in messages:
# 跳过 system 消息
if hasattr(msg, 'type') and msg.type == "system":
continue
if hasattr(msg, 'type'):
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type
result.append({
"role": role,
"content": msg.content
})
elif isinstance(msg, dict):
role = msg.get("role", msg.get("type", "unknown"))
if role in ["human", "user"]:
role = "user"
elif role in ["ai", "assistant"]:
role = "assistant"
result.append({
"role": role,
"content": msg.get("content", "")
})
return result
except Exception as e:
error(f"获取线程消息历史失败: {e}")
return []
async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]:
"""
获取线程摘要(用于历史列表展示)
Args:
thread_id: 线程 ID
Returns:
包含摘要信息的字典
"""
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None or not state.values:
return {"thread_id": thread_id, "summary": "空对话", "message_count": 0}
messages = state.values.get("messages", [])
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
# 获取最后更新时间
last_updated = ""
if state.metadata and "created_at" in state.metadata:
last_updated = state.metadata["created_at"].isoformat()
return {
"thread_id": thread_id,
"summary": summary,
"message_count": message_count,
"last_updated": last_updated
}
except Exception as e:
error(f"获取线程摘要失败: {e}")
return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0}
def _extract_summary(self, messages: List) -> str:
"""
从消息列表中提取摘要
策略:
1. 如果有 summarize 节点生成的 summary优先使用
2. 否则使用第一条用户消息的前 50 字
"""
# 查找是否有 summary 字段
for msg in messages:
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'):
return msg.additional_kwargs['summary']
elif isinstance(msg, dict) and msg.get('summary'):
return msg['summary']
# 使用第一条用户消息作为摘要
for msg in messages:
if hasattr(msg, 'type') and msg.type == "human":
content = msg.content
return content[:50] + "..." if len(content) > 50 else content
elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]:
content = msg.get("content", "")
return content[:50] + "..." if len(content) > 50 else content
return "空对话"

View File

@@ -0,0 +1,57 @@
# app/llm_factory.py
import os
from ..config import ZHIPUAI_API_KEY, DEEPSEEK_API_KEY, VLLM_BASE_URL, LLAMACPP_API_KEY
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
class LLMFactory:
@staticmethod
def create_zhipu():
api_key = ZHIPUAI_API_KEY
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=api_key,
temperature=0.1,
max_tokens=4096,
timeout=120.0,
max_retries=3,
streaming=True,
)
@staticmethod
def create_deepseek():
api_key = DEEPSEEK_API_KEY
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set")
return ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(api_key),
model="deepseek-reasoner",
temperature=0.1,
max_tokens=4096,
timeout=60.0,
max_retries=2,
streaming=True,
)
@staticmethod
def create_local():
base_url = VLLM_BASE_URL
return ChatOpenAI(
base_url=base_url,
api_key=SecretStr(LLAMACPP_API_KEY),
model="gemma-4-E4B-it",
timeout=60.0,
max_retries=2,
streaming=True,
)
# 模型创建器映射
CREATORS = {
"local": create_local,
"deepseek": create_deepseek,
"zhipu": create_zhipu,
}

View File

@@ -0,0 +1,37 @@
# app/prompts.py
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
"""
创建系统提示模板,可选择动态注入工具描述。
"""
tools_section = ""
if tools:
tool_descs = []
for tool in tools:
# 提取工具名称和描述的第一行
name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
desc = (tool.description or "").split('\n')[0]
tool_descs.append(f"- {name}: {desc}")
tools_section = "\n".join(tool_descs)
system_template = (
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
f"{tools_section}\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接。\n"
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n"
"3. 优先利用已知用户信息进行个性化回复。\n"
"4. 若无信息可依,礼貌询问或提供通用帮助。"
)
return ChatPromptTemplate.from_messages([
("system", system_template),
MessagesPlaceholder(variable_name="messages")
])

View File

@@ -0,0 +1,23 @@
# app/rag_initializer.py
from ..rag.tools import create_rag_tool_sync
from rag_core import create_parent_retriever
from ..logger import info, warning
async def init_rag_tool(local_llm_creator):
"""初始化 RAG 工具,失败返回 None"""
try:
info("🔄 正在初始化 RAG 检索系统...")
retriever = create_parent_retriever(
collection_name="rag_documents",
search_k=5,
)
rewrite_llm = local_llm_creator()
rag_tool = create_rag_tool_sync(
retriever, rewrite_llm,
num_queries=3, rerank_top_n=5
)
info("✅ RAG 检索工具初始化成功")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None

View File

@@ -0,0 +1,154 @@
"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
# 本地模块
from ..graph.graph_builder import GraphBuilder, GraphContext
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from .llm_factory import LLMFactory
from .rag_initializer import init_rag_tool
from ..logger import info, warning
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graphs = {}
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
async def initialize(self):
# 1. 初始化 RAG 工具(如果需要)
rag_tool = await init_rag_tool(LLMFactory.create_local)
if rag_tool:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
# 2. 构建各模型的 Graph
for name, creator in LLMFactory.CREATORS.items():
try:
info(f"🔄 初始化模型 '{name}'...")
llm = creator()
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
available = list(self.graphs.keys())
if not available:
raise RuntimeError("没有可用的模型")
model = available[0]
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
graph = self.graphs[model]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""流式处理消息,返回异步生成器"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"],
version="v2",
subgraphs=True
):
chunk_type = chunk["type"]
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token,
"metadata": metadata
}
elif chunk_type == "updates":
updates_data = chunk["data"]
serialized_data = self._serialize_value(updates_data)
processed_event = {
"type": "state_update",
"data": serialized_data
}
if "messages" in serialized_data:
processed_event["messages"] = serialized_data["messages"]
elif chunk_type == "custom":
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
else:
continue
if processed_event:
yield processed_event