Deep Learning中的Large Batch Training相关理论与实践
背景
论文脉络梳理
- 《ON LARGE-BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA》:这篇论文解释了Large Batch Training使收敛性变差的原因:使用Large Batch更容易落入Sharp Minima,而Sharp Minima属于过拟合,所以其泛化性比较差。
- 《Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour》:这是FaceBook提出的一篇极具争议性的论文,从实践上来说它的的复现难度也是比较大的。该论文从实践的角度出发,在ResNet上提出了一种针对Large batch training的训练方法,即learning rate scaling rule。当batch size相对于baseline增加N倍时,learning rate也要相应的增加N倍,但也指出batch size的提升有一个upper bound,超过这个值,泛化性依然会变得很差。这篇论文对learning rate scaling rule有一些公式推导,但并不本质,更多的是做了较强的假设。总体来说,这是一篇实验做得比较solid,但理论基础并不丰满的实践论文。
- 《A BAYESIAN PERSPECTIVE ON GENERALIZATION AND STOCHASTIC GRADIENT DESCENT》:这是Google发在ICLR 2018上的一篇理论和实验都比较完善的论文。因为在ResNet上已经有了Learning Rate Scaling Rule的成功经验,因此该论文从贝叶斯的角度解释了泛化性和SGD。论文的核心观点是指出了Batch Training相对于Full Batch Training来说引入了Noise,而Noise具有波动的效果,这在论文里被称为Flucturate,它可以在更新时在一定程度上偏离Sharp Minima,从而进入Broad Minima,进而有了较好的泛化性,所以Noise起了较大的作用。进一步的,论文中将SGD的更新公式进行进行分析,等价为一个微分方程的定积分结果,通过将SGD更新公式与微分方程进行等价,导出了Flucturate的表达式,确定了影响其值的变动因素,即和Learning Rate与Batch size有关。若把Flucturate看做常量,那么Learning Rate与Batch Size可以近似看做是线性关系,这与论文2中的Learning Rate Scaling Rule一致。总体来说,这篇论文数学理论相对丰满的解释了Learning Rate Scaling Rule。
- 《Don't Decay the Learning Rate, Increase the Batch Size》:这是Google发在ICLR 2018上的第二篇论文,这篇论文的实验和结论非常简单,但是理论基础依然来自于论文3,所以阅读此篇论文之前一定要精度论文3。该论文从推导出的Mini Batch SGD的Flucturate公式出发,提出了一种使用Large Batch Training的加速方法。因为在一个完整的模型训练过程中,通常会随着轮数的增加而适当对Learning Rate做Decay。通过论文3中给出的公式,即Flucturate固定时,Learning Rate与Batch Size成正比关系,引发了思考:究竟是Learning Rate本身需要Decay才能使训练过程继续,还是Learning Rate的Decay间接影响了Noise的Flucturate才能使训练过程继续?通过实验验证,真正影响训练过程的本质是Noise的Flucturate。因此我们考虑到Learning Rate与Batch Size的正比例关系,我们可以固定Learning Rate不变,而将Batch Size增加N倍来缩小Noise的Flucturate。定时增加Batch Size不但可以维持原有方式的Flucturate,还可以加速训练过程,减少Update的更新频次,增加计算通信占比,提高加速比。总体来说,该论文基于论文3为理论基础,提出了一种逐渐增加Batch Size提高计算加速比和收敛加速比的方法。
要点梳理
理论基础
- 从贝叶斯理论角度出发,论证Broad Minima相对于Sharp Minima具有更好的泛化性
- 用贝叶斯理论解释泛化性是有效的
- 贝叶斯理论与SGD
- 随机偏微分方程的与Scaling Rule的推导
优化方法
- 使用Large Batch Training提高训练速度
理论基础
从贝叶斯理论角度出发,论证broad minima相对于sharp minima具有更好的泛化性
内容
一般情况下,我们对模型参数的分布会做高斯假设
所以有
可以看出这个公式就是模型训练中Loss Function的主要部分,前面一项H(w;M)是Cost,而后面一项是正则项。我们要最小化Loss Function,本质上是最大化C(w;M)这一项。假设我们训练了两组模型参数,如何判断哪一个模型的泛化性更好?这里使用如下公式来判断。
等式右面的第二项是对模型的偏好因子,在这里应该均设置为1,消除偏置的影响。右边第一项我们叫做Bayesian Evidence Ratio,它描述了训练样本改变了我们对模型先验偏好的程度。为了计算这个比值,我们需要计算分子和分母。
使用泰勒展开式对C(w;M)在最优值w_0附近进行近似展开,得到如下式子。
至此,我们可以对上述公式的结果进行分析。上述公式中最后一项其实就是Occam Factor。通过分析我们也知道二阶导数正负衡量的是函数的凹凸性,而二阶导数的大小衡量和曲率相关。当C''(w_0)越大时,该位置附近就越弯曲,越接近sharp minima,进而导致P(y|x;M)的概率越低,这符合Occam Razor的原则,越简单的模型泛化性越好,这是因为简单的模型是在Broad Minima上。也可以提高正则系数对C''(w_0)进行惩罚,从而控制Occam factor,提高泛化性。当扩展到多个参数后,该公式如下所示。
小结
用贝叶斯理论解释泛化性是有效的
内容
这个实验主要是为了证明Bayesian Evidence的曲线和Test Cross Entropy的变化趋势是一致的,并且也复现了《Understanding deep learning requires rethinking generalization》中呢Deep Learning Model的结果。
小结
贝叶斯理论与SGD
内容
这些实验其实就是验证不同Batch Size训练出的模型在test集上的表现,并说明存在一个最佳的Batch Size,使用它训练出的模型,其泛化性优于其他Batch Size训练出的模型。
小结
随机偏微分方程的与scaling rule的推导
内容
根据中心极限定理,我们可以得出以下结论
所以标准的Stochastic Gradient Descent可以看成是标准梯度加上一个Noise,这个Noise就是α中的内容。下面进一步研究Noise的性质。
其中,F(w)为梯度的协方差项,δ_ij代表了Indicator,即当i=j时,δ_ij=1,否则等于0。这是因为样本和样本之间是相互独立的关系,所以协方差应该等于0。如果看不懂这个公式可以按照下面的原型推理,一目了然。
根据协方差矩阵的可列可拆的性质,我们求得如下期望。
至此,Noise的统计特性已经全部计算出来,下面需要和随机偏微分方程进行等价。首先,SGD的Update规则是一个离散的过程,不是连续的过程。如果我们把SGD的每一步想象成为一个连续的可微分的过程,每次Update一个偏微分算子,那么可以将上述学习率为ε的Update公式看成是某个微分方程的定积分结果,下面先介绍这个偏微分方程(这个偏微分方程的产生来自于《Handbook of Stochastic Methods》)。
这里t是连续的变量,η(t)代表了t时刻的Noise,具有如下性质。
因为我们知道Noise的期望必定等于0,而方差会有个波动的Scale,且波动的大小是以F(w)有关,所以这个Scale我们用g来表示,即Flucturate。而SGD的Update规则可以改写如下所示。
为了探求g的变化因素,我们需要将偏微分方程的最后一项的方差和SGD的α方差对应起来,得到
上面最后的积分公式推导可能会有些迷惑,大概是会迷惑在积分的方差是如何化简到二重积分这一过程,其实积分符号只是个对连续变量的求和过程,所以依然可以使用协方差的可列可拆的性质,如果还是不习惯,将积分符合和dt换成求和符号再去使用协方差公式即可轻松得到结论。
所以,我们得到了结论,SGD引入了一些Noise,这个Noise具有一定的Flucturate,它的大小是和Batch Size成反比,与Learning Rate成正比。
小结
理论总结
优化方法
理论基础公式
对于Momentum-SGD来说,形式表达为(公式推导来自于langvein动力学)
Large batch training的优化原理
作者做了三组实验,一组是标准的对Learning Rate做Decay,一组是固定Rearning Rate不变,在原来发生Learning Rate Decay的轮数将Batch Size扩大N倍(N是Learning Rate Decay的Factor,即与Learning Rate的Decay为相同力度)。另一组是二者的结合Hybrid,即先Learning Rate Decay,后变化Batch Size。实验证明三者的泛化性曲线相同,所以证明了Learning Rate Decay实际上是对g做了Scale down。然而增加Batch Size不但可以达到同样的效果,还能提高计算通信占比,并且在整体训练过程中减少Update的次数,这是Increase Batch Size Training的优化点。
关于Momentum-SGD
更大Batch Size和消除Warm Up
小结
总结
Deep Learning中的Large Batch Training相关理论与实践的更多相关文章
- ON LARGE BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA
目录 概 主要内容 一些解决办法 Keskar N S, Mudigere D, Nocedal J, et al. On Large-Batch Training for Deep Learning ...
- Deep Learning and Shallow Learning
Deep Learning and Shallow Learning 由于 Deep Learning 现在如火如荼的势头,在各种领域逐渐占据 state-of-the-art 的地位,上个学期在一门 ...
- AndrewNG Deep learning课程笔记 - CNN
参考, An Intuitive Explanation of Convolutional Neural Networks http://www.hackcv.com/index.php/archiv ...
- Deep Learning in NLP (一)词向量和语言模型
原文转载:http://licstar.net/archives/328 Deep Learning 算法已经在图像和音频领域取得了惊人的成果,但是在 NLP 领域中尚未见到如此激动人心的结果.关于这 ...
- (转)分布式深度学习系统构建 简介 Distributed Deep Learning
HOME ABOUT CONTACT SUBSCRIBE VIA RSS DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...
- Deep Learning In NLP 神经网络与词向量
0. 词向量是什么 自然语言理解的问题要转化为机器学习的问题,第一步肯定是要找一种方法把这些符号数学化. NLP 中最直观,也是到目前为止最常用的词表示方法是 One-hot Representati ...
- Word2Vec之Deep Learning in NLP (一)词向量和语言模型
转自licstar,真心觉得不错,可惜自己有些东西没有看懂 这篇博客是我看了半年的论文后,自己对 Deep Learning 在 NLP 领域中应用的理解和总结,在此分享.其中必然有局限性,欢迎各种交 ...
- deep learning深度学习之学习笔记基于吴恩达coursera课程
feature study within neural network 在regression问题中,根据房子的size, #bedrooms原始特征可能演算出family size(可住家庭大小), ...
- 学习Data Science/Deep Learning的一些材料
原文发布于我的微信公众号: GeekArtT. 从CFA到如今的Data Science/Deep Learning的学习已经有一年的时间了.期间经历了自我的兴趣.擅长事务的探索和试验,有放弃了的项目 ...
随机推荐
- CNN算法解决MNIST数据集识别问题
网络实现程序如下 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 用于设置将记 ...
- nmap 介绍
原文地址:http://drops.wooyun.org/tips/2002 原文地址:http://infotechbits.wordpress.com/2014/05/04/introductio ...
- Windows 10 无法使用搜索栏,显示一片空白
查看这个:https://blog.csdn.net/qq_41571056/article/details/82928919
- HARD FAULT
程序陷在while(1)里面 解决办法 定点到发生死循环的位置 打开stack windows逐层查找发生死循环之前运行过的函数 导致原因 1 内存溢出或者访问越界,通常为数组或结构体访问越界.这个需 ...
- 风格豆腐干地方v出vccxzzxx
ksdfjlksdjflksdjlfjsdkflj{b7a6e0i010b7b7g2i010b7b7g2i010b7b7c8i010f1j4i010e0h3i010e0h3i010b7a6c8i010 ...
- Using iSCSI On Ubuntu 10.04 (Initiator And Target)
This guide explains how you can set up an iSCSI target and an iSCSI initiator (client), both running ...
- Linux的50个基本命令
1.ls -a 列出当前目录下的所有文件,包括以.头的隐含文件(如-/.bashrc) ls –l 列出当前目录下文件的详细信息 2. pwd 查看当前所在目录的绝对路经 3. cd 目录之间的移动 ...
- 图片处理类 类库--C#
调用如下: Bitmap bitmap = new Bitmap("C:\\Users\\Thinkpad\\Desktop\\aa.jpg"); Bitmap[] bit = n ...
- 16bit CRC算法C语言实现
#define CRC_16_POLYNOMIALS 0x8005 unsigned short CRC16_3(unsigned char* pchMsg, unsigned short wData ...
- 背水一战 Windows 10 (102) - 应用间通信: 剪切板
[源码下载] 背水一战 Windows 10 (102) - 应用间通信: 剪切板 作者:webabcd 介绍背水一战 Windows 10 之 应用间通信 剪切板 - 基础, 复制/粘贴 text ...