文章来自:公众号【机器学习炼丹术】。求关注~

其实关于BN层,我在之前的文章“梯度爆炸”那一篇中已经涉及到了,但是鉴于面试经历中多次问道这个,这里再做一个更加全面的讲解。

Internal Covariate Shift(ICS)

Batch Normalization的原论文作者给了Internal Covariate Shift一个较规范的定义:在深层网络训练的过程中,由于网络中参数变化而引起内部结点数据分布发生变化的这一过程被称作Internal Covariate Shift。

这里做一个简单的数学定义,对于全链接网络而言,第i层的数学表达可以体现为:

\(Z^i=W^i\times input^i+b^i\)

\(input^{i+1}=g^i(Z^i)\)

  • 第一个公式就是一个简单的线性变换;
  • 第二个公式是表示一个激活函数的过程。

【怎么理解ICS问题】

我们知道,随着梯度下降的进行,每一层的参数\(W^i,b^i\)都会不断地更新,这意味着\(Z^i\)的分布也不断地改变,从而\(input^{i+1}\)的分布发生了改变。这意味着,除了第一层的输入数据不改变,之后所有层的输入数据的分布都会随着模型参数的更新发生改变,而每一层就要不停的去适应这种数据分布的变化,这个过程就是Internal Covariate Shift。

BN解决的问题

【ICS带来的收敛速度慢】

因为每一层的参数不断发生变化,从而每一层的计算结果的分布发生变化,后层网络不断地适应这种分布变化,这个时候会让整个网络的学习速度过慢。

【梯度饱和问题】

因为神经网络中经常会采用sigmoid,tanh这样的饱和激活函数(saturated actication function),因此模型训练有陷入梯度饱和区的风险。解决这样的梯度饱和问题有两个思路:第一种就是更为非饱和性激活函数,例如线性整流函数ReLU可以在一定程度上解决训练进入梯度饱和区的问题。另一种思路是,我们可以让激活函数的输入分布保持在一个稳定状态来尽可能避免它们陷入梯度饱和区,这也就是Normalization的思路。

Batch Normalization

batchNormalization就像是名字一样,对一个batch的数据进行normalization。

现在假设一个batch有3个数据,每个数据有两个特征:(1,2),(2,3),(0,1)

如果做一个简单的normalization,那么就是计算均值和方差,把数据减去均值除以标准差,变成0均值1方差的标准形式。

对于第一个特征来说:

\(\mu=\frac{1}{3}(1+2+0)=1\)

\(\sigma^2=\frac{1}{3}((1-1)^2+(2-1)^2+(0-1)^2)=0.67\)

【通用公式】

\(\mu=\frac{1}{m}\sum_{i=1}^m{Z}\)

\(\sigma^2=\frac{1}{m}\sum_{i=1}^m(Z-\mu)\)

\(\hat{Z}=\frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}\)

  • 其中m表示一个batch的数量。
  • \(\epsilon\)是一个极小数,防止分母为0。

目前为止,我们做到了让每个特征的分布均值为0,方差为1。这样分布都一样,一定不会有ICS问题

如同上面提到的,Normalization操作我们虽然缓解了ICS问题,让每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。每一层的分布都相同,所有任务的数据分布都相同,模型学啥呢

【0均值1方差数据的弊端】

  1. 数据表达能力的缺失;
  2. 通过让每一层的输入分布均值为0,方差为1,会使得输入在经过sigmoid或tanh激活函数时,容易陷入非线性激活函数的线性区域。(线性区域和饱和区域都不理想,最好是非线性区域)

为了解决这个问题,BN层引入了两个可学习的参数\(\gamma\)和\(\beta\),这样,经过BN层normalization的数据其实是服从\(\beta\)均值,\(\gamma^2\)方差的数据。

所以对于某一层的网络来说,我们现在变成这样的流程:

  1. \(Z=W\times input^i+b\)
  2. \(\hat{Z}=\gamma \times \frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta\)
  3. \(input^{i+1}=g(\hat{Z})\)

(上面公式中,省略了\(i\),总的来说是表示第i层的网络层产生第i+1层输入数据的过程)

测试阶段的BN

