一个典型的SGD过程中,一个epoch内的一批样本的平均梯度与梯度方差,在下图中得到了展示。

无论什么样的网络结构,无论是哪一层网络的梯度,大体上都遵循下面这样的规律:

高信号/噪音比一段时间之后,信号/噪音比逐渐降低,收敛速度减缓,梯度的方差增大,梯度均值减小。

噪音增加的作用及其必要性会在另一篇文章中阐述,这里仅讨论噪音的产生对于模型收敛速度能够产生怎样的影响。

首先定义模型收敛速度:训练后期,噪音梯度导致权重更新时,导致系统新增的熵 H(混乱度)对于SGD迭代次数 t 的导数。

对于第k层的权重的梯度,每一轮(时间t)更新:

\[\frac{\partial {{\mathbf{W}}^{\left( k \right)}}}{\partial t}=-\nabla \operatorname{E}({{\mathbf{W}}^{\left( k \right)}})+\beta _{\left( k \right)}^{-1}\xi \left( t \right)\]

其中E是全局损失函数, $\beta $是信号/噪音比,$\xi $是高斯白噪音, $P\left( \xi \left( t \right) \right)=Norm\left( 0,\sigma \left( t \right) \right)$ ,方差$\sigma \left( t \right)$随着时间增加而变大。

因为使用高噪音进行梯度下降更新权重W时引进了额外的熵,考虑熵的变化$\Delta H({{\mathbf{W}}^{(k)}})$

假设将损失函数E分割成非常多个小区间,问题转化为:$\Delta H({{\mathbf{W}}^{(k)}})\text{=}\Delta H({{\text{E}}_{1}}({{\mathbf{W}}^{(k)}}),{{\text{E}}_{2}}({{\mathbf{W}}^{(k)}})......{{\text{E}}_{N}}({{\mathbf{W}}^{(k)}}))$

已知$\operatorname{H}\left( E \right)=-\underset{\text{i}}{\mathop{\sum }}\,p\left( {{\text{E}}_{\text{i}}} \right)\log p\left( {{\text{E}}_{\text{i}}} \right)$

\[\frac{\partial \operatorname{H}}{\partial p}=-\left( \sum\limits_{\text{i}}{\log \left( p\left( {{E}_{i}} \right) \right)+1} \right)\]

又已知系统达到热平衡后,使熵最大的p(W)分布是玻尔兹曼分布(参见Boltzmann与最大熵的关联文章

${{p}_{E={{E}_{i}}}}\left( \mathbf{W} \right)=\frac{1}{\text{Z}}{{\text{e}}^{-\beta {{E}_{i}}\left( \mathbf{W} \right)}}$ ,Z是配分函数partition function $Z=\sum\limits_{E'}{{{e}^{-\beta E'(\mathbf{W})}}}$

考虑热平衡附近时,p怎样随着E改变:

\[\frac{\partial p}{\partial E}=\frac{\partial }{\partial {{E}_{\text{i}}}}\left( {{{e}^{-\beta {{E}_{i}}}}}/{\left( {{e}^{-\beta {{E}_{i}}}}+\sum\nolimits_{k\ne i}{{{e}^{-\beta {{E}_{k}}}}} \right)}\; \right)=-\beta p(1-p)\]

使用链式法则得到:

$\frac{\partial \text{H}}{\partial t}=\sum\limits_{i}{\frac{\partial \text{H}}{\partial {{p}_{i}}}\frac{\partial {{p}_{i}}}{\partial {{\text{E}}_{i}}}\frac{\partial {{\text{E}}_{i}}}{\partial \mathbf{W}}\frac{\partial \mathbf{W}}{\partial t}}$

训练到接近收敛时,尽管每次更新权重时计算的loss的白噪音会越来越大,但全局loss E会稳定得多,并且逐渐下降到一个比较小的区间内,所以只考虑该区间内对应的$\Delta \text{H}$以及$\Delta \text{t}$,带入前面求出的偏导得到:

