1. 摘要

  • 训练深层的神经网络非常困难,因为在训练的过程中,随着前面层数参数的改变,每层输入的分布也会随之改变。这需要我们设置较小的学习率并且谨慎地对参数进行初始化,因此训练过程比较缓慢。

  • 作者将这种现象称之为 internal covariate shift,通过对每层的输入进行归一化来解决这个问题。

  • 引入 BN 后,我们可以不用太在意参数的初始化,同时使用更大的学习率,而且也会有正则化的效果,在一些情况下可以不用再使用 Dropout。

2. 介绍

因为网络中每一层的输入都受到前面所有层参数的影响,因此参数的一个小的改变将会随着网络加深而被逐渐放大。在这样的情况下,每一层网络都需要一直去适应新的输入分布,训练过程也就变得很慢。

考虑如下的网络计算

\(F_1\) 和 \(F_2\) 是任意的变换,\(\Theta_1\) 和 \(\Theta_2\) 是需要学习的参数。学习 \(\Theta_2\) 可以看作是输入 \(x=F_1(u, \Theta_1)\) 被传入子网络

因此,如果 \(x\) 的分布保持不变的话, \(\Theta_2\) 就不用去重新适应来补偿 \(x\) 分布的变化。

另一方面,如果我们采用 Sigmoid 激活函数的话,随着网络加深,我们很容易落入到饱和区域,容易产生梯度消失现象,减缓训练过程。但是,如果我们能够确保非线性输入的分布维持稳定的话,优化就不容易陷入到饱和区域,从而能够加速训练。

3. 小批量归一化

针对每一层的所有输入进行白化代价非常高并且不是处处可微,因此作者进行了两个必要的简化。

第一,我们独立地归一化一层中的每一个特征,让它们具有零均值标准方差。针对一个 \(d\) 维的输入 \(x=(x^{(1)}...x^{(d)})\),我们将分别对每一维进行归一化。

但是,这样简单地归一化每个输入会改变当前层网络的表达能力。比如,归一化 Sigmoid 函数的输入将会使它们落入激活函数的线性区域。为了解决这个问题,我们要保证嵌入到网络中的变换能够表示恒等变换。对此,我们引入一对参数 \(\gamma^{(k)}, \beta^{(k)}\) 来对归一化后的值再进行缩放和平移。

这样,通过设定 \(\gamma^{(k)}=\sqrt {Var[x^{(k)}]}, \beta^{(k)}=E[x^{(k)}]\) ,如果原始激活值是最优的话,我们也能够恢复到原有状态

第二,用小批量样本来产生对每个激活值均值和方差的估计。针对 \(m\) 个样本的小批量,归一化变换如下所示:

在训练过程中,我们需要计算 BN 变换参数的梯度来进行反向传播,根据链式法则,有

因此,BN 在网络中引入了对激活值的归一化,并且是一个可微的变换。这样,每一层都可以在同样的分布上持续学习而不用担心内部偏移问题,所以可以加速训练过程。最后,在归一化后学习到的仿射变换允许网络表示恒等变换,因此也保留了网络的容量也即表示能力。

4. 测试

虽然对小批量的激活值进行归一化在训练时是非常有效的,但在测试时却是不必要也不可取的,我们想让输出只确定地依赖于输入。因此,一旦训练好了一个网络,我们用训练时总体的均值和方差来进行归一化。

忽略 \(\epsilon\),这些归一化后的激活值就具有了和训练时一样的零均值和标准方差。我们采用无偏的方差估计

期望是根据训练过程中所有大小为 \(m\) 的小批量样本计算的,\(\sigma_B^2\) 代表它们的方差。同时,我们使用滑动平均来跟踪训练过程中每个小批量的均值和方差。

5. 在全连接和卷积网络中引入 BN

针对全连接网络,我们在非线性之前加入 BN 变换,对 \(x=Wu+b\) 进行归一化。注意到,偏置 \(b\) 可以被忽略,因为在后序的减去均值的过程中它的作用被抵消了。因此,就变成了

\[z=g(Wu+b) \to z=g(BN(Wu))\]

对于 \(x\) 的每一个维度我们学习一对参数 \(\gamma^{(k)}, \beta^{(k)}\) 。

