1.优化器算法简述

首先来看一下梯度下降最常见的三种变形 BGD,SGD,MBGD,这三种形式的区别就是取决于我们用多少数据来计算目标函数的梯度,这样的话自然就涉及到一个 trade-off,即参数更新的准确率和运行时间。

2.Batch Gradient Descent (BGD)

梯度更新规则:

BGD 采用整个训练集的数据来计算 cost function 对参数的梯度:

 

缺点:

由于这种方法是在一次更新中,就对整个数据集计算梯度,所以计算起来非常慢,遇到很大量的数据集也会非常棘手,而且不能投入新数据实时更新模型。

for i in range(nb_epochs):
params_grad = evaluate_gradient(loss_function, data, params)
params = params - learning_rate * params_grad

我们会事先定义一个迭代次数 epoch,首先计算梯度向量 params_grad,然后沿着梯度的方向更新参数 params,learning rate 决定了我们每一步迈多大。

Batch gradient descent 对于凸函数可以收敛到全局极小值,对于非凸函数可以收敛到局部极小值。

3.Stochastic Gradient Descent (SGD)

梯度更新规则:

和 BGD 的一次用所有数据计算梯度相比,SGD 每次更新时对每个样本进行梯度更新,对于很大的数据集来说,可能会有相似的样本,这样 BGD 在计算梯度时会出现冗余,而 SGD 一次只进行一次更新,就没有冗余,而且比较快,并且可以新增样本。

for i in range(nb_epochs):
np.random.shuffle(data)
for example in data:
params_grad = evaluate_gradient(loss_function, example, params)
params = params - learning_rate * params_grad

看代码,可以看到区别,就是整体数据集是个循环,其中对每个样本进行一次参数更新。

随机梯度下降是通过每个样本来迭代更新一次,如果样本量很大的情况,那么可能只用其中部分的样本,就已经将theta迭代到最优解了,对比上面的批量梯度下降,迭代一次需要用到十几万训练样本,一次迭代不可能最优,如果迭代10次的话就需要遍历训练样本10次。缺点是SGD的噪音较BGD要多,使得SGD并不是每次迭代都向着整体最优化方向。所以虽然训练速度快,但是准确度下降,并不是全局最优。虽然包含一定的随机性,但是从期望上来看,它是等于正确的导数的。

缺点:

SGD 因为更新比较频繁,会造成 cost function 有严重的震荡。

BGD 可以收敛到局部极小值,当然 SGD 的震荡可能会跳到更好的局部极小值处。

当我们稍微减小 learning rate,SGD 和 BGD 的收敛性是一样的。

4.Mini-Batch Gradient Descent (MBGD)

梯度更新规则:

MBGD 每一次利用一小批样本,即 n 个样本进行计算,这样它可以降低参数更新时的方差,收敛更稳定,另一方面可以充分地利用深度学习库中高度优化的矩阵操作来进行更有效的梯度计算。

和 SGD 的区别是每一次循环不是作用于每个样本,而是具有 n 个样本的批次。

for i in range(nb_epochs):
np.random.shuffle(data)
for batch in get_batches(data, batch_size=50):
params_grad = evaluate_gradient(loss_function, batch, params)
params = params - learning_rate * params_grad

超参数设定值:  n 一般取值在 50~256

缺点:(两大缺点)

  1. 不过 Mini-batch gradient descent 不能保证很好的收敛性,learning rate 如果选择的太小,收敛速度会很慢,如果太大,loss function 就会在极小值处不停地震荡甚至偏离。(有一种措施是先设定大一点的学习率,当两次迭代之间的变化低于某个阈值后,就减小 learning rate,不过这个阈值的设定需要提前写好,这样的话就不能够适应数据集的特点。)对于非凸函数,还要避免陷于局部极小值处,或者鞍点处,因为鞍点周围的error是一样的,所有维度的梯度都接近于0,SGD 很容易被困在这里。(会在鞍点或者局部最小点震荡跳动,因为在此点处,如果是训练集全集带入即BGD,则优化会停止不动,如果是mini-batch或者SGD,每次找到的梯度都是不同的,就会发生震荡,来回跳动。)
  2. SGD对所有参数更新时应用同样的 learning rate,如果我们的数据是稀疏的,我们更希望对出现频率低的特征进行大一点的更新。LR会随着更新的次数逐渐变小。

5.Adam:Adaptive Moment Estimation

这个算法是另一种计算每个参数的自适应学习率的方法。相当于 RMSprop + Momentum

除了像 Adadelta 和 RMSprop 一样存储了过去梯度的平方 vt 的指数衰减平均值 ,也像 momentum 一样保持了过去梯度 mt 的指数衰减平均值

如果 mt 和 vt 被初始化为 0 向量,那它们就会向 0 偏置,所以做了偏差校正,通过计算偏差校正后的 mt 和 vt 来抵消这些偏差:

梯度更新规则:

超参数设定值:
建议 β1 = 0.9,β2 = 0.999,ϵ = 10e−8