我们知道BN在每一层计算的\(\mu\)与\(\sigma^2\) 都是基于当前batch中的训练数据,但是这就带来了一个问题:我们在预测阶段,有可能只需要预测一个样本或很少的样本,没有像训练样本中那么多的数据,这样的\(\sigma^2\)和\(\mu\)要怎么计算呢?

利用训练集训练好模型之后,其实每一层的BN层都保留下了每一个batch算出来的\(\mu\)和\(\sigma^2\).然后呢利用整体的训练集来估计测试集的\(\mu_{test}\)和\(\sigma_{test}^2\)

\(\mu_{test}=E(\mu_{train})\)

\(\sigma_{test}^2=\frac{m}{m-1}E(\sigma_{train}^2)\)

然后再对测试机进行BN层:

当然,计算训练集的\(\mu\)和\(\simga\)的方法除了上面的求均值之外。吴恩达老师在其课程中也提出了,可以使用指数加权平均的方法。不过都是同样的道理,根据整个训练集来估计测试机的均值方差。

BN层的好处有哪些

  1. BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度。

    BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。

  2. BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题

    通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题;另外通过自适应学习\(\gamma\)与 \(\beta\) 又让数据保留更多的原始信息。

  3. BN具有一定的正则化效果

    在Batch Normalization中,由于我们使用mini-batch的均值与方差作为对整体训练样本均值与方差的估计,尽管每一个batch中的数据都是从总体样本中抽样得到,但不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音

BN与其他normalizaiton的比较

【weight normalization】

Weight Normalization是对网络权值进行normalization,也就是L2 norm。

相对于BN有下面的优势:

  1. WN通过重写神经网络的权重的方式来加速网络参数的收敛,不依赖于mini-batch。BN因为以来minibatch所以BN不能用于RNN网路,而WN可以。而且BN要保存每一个batch的均值方差,所以WN节省内存;
  2. BN的优点中有正则化效果,但是添加噪音不适合对噪声敏感的强化学习、GAN等网络。WN可以引入更小的噪音。

但是WN要特别注意参数初始化的选择。


【Layer normalization】

更常见的比较是BN与LN的比较。

BN层有两个缺点:

  1. 无法进行在线学习,因为在线学习的mini-batch为1;LN可以
  2. 之前提到的BN不能用在RNN中;LN可以
  3. 消耗一定的内存来记录均值和方差;LN不用

但是,在CNN中LN并没有取得比BN更好的效果。

参考链接:

  1. https://zhuanlan.zhihu.com/p/34879333
  2. https://www.zhihu.com/question/59728870
  3. https://zhuanlan.zhihu.com/p/113233908
  4. https://www.zhihu.com/question/55890057/answer/267872896