\[\frac{\partial H}{\partial t}=\sum\limits_{\text{i}}{\left( \log \left( {{p}_{\text{i}}} \right)+1 \right)\beta {{p}_{i}}(1-{{p}_{i}})\nabla {{E}_{i}}(\mathbf{W})(-\nabla {{E}_{i}}(\mathbf{W})+\beta _{(k)}^{-1}\xi (t))}\]

噪音项在求期望时被平均成0,同时使用泰勒级数在p=1附近展开ln(p) :$\ln (p)=(p-1)-\frac{1}{2}{{(p-1)}^{2}}+\frac{1}{3}{{(p-1)}^{3}}-......$

可推出

$(\log (p)+1)(1-p)=-p\log p+1-p+\log p\approx -p\log p+1-p+(p-1)-\frac{1}{2}{{(p-1)}^{2}}=-p\log p-\frac{1}{2}{{(p-1)}^{2}}$

当p_i接近1时,忽略二次项,得到熵H,既 -plogp

继续带入可得(注意beta后面是预期值符号,不是损失函数E)

\[\frac{\partial H}{\partial t}\approx \beta \sum\limits_{\text{i}}{-{{p}_{i}}{{\left( \nabla {{E}_{i}}(\mathbf{W}) \right)}^{2}}}H=-\beta \operatorname{E}\left[ {{\left( \nabla E(\mathbf{W}) \right)}^{2}} \right]H\]

这里看出当训练时在全局loss逐渐收敛到一个小区间E_i内,p_i趋近于1,这时候熵的该变量与训练迭代次数满足上述微分方程。

解微分方程得到:

$H=H_{0}\exp\left(-\beta\mathbb{E}\left[(\nabla E(W))^{2}\right])t\right)$

该方程只在全局loss相对稳定之后成立,此时SGD噪音带来的熵随训练时间的增加而指数减少。

半衰期之前一直被当做常量来看待,但其实半衰期随着全局梯度平方的预期值的减小,会逐渐增大。

也就是说要从噪音里引入固定量的熵,所消耗的时间(迭代轮数)会越来越多,甚至比普通的指数衰减花费更多的时间。

第k层权重更新噪音引入的熵 会以 给定下一层特征层时输入数据X的熵 的形式展现。

\[\Delta H(\delta {{\mathbf{W}}^{(k)}})=\Delta H(X|{{T}^{(k+1)}})\]

噪音引入的熵的作用,会在下面几篇介绍信息瓶颈理论的文章里详细阐述。

SGD训练时收敛速度的变化研究。的更多相关文章

  1. 将caffe训练时loss的变化曲线用matlab绘制出来

    1. 首先是提取 训练日志文件; 2. 然后是matlab代码: clear all; close all; clc; log_file = '/home/wangxiao/Downloads/43_ ...

  2. DenseNet算法详解——思路就是highway,DneseNet在训练时十分消耗内存

    论文笔记:Densely Connected Convolutional Networks(DenseNet模型详解) 2017年09月28日 11:58:49 阅读数:1814 [ 转载自http: ...

  3. caffe︱深度学习参数调优杂记+caffe训练时的问题+dropout/batch Normalization

    一.深度学习中常用的调节参数 本节为笔者上课笔记(CDA深度学习实战课程第一期) 1.学习率 步长的选择:你走的距离长短,越短当然不会错过,但是耗时间.步长的选择比较麻烦.步长越小,越容易得到局部最优 ...

  4. 理解dropout——本质是通过阻止特征检测器的共同作用来防止过拟合 Dropout是指在模型训练时随机让网络某些隐含层节点的权重不工作,不工作的那些节点可以暂时认为不是网络结构的一部分,但是它的权重得保留下来(只是暂时不更新而已),因为下次样本输入时它可能又得工作了

    理解dropout from:http://blog.csdn.net/stdcoutzyx/article/details/49022443 http://www.cnblogs.com/torna ...

  5. caffe下训练时遇到的一些问题汇总

    1.报错:“db_lmdb.hpp:14] Check failed:mdb_status ==0(112 vs.0)磁盘空间不足.” 这问题是由于lmdb在windows下无法使用lmdb的库,所以 ...

  6. 基于google earth engine 云计算平台的全国水体变化研究

    第一个博客密码忘记了,今天才来开通第二个博客,时间已经过去两年了,三年的硕士生涯,真的是感慨良多,最有收获的一段时光,莫过于在实验室一个人敲着代码了,研三来得到中科院深圳先进院,在这里开始了新的研究生 ...

  7. A TensorBoard plugin for visualizing arbitrary tensors in a video as your network trains.Beholder是一个TensorBoard插件,用于在模型训练时查看视频帧。

    Beholder is a TensorBoard plugin for viewing frames of a video while your model trains. It comes wit ...

  8. 使用Deeplearning4j进行GPU训练时,出错的解决方法

    一.问题 使用deeplearning4j进行GPU训练时,可能会出现java.lang.UnsatisfiedLinkError: no jnicudnn in java.library.path错 ...

  9. Android8.0运行时权限策略变化和适配方案

    版权声明:转载必须注明本文转自严振杰的博客:http://blog.yanzhenjie.comAndroid8.0也就是Android O即将要发布了,有很多新特性,目前我们可以通过AndroidS ...

