本文转自:谷歌工程师:聊一聊深度学习的weight initialization

TLDR (or the take-away)

Weight Initialization matters!!! 深度学习中的weight initialization对模型收敛速度和模型质量有重要影响!

  • 在ReLU activation function中推荐使用Xavier Initialization的变种,暂且称之为He Initialization:

  • 使用Batch Normalization Layer可以有效降低深度网络对weight初始化的依赖:

实验代码请参见我的Github

背景

深度学习模型训练的过程本质是对weight(即参数 W)进行更新,这需要每个参数有相应的初始值。有人可能会说:“参数初始化有什么难点?直接将所有weight初始化为0或者初始化为随机数!”对一些简单的机器学习模型,或当optimization function是convex function时,这些简单的方法确实有效。

然而对于深度学习而言,非线性函数被疯狂叠加,产生如本文题图所示的non-convex function,如何选择参数初始值便成为一个值得探讨的问题——其本质是初始参数的选择应使得objective function便于被优化。事实上,在学术界这也是一个被actively研究的领域。

TLDR里已经涵盖了本文的核心要点,下面在正文中,我们来深入了解一下前因后果。

初始化为0的可行性?

答案是不可行。 这是一道送分题 哈哈!为什么将所有W初始化为0是错误的呢?是因为如果所有的参数都是0,那么所有神经元的输出都将是相同的,那在back propagation的时候同一层内所有神经元的行为也是相同的 --- gradient相同,weight update也相同。这显然是一个不可接受的结果。

可行的几种初始化方式

pre-training

pre-training是早期训练神经网络的有效初始化方法,一个便于理解的例子是先使用greedy layerwise auto-encoder做unsupervised pre-training,然后再做fine-tuning。具体过程可以参见UFLDL的一个tutorial,因为这不是本文重点,就在这里简略的说一下:

  • pre-training阶段,将神经网络中的每一层取出,构造一个auto-encoder做训练,使得输入层和输出层保持一致。在这一过程中,参数得以更新,形成初始值

  • fine-tuning阶段,将pre-train过的每一层放回神经网络,利用pre-train阶段得到的参数初始值和训练数据对模型进行整体调整。在这一过程中,参数进一步被更新,形成最终模型。

随着数据量的增加以及activation function (参见我的另一篇文章) 的发展,pre-training的概念已经渐渐发生变化。目前,从零开始训练神经网络时我们也很少采用auto-encoder进行pre-training,而是直奔主题做模型训练。不想从零开始训练神经网络时,我们往往选择一个已经训练好的在任务A上的模型(称为pre-trained model),将其放在任务B上做模型调整(称为fine-tuning)。

random initialization

随机初始化是很多人目前经常使用的方法,然而这是有弊端的,一旦随机分布选择不当,就会导致网络优化陷入困境。下面举几个例子。

核心代码见下方,完整代码请参见我的Github

这里我们创建了一个10层的神经网络,非线性变换为tanh,每一层的参数都是随机正态分布,均值为0,标准差为0.01。下图给出了每一层输出值分布的直方图。

随着层数的增加,我们看到输出值迅速向0靠拢,在后几层中,几乎所有的输出值 x 都很接近0!回忆优化神经网络的back propagation算法,根据链式法则,gradient等于当前函数的gradient乘以后一层的gradient,这意味着输出值 x 是计算gradient中的乘法因子,直接导致gradient很小,使得参数难以被更新!

让我们将初始值调大一些:

均值仍然为0,标准差现在变为1,下图是每一层输出值分布的直方图:

几乎所有的值集中在-1或1附近,神经元saturated了!注意到tanh在-1和1附近的gradient都接近0,这同样导致了gradient太小,参数难以被更新。

Xavier initialization

Xavier initialization可以解决上面的问题!其初始化方式也并不复杂。Xavier初始化的基本思想是保持输入和输出的方差一致,这样就避免了所有输出值都趋向于0。注意,为了问题的简便,Xavier初始化的推导过程是基于线性函数的,但是它在一些非线性神经元中也很有效。让我们试一下:

Woohoo!输出值在很多层之后依然保持着良好的分布,这很有利于我们优化神经网络!之前谈到Xavier initialization是在线性函数上推导得出,这说明它对非线性函数并不具有普适性,所以这个例子仅仅说明它对tanh很有效,那么对于目前最常用的ReLU神经元呢(关于不同非线性神经元的比较请参考这里)?继续做一下实验:

前面看起来还不错,后面的趋势却是越来越接近0。幸运的是,He initialization可以用来解决ReLU初始化的问题。

He initialization

He initialization的思想是:在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0,所以,要保持variance不变,只需要在Xavier的基础上再除以2:

看起来效果非常好,推荐在ReLU网络中使用!

Batch Normalization Layer

Batch Normalization是一种巧妙而粗暴的方法来削弱bad initialization的影响,其基本思想是:If you want it, just make it!

我们想要的是在非线性activation之前,输出值应该有比较好的分布(例如高斯分布),以便于back propagation时计算gradient,更新weight。Batch Normalization将输出值强行做一次Gaussian Normalization和线性变换:

