ICLR 2018 | Deep Gradient Compression: Reducing the Communication Bandwidth for Distributed Training
为了降低大规模分布式训练时的通信开销,作者提出了一种名为深度梯度压缩(Deep Gradient Compression, DGC)的方法。DGC通过稀疏化技术,在每次迭代时只选择发送一部分比较“重要”的梯度元素,以达到降低整个训练过程通信量的目的。为了保证使用DGC后模型的精度,作者还使用了几种花里胡哨的技术,包括动量修正(momentum correction)、本地梯度裁剪(local gradient cliping)、动量因子遮蔽(momentum factor masking)和预训练(warmup training)。
梯度稀疏化
为了降低训练过程中的通信量,我们可以让每个节点在每次迭代中只发送那些“重要”的梯度。因为“不重要”的梯度元素对模型参数更新的贡献比较小,所以就可以不用发送这些对更新贡献较小的梯度。那么问题来了,我们如何知道梯度对参数更新的贡献呢,换句话说,如何评估梯度的重要性?作者在这里使用了一种启发式方法——以梯度元素的大小是否超过某个阈值来判断该元素的重要性——当然,这也是目前最常用的方法。为了防止丢失大量的信息,我们可以把每次迭代中没有超过阈值的小梯度元素存起来,在下次迭代中加回到原始梯度向量中。随着训练的进行,这些较小的梯度元素会累加地越来越大,直至在以后的某次迭代中超过阈值,被节点发送出去。
令\(F(w)\)是我们想要优化的损失函数,同步的分布式SGD算法在\(N\)个训练节点上会进行如下更新:
F(w) &= \frac{1}{|\mathcal{X}|}\sum_{x\in \mathcal{X}}f(x,w)\\
w_{t+1} &= w_t -\eta\frac{1}{Nb}\sum_{k=1}^N\sum_{x\in\mathcal{B}_{k,t}}\triangledown f(x,w_t)
\end{aligned}\tag{1}\label{1}
\]
其中\(\mathcal{X}\)是训练数据集,\(w\)是神经网络的权值,\(f(x,w_t)\)是由\(x\in \mathcal{X}\)计算的损失值,\(\eta\)是学习率,\(\mathcal{B}_{k,t}\)是第\(t\)次迭代中第\(k\)个节点上读取的一个batch的数据样本,每个batch的大小为\(b\)。考虑将\(w\)拉直后在第\(i\)个位置上的权值\(w^{(i)}\),经过\(T\)轮迭代后,我们有:
\]
等式\(\ref{2}\)表明本地梯度累加可以看作将batch size从\(Nb\)增加到\(NbT\),其中\(T\)是两次稀疏更新的间隔,即每进行\(T\)次迭代就发送一次\(w^{(i)}\)的梯度。这是作者说等式\(\ref{2}\)满足学习率缩放规则,学习率\(\eta T\)中的\(T\)和批量大小\(NbT\)中的\(T\)相互抵消了。说实话,这里有点没太看懂。个人理解的学习率缩放指的是学习率与batch size等比例增加,比如batch size从128变为1024,那么学习率应该变为原来的8倍。搞不清楚这里出现的两个\(T\)是什么意思。。。
本地梯度累加
动量修正
在\(N\)个节点上使用标准的动量SGD进行分布式训练的过程如下所述:
\]
这里\(m\)是动量项,\(\triangledown_{k,t}\)是梯度\(\frac{1}{Nb}\sum_{x\in \mathcal{B}_{k,t}}\triangledown f(x,w_t)\)的简写形式。考虑将权重\(w\)拉直后第\(i\)个位置上的元素\(w^{(i)}\),经过\(T\)轮迭代后,\(w^{(i)}\)的变化为:
\]
如果动量SGD直接使用稀疏梯度进行更新(算法1的第15行),那么整个更新过程就不再等价于等式\(\ref{3}\),而是变成了:
v_{k,t} &= v_{k,t-1}+\triangledown_{k,t}\\
u_t &= mu_{t-1}+\sum_{k=1}^{N}sparse(v_{k,t})\\
w_{t+1} &= w_t-\eta u_t
\end{aligned}\tag{5}\label{5}
\]
这里,\(v_{k,t}\)是节点\(k\)上的梯度累加和,一旦累加结果大于阈值,它将会被编码处理然后发送出去,以参与\(u_t\)的更新。随后,节点\(k\)上的累加结果\(v_{k,t}\)在sparse()
中通过掩码被清空。经过\(T\)轮稀疏更新后,权值\(w^{(i)}\)变为:
\]
等式\(\ref{6}\)相比于等式\(\ref{4}\)少了\(\sum_{\tau=0}^{T-1}m^\tau\)这一项,这就导致收敛速率的降低。
在图中,等式\(\ref{4}\)会从点A优化到点B,但是由于每个节点上的本地梯度会累加,等式\(\ref{4}\)会到达点C。当梯度稀疏性很高时,那些“不重要”梯度的更新间隔\(T\)会显著增加,这就会导致模型性能的下降。为了避免上述情况,我们需要对等式\(\ref{5}\)进行动量修正,以令其与等式\(\ref{3}\)等价。如果我们等式\(\ref{3}\)中的速度\(u_t\)看成“梯度”,那么等式\(\ref{3}\)的第二项可以看成针对“梯度”\(u_t\)的标准SGD算法。因此,我们可以在局部累积速度\(u_t\)而不是实际梯度\(\triangledown_{k,t}\)从而将等式\(\ref{5}\)变成与等式\(\ref{3}\)相近的形式:
u_{k,t} &= mu_{k,t-1} +\triangledown_{k,t} \\
v_{k,t} &= v_{k,t-1}+u_{k,t} \\
w_{t+1} &= w_{t}-\eta\sum_{k=1}^Nsparse(v_{k,t})
\end{aligned}\tag{7}\label{7}
\]
这里前两项是修正后的局部梯度累加,累加结果\(v_{k,t}\)用于随后的稀疏化和通信。通过对局部累加的简单修改,我们可以由等式\(\ref{7}\)推导出等式\(\ref{4}\)中的累加折扣因子\(\sum_{\tau =0}^{T-1} m^\tau\),如图中(b)所示。注意,动量修正只对更新方程进行调整,并不会引入任何超参数。
本地梯度裁剪
梯度裁剪被广泛地用于防止梯度爆炸。该方法会在梯度的L2范数之和超过某一阈值时对梯度进行重缩放。一般地,从所有节点进行梯度聚合之后执行梯度裁剪。因为我们会在每个训练节点的每次迭代中独立地累加梯度,所以我们会在将本次梯度\(G_t\)加到前一次的累加梯度\(G_{t-1}\)之前进行梯度裁剪。如果所有\(N\)个节点具有相同的梯度分布,那么我们将阈值缩放\(N^{-\frac{1}{2}}\)。在实践中,我们发现局部梯度裁剪与标准梯度裁剪在训练中的行为非常相似,这表明我们的假设在实际数据中是有效的。正如我们将在第4节中看到的那样,动量修正和局部梯度裁剪有助于将AN4语料库中的单词错误率从14.1%降低到12.9%,而训练曲线更接近带动量的SGD。
梯度陈旧性问题
因为延迟了小梯度的更新,所以当这些更新确实发生时,它们已经变得陈旧了。在作者的实验中,当梯度稀疏度为99.9%时,大多数参数每600到1000次才迭代更新一次,这种陈旧性会降低收敛速度和模型性能。为了解决这个问题,作者使用了动量因子遮蔽和预训练等技术。
动量因子遮蔽
根据文献[1]中结论,异步SGD会产生一个隐式动量(implicit momentum),从而导致收敛变慢。本地梯度累加跟异步GD存在相似性:不能及时更新梯度产生staleness。文献[1]发现负动量(negative momentum)能一定程度抵消隐式动量的效果,提高收敛速度。本文作者采用了一种类似的方法,如果累加梯度\(v_{k,t}\)大于阈值(即本次迭代将会进行数据传输和权重更新),那么就将\(v_{k,t}\)和\(u_{k,t}\)中对应元素清零,从而防止陈旧的动量影响模型权重的更新:
\]
预训练
在训练的早期,网络权重会迅速地变化,梯度会比较稠密。因此,在训练刚开始时,我们需要使用一个较小地学习率来减缓神经网络的权重变化速率以及增加梯度的稀疏性。这里作者说在预训练阶段按照指数速率控制梯度的稀疏性:75%、93.75%、98.4375%、99.6%、99.9%,搞不懂是怎么手动控制梯度稀疏性的。。。
参考文献
[1] Mitliagkas, Ioannis, et al. "Asynchrony begets momentum, with an application to deep learning." 2016 54th Annual Allerton Conference on Communication, Control, and Computing (Allerton). IEEE, 2016.
ICLR 2018 | Deep Gradient Compression: Reducing the Communication Bandwidth for Distributed Training的更多相关文章
- INTERSPEECH 2014 | 1-Bit Stochastic Gradient Descent and its Application to Data-Parallel Distributed Training of Speech DNNs
这篇文章之前也读过,不过读的不太仔细,论文中的一些细节并没有注意到.最近为了写开题报告,又把这篇论文细读了一遍.据笔者了解,这篇论文应该是梯度量化领域的开山之作,首次使用了梯度量化技术来降低分布式神经 ...
- [CVPR2018] Context-aware Deep Feature Compression for High-speed Visual Tracking
基于内容感知深度特征压缩的高速视觉跟踪 论文下载:http://cn.arxiv.org/abs/1803.10537对于视频这种高维度数据,作者训练了多个自编码器AE来进行数据压缩,至于怎么选择具体 ...
- 论文笔记——Deep Model Compression Distilling Knowledge from Noisy Teachers
论文地址:https://arxiv.org/abs/1610.09650 主要思想 这篇文章就是用teacher-student模型,用一个teacher模型来训练一个student模型,同时对te ...
- NASH:基于丰富网络态射和爬山算法的神经网络架构搜索 | ICLR 2018
论文提出NASH方法来进行神经网络结构搜索,核心思想与之前的EAS方法类似,使用网络态射来生成一系列效果一致且继承权重的复杂子网,本文的网络态射更丰富,而且仅需要简单的爬山算法辅助就可以完成搜索,耗时 ...
- 基于层级表达的高效网络搜索方法 | ICLR 2018
论文基于层级表达提出高效的进化算法来进行神经网络结构搜索,通过层层堆叠来构建强大的卷积结构.论文的搜索方法简单,从实验结果看来,达到很不错的准确率,值得学习 来源:[晓飞的算法工程笔记] 公众号 ...
- MLHPC 2016 | Communication Quantization for Data-parallel Training of Deep Neural Networks
本文主要研究HPC上进行数据并行训练的可行性.作者首先在HPC上实现了两种通信量化算法(1 Bit SGD以及阈值量化),然后提出了自适应量化算法以解决它们的缺点.此外,发挥出量化算法的性能,作者还自 ...
- A Deep Neural Network Approach To Speech Bandwidth Expansion
题名:一种用于语音带宽扩展的深度神经网络方法 作者:Kehuang Li:Chin-Hui Lee 2015年出来的 摘要 本文提出了一种基于深度神经网络(DNN)的语音带宽扩展(BWE)方法.利用对 ...
- Federated Machine Learning: Concept and Applications
郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! Qiang Yang, Yang Liu, Tianjian Chen, and Yongxin Tong. 2019. Federate ...
- AI系统——梯度累积算法
明天博士论文要答辩了,只有一张12G二手卡,今晚通宵要搞定10个模型实验 挖槽,突然想出一个T9开天霹雳模型,加载不进去我那张12G的二手卡,感觉要错过今年上台Best Paper领奖 上面出现的 ...
随机推荐
- AI系统——机器学习和深度学习算法流程
终于考上人工智能的研究僧啦,不知道机器学习和深度学习有啥区别,感觉一切都是深度学习 挖槽,听说学长已经调了10个月的参数准备发有2000亿参数的T9开天霹雳模型,我要调参发T10准备拿个Best Pa ...
- 移动端H5选择本地图片
移动端H5选择本地图片 html://input<input type="file" accept="image/*" capture="cam ...
- 使用医学影像开源库cornerstone.js解析Dicom图像显示到HTML中
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- 【Java】Eclipse常用快捷键
Eclipse常用快捷键 * 1.补全代码的声明:alt + / * 2.快速修复: ctrl + 1 * 3.批量导包:ctrl + shift + o * 4.使用单行注释:ctrl + / * ...
- [FatFs 学习] SD卡总结-SPI模式
SD卡为移动设备提供了安全的,大容量存储解决方法.它本身可以通过两种总线模式和MCU进行数据传输,一种是称为SD BUS的4位串行数据模式,另一种就是大家熟知的4线SPI Bus模式.一些廉价,低端的 ...
- java内部类细节
1 package face_09; 2 /* 3 * 为什么内部类能直接访问外部类中的成员呢? 4 * 那是因为内部类持有了外部类的引用. 外部类名.this 5 * 6 */ 7 class Ou ...
- 集合框架-Map重点方法entrySet演示
1 package cn.itcast.p6.map.demo; 2 3 import java.util.HashMap; 4 import java.util.Iterator; 5 import ...
- TCP可靠性
目录 一:TCP可靠性 1.通过序列号与确认应答提高可靠性 一:TCP可靠性 简介 TCP 通过检验和.序列号.确认应答.重发控制.连接管理以及窗口控制等机制实现可靠性传输. 1.通过序列号与确认应答 ...
- URL Rewrite(四种重定向策略)
目录 一:Rewrite基本概述 1.Rewrite简介 2.Rewrite基本概述 3.Rewrite作用 4.什么是URL? 二:rewrite语法 三:Rewrite标记Flag 1.last和 ...
- python31day
内容回顾 网编总结,思维导图 计划 并发编程的开始,计划6天 操作系统1天 进程2天 线程2天 携程1天 今日内容 操作系统 多道操作系统: 从顺序的一个个执行的思路变成:并行轮流使用cpu 一个程序 ...