针对卷积网络,我们想要归一化保持卷积的特性,因此,不同样本的所有位置的同一个特征图用同样的方式进行归一化。对于一个大小为 \(m\) 的小批量样本和大小为 \(p×q\) 的特征图,有效的小批次为

\[m'=m\cdot pq\]

对于每一个特征图我们学习一对参数 \(\gamma^{(k)}, \beta^{(k)}\) 。

6. BN 允许更高的学习率

在传统的网络中,太高的学习率可能会导致梯度消失或者爆炸,也会使网络陷入在糟糕的局部最优解。但引入 BN 后,它会阻止参数的小的变化被放大成激活值和梯度的更大变化或者次优变化,比如说不会让训练陷入到非线性的饱和区域。

BN 也使得训练对参数的规模更适应。正常情况下,大的学习率会增大参数,然后放大反向传播的梯度导致模型爆炸。但是,引入 BN 后,反向传播不受参数的规模影响。实际上,对于标量 \(a\),有

\[x_i = Wu_i\]
\[\mu = \frac{1}{m}\sum_{i=1}^{m}x_i\]
\[\sigma^2 = \frac{1}{m}\sum_{i=1}^{m}(x_i-\mu)^2\]
\[\hat x_i = \frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}}\]
\[y_i =\gamma \hat x_i+\beta=BN[Wu_i]\]

所以,\(BN[Wu] = BN[(aW)u]\),\(\hat x_i\) 求取时分子分母都放大 \(a\) 倍。反向传播时,有

\[\frac{\partial BN[(aW)u]}{\partial x} = \frac{1}{a} \cdot \frac{\partial BN[Wu]}{\partial x}\]

\[\tag{1}\frac{\partial BN[(aW)u]}{\partial u} = \frac{\partial BN[Wu]}{\partial u}\]
\[\tag{2}\frac{\partial BN[(aW)u]}{\partial (aW)} = \frac{1}{a} \cdot \frac{\partial BN[Wu]}{\partial W}\]

由式 (1) 可以看到,参数的规模不影响梯度的反向传播。而且,由式 (2) 知,较大的参数将会获得较小的梯度,BN 能够稳定参数的增长

7. BN 的正则化效果

当使用 BN 进行训练时,小批次中的一个样本会与其他样本结合在一起被传入网络中,网络不再会为某个给定的训练样例生成确定值。在实验中,作者发现这种效应有利于网络的泛化。引入 BN 后,Dropout 可以被移除或减少作用。

8. 加速 BN 网络的训练

仅仅在网络中添加 BN 不能充分利用这种方法的优越性,除此之外,作者采用了以下策略:

  • 增大学习率
  • 移除 Dropout
  • 减小 L2 正则化权重
  • 加快学习率衰减
  • 移除局部响应归一化
  • 更彻底地打乱训练数据,防止同样的数据总出现在同一个批次中
  • 减少光度畸变

9. 实验结果

可以看到,在 MNIST 数据集上,引入 BN 后网络网络收敛很快,并且输入的分布更加稳定。

在 ImageNet 数据集上,引入 BN 后很快就达到了原来 Inception 网络取得的准确率。