干货 | 这可能全网最好的BatchNorm详解的更多相关文章

  1. [转帖]HTTPS系列干货(一):HTTPS 原理详解

    HTTPS系列干货(一):HTTPS 原理详解 https://tech.upyun.com/article/192/HTTPS%E7%B3%BB%E5%88%97%E5%B9%B2%E8%B4%A7 ...

  2. 【转】HTTPS系列干货(一):HTTPS 原理详解

    HTTPS系列干货(一):HTTPS 原理详解 前言 HTTPS(全称:HyperText Transfer Protocol over Secure Socket Layer),其实 HTTPS 并 ...

  3. Mybatis系列全解(五):全网最全!详解Mybatis的Mapper映射文件

    封面:洛小汐 作者:潘潘 若不是生活所迫,谁愿意背负一身才华. 前言 上节我们介绍了 < Mybatis系列全解(四):全网最全!Mybatis配置文件 XML 全貌详解 >,内容很详细( ...

  4. HTTPS系列干货(一):HTTPS 原理详解

    HTTPS(全称:HyperText Transfer Protocol over Secure Socket Layer),其实 HTTPS 并不是一个新鲜协议,Google 很早就开始启用了,初衷 ...

  5. 【腾讯Bugly干货分享】iOS10 SiriKit QQ适配详解

    本文来自于腾讯bugly开发者社区,非经作者同意,请勿转载,原文地址:http://dev.qq.com/topic/57ece0331288fb4d31137da6 1. 概述 苹果在iOS10开放 ...

  6. 干货分享:Academic Essay写作套路详解

    你想过如何中立的表达自己吗?大概只有10%不到的同学,会真正重视这个细节.但很多留学生能顺利写完作文已经不容易,还要注意什么中立不中立的.我知道这个标准,对许多同学有些过分,但很残酷的告诉你,这的确是 ...

  7. 干货分享:Research Essay写作规范详解

    同学们在刚到国外时觉得一切都很新鲜,感觉到处都在吸引着他们,但是大部分留学生在刚碰到Research Essay便是一头包.其实Research Essay也没有想象中的那么难,只是留学生们初次接触, ...

  8. 全程干货,requests模块与selenium框架详解

    requests模块 前言: 通常我们利用Python写一些WEB程序.webAPI部署在服务端,让客户端request,我们作为服务器端response数据: 但也可以反主为客利用Python的re ...

  9. Mybatis系列全解(四):全网最全!Mybatis配置文件XML全貌详解

    封面:洛小汐 作者:潘潘 做大事和做小事的难度是一样的.两者都会消耗你的时间和精力,所以如果决心做事,就要做大事,要确保你的梦想值得追求,未来的收获可以配得上你的努力. 前言 上一篇文章 <My ...

随机推荐

  1. GIT更换连接方式

    1-使用 git remote -v 查看对应的克隆地址: git remote -v origin https://github.com/username/repository.git (fetch ...

  2. Kubernetes 两步验证 - 使用 Serverless 实现动态准入控制

    作者:CODING - 王炜 1. 背景 如果对 Kubernetes 集群安全特别关注,那么我们可能想要实现这些需求: 如何实现 Kubernetes 集群的两步验证,除了集群凭据,还需要提供一次性 ...

  3. Nginx 从入门到放弃(五)

    nginx的rewrite重写 nginx具有将一个路由经过加工变形成另外一个路由的功能,这就叫做重写. 重写中用到的指令 if (条件) {} 设定条件,再进行重写 set # 设定变量 retur ...

  4. 恕我直言你可能真的不会java第12篇-如何使用Stream API对Map类型元素排序

    在这篇文章中,您将学习如何使用Java对Map进行排序.前几日有位朋友面试遇到了这个问题,看似很简单的问题,但是如果不仔细研究一下也是很容易让人懵圈的面试题.所以我决定写这样一篇文章.在Java中,有 ...

  5. 赞!7000 字学习笔记,一天搞定 MySQL

    MySQL数据库简介 MySQL近两年一直稳居第二,随时有可能超过Oracle计晋升为第一名,因为MySQL的性能一直在被优化,同时安全机制也是逐渐成熟,更重要的是开源免费的. MySQL是一种关系数 ...

  6. Python 之父说 Python 历史

    前言 本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理. 作者:鸿影洲冷 这篇文章主要内容来源于 Python 编程语言的最初设计者 ...

  7. 洛谷 P3574 [POI2014]FAR-FarmCraft

    题目传送门 题目描述 输入输出格式 输入格式: 输出格式: 一行,包含一个整数,代表题目中所说的最小时间. 输入输出样例 样例输入 样例输出 提示 分析 我们设f[x]为遍历完以x为根的子树且将这棵子 ...

  8. 使用LLDB和debugserver对ios程序进行调试

    在没有WIFI的情况下,使用USB连接IOS设备,使用辅助插件usbmuxd来辅助调试.我其实也想用wifi调试,奈何公司的wifi绑定了mac地址,而我又使用的是黑苹果虚拟机,使用桥接的方式修改网段 ...

  9. python中获取文件路径的几种方式

    # 如果执行文件为E:\aa\bb\aa.py 1.获取当前路径 current_path11 = os.path.abspath(__file__) current_path12 = os.path ...

  10. day68 form组件

    目录 一.自定义分页器的拷贝和使用 二.Forms组件 1 前戏 2 form组件的基本功能 3 基本使用 4 基本方法 5 渲染标签 6 展示提示信息 7 钩子函数(HOOK) 8 forms组件其 ...