实践表明,Adam 比其他适应性学习方法效果要好。

参考文献:

https://www.cnblogs.com/guoyaohua/p/8542554.html

Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)的更多相关文章

  1. 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)

    在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...

  2. pytorch学习笔记(十二):详解 Module 类

    Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...

  3. 机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集

    机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集 关键字:FPgrowth.频繁项集.条件FP树.非监督学习作者:米 ...

  4. Linux防火墙iptables学习笔记(三)iptables命令详解和举例[转载]

     Linux防火墙iptables学习笔记(三)iptables命令详解和举例 2008-10-16 23:45:46 转载 网上看到这个配置讲解得还比较易懂,就转过来了,大家一起看下,希望对您工作能 ...

  5. (转)live555学习笔记10-h264 RTP传输详解(2)

    参考: 1,live555学习笔记10-h264 RTP传输详解(2) http://blog.csdn.net/niu_gao/article/details/6936108 2,H264 sps ...

  6. 大数据学习笔记——Spark工作机制以及API详解

    Spark工作机制以及API详解 本篇文章将会承接上篇关于如何部署Spark分布式集群的博客,会先对RDD编程中常见的API进行一个整理,接着再结合源代码以及注释详细地解读spark的作业提交流程,调 ...

  7. 学习笔记--Grunt、安装、图文详解

    学习笔记--Git安装.图文详解 安装Git成功后,现在安装Gruntjs,官网:http://gruntjs.com/ 一.安装node 参考node.js 安装.图文详解 (最新的node会自动安 ...

  8. Java8学习笔记(五)--Stream API详解[转]

    为什么需要 Stream Stream 作为 Java 8 的一大亮点,它与 java.io 包里的 InputStream 和 OutputStream 是完全不同的概念.它也不同于 StAX 对 ...

  9. 【官方文档】Nginx模块Nginx-Rtmp-Module学习笔记(一) RTMP 命令详解

    源码地址:https://github.com/Tinywan/PHP_Experience 说明: rtmp的延迟主要取决于播放器设置,但流式传输软件,流的比特率和网络速度(以及响应时间“ping” ...

随机推荐

  1. 攻防世界 web3.backup

    如果网站存在备份文件,常见的备份文件后缀名有:.git ..svn..swp..~..bak..bash_history..bkf尝试在URL后面,依次输入常见的文件备份扩展名. ip/index.p ...

  2. bash执行顺序:alias --> function --> builtin --> program

    linux bash的执行顺序如下所示: 先 alias --> function --> builtin --> program 后 验证过程: 1,在bash shell中有内置 ...

  3. 数组中出现次数超过一半的数字 牛客网 剑指Offer

    数组中出现次数超过一半的数字 牛客网 剑指Offer 题目描述 数组中有一个数字出现的次数超过数组长度的一半,请找出这个数字.例如输入一个长度为9的数组{1,2,3,2,2,2,5,4,2}.由于数字 ...

  4. 2021CCPC华为云挑战赛 部分题题解

    CDN流量调度问题 题看了没多久就看出来是\(DP\)的题,然后就设了状态\(f[i][j]\)表示到前\(i\)个点时已经用了\(j\)个节点的最小总代价,结果发现转移时\(O(nm^2)\),但这 ...

  5. linux删除文件未释放

    https://access.redhat.com/solutions/2316 $ /usr/sbin/lsof | grep deleted ora 25575 data 33u REG 65,6 ...

  6. 动手写一个简单的Web框架(模板渲染)

    动手写一个简单的Web框架(模板渲染) 在百度上搜索jinja2,显示的大部分内容都是jinja2的渲染语法,这个不是Web框架需要做的事,最终,居然在Werkzeug的官方文档里找到模板渲染的代码. ...

  7. 【Git 系列】基础知识全集

    Git 是一种分布式版本控制系统,它可以不受网络连接的限制,加上其它众多优点,目前已经成为程序开发人员做项目版本管理时的首选,非开发人员也可以用 Git 来做自己的文档版本管理工具. 一.Git 基础 ...

  8. 重新整理 .net core 实践篇——— 权限源码阅读四十五]

    前言 简单介绍一下权限源码阅读一下. 正文 一直有人对授权这个事情上争论不休,有的人认为在输入账户密码给后台这个时候进行了授权,因为认为发送了一个身份令牌,令牌里面可能有些用户角色信息,认为这就是授权 ...

  9. [cf1184E]Daleks' Invasion

    先求出任意一棵最小生成树,然后对边分类讨论1.非树边,答案即最小生成树的环上的最长边2.树边,反过来考虑,相当于对于每一个点对那条路经打上标记,取min对于1直接用倍增维护即可,对于2可以用树链剖分/ ...

  10. MS17-010漏洞利用

    MS17-010漏洞利用 1.安装虚拟机win7 x64,实现利用ms17-010实现对其win7 x64主机开始渗透,查看该主机信息,打开远程桌面,抓取用户名和密码并破译,创建一个 : 学号.txt ...