获取更多精彩,请关注「seniusen」!

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift的更多相关文章

  1. Batch normalization:accelerating deep network training by reducing internal covariate shift的笔记

    说实话,这篇paper看了很久,,到现在对里面的一些东西还不是很好的理解. 下面是我的理解,当同行看到的话,留言交流交流啊!!!!! 这篇文章的中心点:围绕着如何降低  internal covari ...

  2. Deep Learning 27:Batch normalization理解——读论文“Batch normalization: Accelerating deep network training by reducing internal covariate shift ”——ICML 2015

    这篇经典论文,甚至可以说是2015年最牛的一篇论文,早就有很多人解读,不需要自己着摸,但是看了论文原文Batch normalization: Accelerating deep network tr ...

  3. 图像分类(二)GoogLenet Inception_v2:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    Inception V2网络中的代表是加入了BN(Batch Normalization)层,并且使用 2个 3*3卷积替代 1个5*5卷积的改进版,如下图所示: 其特点如下: 学习VGG用2个 3* ...

  4. 论文笔记:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    ICML, 2015 S. Ioffe and C. Szegedy 解决什么问题(What) 分布不一致导致训练慢:每一层的分布会受到前层的影响,当前层分布发生变化时,后层网络需要去适应这个分布,训 ...

  5. Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift(BN)

    internal covariate shift(ics):训练深度神经网络是复杂的,因为在训练过程中,每层的输入分布会随着之前层的参数变化而发生变化.所以训练需要更小的学习速度和careful参数初 ...

  6. Batch Normalization原理及其TensorFlow实现——为了减少深度神经网络中的internal covariate shift,论文中提出了Batch Normalization算法,首先是对”每一层“的输入做一个Batch Normalization 变换

    批标准化(Bactch Normalization,BN)是为了克服神经网络加深导致难以训练而诞生的,随着神经网络深度加深,训练起来就会越来越困难,收敛速度回很慢,常常会导致梯度弥散问题(Vanish ...

  7. 深度学习网络层之 Batch Normalization

    Batch Normalization Ioffe 和 Szegedy 在2015年<Batch Normalization: Accelerating Deep Network Trainin ...

  8. 【深度学习】深入理解Batch Normalization批标准化

    这几天面试经常被问到BN层的原理,虽然回答上来了,但还是感觉答得不是很好,今天仔细研究了一下Batch Normalization的原理,以下为参考网上几篇文章总结得出. Batch Normaliz ...

  9. 解读Batch Normalization

    原文转自:http://blog.csdn.net/shuzfan/article/details/50723877 本次所讲的内容为Batch Normalization,简称BN,来源于<B ...

随机推荐

  1. STM32启动代码分析

    STM32启动文件简单分析(STM32F10x.s适用范围)定时器, 型号, 名字在<<STM32不完全手册里面>>,我们所有的例程都采用了一个叫STM32F10x.s的启动文 ...

  2. Oracle高级函数篇之递归查询start with connect by prior简单用法

    路飞:" 把原来CSDN的博客转移到博客园咯!" 前段时间,自己负责的任务中刚好涉及到了组织关系的业务需求,自己用了oracle递归查询.下面简单来举个例子.在工作中我们经常会遇到 ...

  3. WebGl 旋转(矩阵变换)

    代码1: <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF- ...

  4. C语言写的2048小游戏

    基于"基于C_语言的2048算法设计_颜冠鹏.pdf" 这一篇文献提供的思路 在中国知网上能找到 就不贴具体内容了 [摘 要] 针对2048的游戏规则,分析了该游戏的算法特点,对其 ...

  5. 解决在 win10 下 vs2017 中创建 MFC 程序拖放文件接收不到 WM_DROPFILES 消息问题

    解决方案 这个问题是由于 win10 的安全机制搞的鬼,即使以管理员权限运行也不行,因为它会把 WM_DROPFILES 消息过滤掉,那怎么办呢?只需在窗口初始化 OnInitDialog() 里添加 ...

  6. Python学习 :socket基础

    socket基础 什么是socket? - socket为接口通道,内部封装了IP地址.端口.协议等信息:我们可以看作是以前的通过电话机拨号上网的年代,socket即为电话线 socket通信流程 我 ...

  7. aircrack-ng 破解无线网络

    1.科普当今时代,wifi 已成为我们不可缺少的一部分,上网.看视频.玩游戏,没有 wifi 你就等着交高额的流量费吧,本来我想单独的写 wpa 破解和 wps 破解,后来觉得分开写过于繁琐,索性合并 ...

  8. linux——nginx的安装及配置

    目录 1. 在Linux上安装nginx 2. 配置nginx反向代理 1. 在Linux上安装ngix 1.1 在以下网页下载nginx的tar包,并将其传到linux中 http://nginx. ...

  9. GC错误

    如果出现GC错误,可设置客户端 set mapreduce.map.java.opts 设置一下 R的GC错误,在顶端设置这个参数 options(java.parameters = "-X ...

  10. FlexPaper 里的pdf2json.exe 下载地址

    在使用FlexPaper 做在线阅读,需要使用到pdf2json.exe,将PDF转成JSON或者XML格式,网上很少下载的,现在提供一个下载的地址 http://pan.baidu.com/s/1i ...