随机推荐

  1. ThreeJs 模型的缩放、移动、旋转 以及使用鼠标对三维物体的缩放

    首先我们创建一个模型对象 var geometry = new THREE.BoxGeometry( 100, 100, 100); //边长100的正方体 var material = new TH ...

  2. JVM内存模型和面试题解析

    一.JVM运行时区域 其中, 线程私有的:程序计数器,虚拟机栈,本地方法栈 线程共享的:堆,方法区,直接内存 1 程序计数器 程序计数器是一块较小的内存空间,可以看作是当前线程所执行的字节码的行号指示 ...

  3. php+mysql 原生事务回滚

    <?php $conn = mysql_connect('127.0.0.1', 'root', ''); mysql_select_db('msc_test'); mysql_query('S ...

  4. 面试常问MySQL性能优化问题

    知识综述: [1] MySQL中锁的种类: 常见的表锁和行锁,也有Metadata Lock等等,表锁是对一整张表加锁,分为读锁和写锁,因为是锁住整张表,所以会导致并发能力下降,一般是做ddl处理时使 ...

  5. MyEclipse Web项目部署失败:Deployment failure on Tomcat 7.x.Could not copy all resources to XXX.

    在做第一个MyEclipse web项目时,总是部署失败: Deployment failure on Tomcat 7.x.Could not copy all resources to XXX.I ...

  6. 《Dare To Dream》第六次作业:团队项目系统设计改进与详细设计

    团队项目系统设计改进与详细设计 一.团队项目系统设计改进 任务1: a.分析项目系统设计说明书初稿的不足,特别是软件系统结构模型建模不完善内容.  初稿的不足:缺乏每个模块的具体业务流程详细设计和流程 ...

  7. 《Dare To Dream 》第三次作业--团队项目的原型设计与开发

    一.实验目的与要求 1.掌握软件原型开发技术:  2.学习使用软件原型开发工具: 二.实验内容与步骤 任务1:针对实验六团队项目选题,采用适当的原型开发工具设计团队项目原型: 任务2:在团队博客发布博 ...

  8. 【博客开篇】服务器配置:Windows2008R2+PHP5.6+SQLServer2008(X64)

    现下流行LAMP,如果选择Windows服务器,那么一般都会选择IIS+Asp.Net+SQL Server(可以简称为WINS),这些配置起来,都是非常方便的. 但也有一些特殊的服务器配置,例如:W ...

  9. 爬虫之scrapy入门

    1.介绍 Scrapy是一个为了爬取网站数据,提取结构性数据而编写的应用框架. 其可以应用在数据挖掘,信息处理或存储历史数据等一系列的程序中.其最初是为了页面抓取 (更确切来说, 网络抓取 )所设计的 ...

  10. 初学python---排序

    1.永久性排序 sort() a = [12,45,1,25,3] a.sort() print(a)  ----[1, 3, 12, 25, 45] 2.临时排序 sorted() a = [12, ...