Batch Normalization中所有的操作都是平滑可导,这使得back propagation可以有效运行并学到相应的参数γ,β。需要注意的一点是Batch Normalization在training和testing时行为有所差别。Training时μβ和σβ由当前batch计算得出;在Testing时μβ和σβ应使用Training时保存的均值或类似的经过处理的值,而不是由当前batch计算。

随机初始化,无Batch Normalization:

随机初始化,有Batch Normalization:

(转载)深度学习的weight initialization的更多相关文章

  1. [深度学习] 权重初始化--Weight Initialization

    深度学习中的weight initialization对模型收敛速度和模型质量有重要影响! 在ReLU activation function中推荐使用Xavier Initialization的变种 ...

  2. 深度学习 weight initialization

    转自: https://www.leiphone.com/news/201703/3qMp45aQtbxTdzmK.htmla https://blog.csdn.net/shuzfan/articl ...

  3. paper 53 :深度学习(转载)

    转载来源:http://blog.csdn.net/fengbingchun/article/details/50087005 这篇文章主要是为了对深度学习(DeepLearning)有个初步了解,算 ...

  4. <深度学习优化策略-3> 深度学习网络加速器Weight Normalization_WN

    前面我们学习过深度学习中用于加速网络训练.提升网络泛化能力的两种策略:Batch Normalization(Batch Normalization)和Layer Normalization(LN). ...

  5. [转载]Deep Learning(深度学习)学习笔记整理

    转载自:http://blog.csdn.net/zouxy09/article/details/8775360 感谢原作者:zouxy09@qq.com 八.Deep learning训练过程 8. ...

  6. 百度首席科学家 Andrew Ng谈深度学习的挑战和未来(转载)

    转载:http://www.csdn.net/article/2014-07-10/2820600 人工智能被认为是下一个互联网大事件,当下,谷歌.微软.百度等知名的高科技公司争相投入资源,占领深度学 ...

  7. 近200篇机器学习&深度学习资料分享【转载】

    编者按:本文收集了百来篇关于机器学习和深度学习的资料,含各种文档,视频,源码等.而且原文也会不定期的更新,望看到文章的朋友能够学到更多. <Brief History of Machine Le ...

  8. 转载-【深度学习】深入理解Batch Normalization批标准化

      全文转载于郭耀华-[深度学习]深入理解Batch Normalization批标准化:   文章链接Batch Normalization: Accelerating Deep Network T ...

  9. 转载:深度学习在NLP中的应用

    之前研究的CRF算法,在中文分词,词性标注,语义分析中应用非常广泛.但是分词技术只是NLP的一个基础部分,在人机对话,机器翻译中,深度学习将大显身手.这篇文章,将展示深度学习的强大之处,区别于之前用符 ...

随机推荐

  1. CodeForces999E 双dfs // 标记覆盖 // tarjan缩点

    http://codeforces.com/problemset/problem/999/E 题意 有向图    给你n个点,m条边,以及一个初始点s,问你至少还需要增加多少条边,使得初始点s与剩下其 ...

  2. 写给IT技术爱好者的一封信

    写给IT技术爱好者的一封信>当前运维素质的分析<... ---------------------- 虽相貌平平,但勤学苦练,亦收获颇丰!如果你决定要成为一名IT从业者,你需要承受以下的东 ...

  3. beego 实现API自动化文档

    安装beego和bee工具 1.beego安装 go get -u github.com/astaxie/beego 2.安装bee工具 go get -u github.com/beego/bee ...

  4. nginx 限速最容易理解的说明

    nginx 限速研究汇报 写在前面 这两天服务器带宽爆了,情况如下图: 出于降低带宽峰值的原因,我开始各种疯狂的研究nginx限速.下面是我研究过程中的心得!(花了好几个小时的时间写的人生第一篇技术类 ...

  5. ffmpeg的各种黑科技

    获取音频的时长 /** * 获取视频文件的时长 * @param ffmpegPath 是ffmpeg软件存放的目录,sourceFile是目标文件 * @return */ public Strin ...

  6. Nginx 简单安装

    关于nginx的介绍:http://baike.baidu.com/link?url=UZjpP8y4MUgrjqACfxd7c3Cm4L8PnJsKkxSJQ3CrRcKZUjkkl4zGNs6Pr ...

  7. 在Java中如何高效的判断数组中是否包含某个元素

    原文出处: hollischuang(@Hollis_Chuang) 如何检查一个数组(无序)是否包含一个特定的值?这是一个在Java中经常用到的并且非常有用的操作.同时,这个问题在Stack Ove ...

  8. JS数据结构库

    lodash https://lodash.com/docs#now https://lodash.com/ A modern JavaScript utility library deliverin ...

  9. JavaSE回顾及巩固的自学之路(二)——————进入JavaSE

    好的.今天接着上一篇文章对JavaSE的历程初步介绍,开始对JavaSE的技术性知识进行探讨. 首先,选择编程,成为一名程序员,应该会了解一些计算机的相关基础知识,毕竟,以后就是和计算机打交道了嘛.s ...

  10. thymeleaf 传参到js的onclick事件中

    html: <img th:onclick="'javascript:imgClick(\''+${card.id}+'\',\''+${card.name}+'\');'" ...