1.优化器算法简述

首先来看一下梯度下降最常见的三种变形 BGD,SGD,MBGD,这三种形式的区别就是取决于我们用多少数据来计算目标函数的梯度,这样的话自然就涉及到一个 trade-off,即参数更新的准确率和运行时间。

2.Batch Gradient Descent (BGD)

梯度更新规则:

BGD 采用整个训练集的数据来计算 cost function 对参数的梯度:

 

缺点:

由于这种方法是在一次更新中,就对整个数据集计算梯度,所以计算起来非常慢,遇到很大量的数据集也会非常棘手,而且不能投入新数据实时更新模型。

for i in range(nb_epochs):
params_grad = evaluate_gradient(loss_function, data, params)
params = params - learning_rate * params_grad

我们会事先定义一个迭代次数 epoch,首先计算梯度向量 params_grad,然后沿着梯度的方向更新参数 params,learning rate 决定了我们每一步迈多大。

Batch gradient descent 对于凸函数可以收敛到全局极小值,对于非凸函数可以收敛到局部极小值。

3.Stochastic Gradient Descent (SGD)

梯度更新规则:

和 BGD 的一次用所有数据计算梯度相比,SGD 每次更新时对每个样本进行梯度更新,对于很大的数据集来说,可能会有相似的样本,这样 BGD 在计算梯度时会出现冗余,而 SGD 一次只进行一次更新,就没有冗余,而且比较快,并且可以新增样本。

for i in range(nb_epochs):
np.random.shuffle(data)
for example in data:
params_grad = evaluate_gradient(loss_function, example, params)
params = params - learning_rate * params_grad

看代码,可以看到区别,就是整体数据集是个循环,其中对每个样本进行一次参数更新。

随机梯度下降是通过每个样本来迭代更新一次,如果样本量很大的情况,那么可能只用其中部分的样本,就已经将theta迭代到最优解了,对比上面的批量梯度下降,迭代一次需要用到十几万训练样本,一次迭代不可能最优,如果迭代10次的话就需要遍历训练样本10次。缺点是SGD的噪音较BGD要多,使得SGD并不是每次迭代都向着整体最优化方向。所以虽然训练速度快,但是准确度下降,并不是全局最优。虽然包含一定的随机性,但是从期望上来看,它是等于正确的导数的。

缺点:

SGD 因为更新比较频繁,会造成 cost function 有严重的震荡。

BGD 可以收敛到局部极小值,当然 SGD 的震荡可能会跳到更好的局部极小值处。

当我们稍微减小 learning rate,SGD 和 BGD 的收敛性是一样的。

4.Mini-Batch Gradient Descent (MBGD)

梯度更新规则:

MBGD 每一次利用一小批样本,即 n 个样本进行计算,这样它可以降低参数更新时的方差,收敛更稳定,另一方面可以充分地利用深度学习库中高度优化的矩阵操作来进行更有效的梯度计算。

和 SGD 的区别是每一次循环不是作用于每个样本,而是具有 n 个样本的批次。

for i in range(nb_epochs):
np.random.shuffle(data)
for batch in get_batches(data, batch_size=50):
params_grad = evaluate_gradient(loss_function, batch, params)
params = params - learning_rate * params_grad

超参数设定值:  n 一般取值在 50~256

缺点:(两大缺点)

  1. 不过 Mini-batch gradient descent 不能保证很好的收敛性,learning rate 如果选择的太小,收敛速度会很慢,如果太大,loss function 就会在极小值处不停地震荡甚至偏离。(有一种措施是先设定大一点的学习率,当两次迭代之间的变化低于某个阈值后,就减小 learning rate,不过这个阈值的设定需要提前写好,这样的话就不能够适应数据集的特点。)对于非凸函数,还要避免陷于局部极小值处,或者鞍点处,因为鞍点周围的error是一样的,所有维度的梯度都接近于0,SGD 很容易被困在这里。(会在鞍点或者局部最小点震荡跳动,因为在此点处,如果是训练集全集带入即BGD,则优化会停止不动,如果是mini-batch或者SGD,每次找到的梯度都是不同的,就会发生震荡,来回跳动。)
  2. SGD对所有参数更新时应用同样的 learning rate,如果我们的数据是稀疏的,我们更希望对出现频率低的特征进行大一点的更新。LR会随着更新的次数逐渐变小。

5.Adam:Adaptive Moment Estimation

这个算法是另一种计算每个参数的自适应学习率的方法。相当于 RMSprop + Momentum

除了像 Adadelta 和 RMSprop 一样存储了过去梯度的平方 vt 的指数衰减平均值 ,也像 momentum 一样保持了过去梯度 mt 的指数衰减平均值

如果 mt 和 vt 被初始化为 0 向量,那它们就会向 0 偏置,所以做了偏差校正,通过计算偏差校正后的 mt 和 vt 来抵消这些偏差:

梯度更新规则:

超参数设定值:
建议 β1 = 0.9,β2 = 0.999,ϵ = 10e−8

实践表明,Adam 比其他适应性学习方法效果要好。

参考文献:

https://www.cnblogs.com/guoyaohua/p/8542554.html

Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)的更多相关文章

  1. 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)

    在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...

  2. pytorch学习笔记(十二):详解 Module 类

    Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...

  3. 机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集

    机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集 关键字:FPgrowth.频繁项集.条件FP树.非监督学习作者:米 ...

  4. Linux防火墙iptables学习笔记(三)iptables命令详解和举例[转载]

     Linux防火墙iptables学习笔记(三)iptables命令详解和举例 2008-10-16 23:45:46 转载 网上看到这个配置讲解得还比较易懂,就转过来了,大家一起看下,希望对您工作能 ...

  5. (转)live555学习笔记10-h264 RTP传输详解(2)

    参考: 1,live555学习笔记10-h264 RTP传输详解(2) http://blog.csdn.net/niu_gao/article/details/6936108 2,H264 sps ...

  6. 大数据学习笔记——Spark工作机制以及API详解

    Spark工作机制以及API详解 本篇文章将会承接上篇关于如何部署Spark分布式集群的博客,会先对RDD编程中常见的API进行一个整理,接着再结合源代码以及注释详细地解读spark的作业提交流程,调 ...

  7. 学习笔记--Grunt、安装、图文详解

    学习笔记--Git安装.图文详解 安装Git成功后,现在安装Gruntjs,官网:http://gruntjs.com/ 一.安装node 参考node.js 安装.图文详解 (最新的node会自动安 ...

  8. Java8学习笔记(五)--Stream API详解[转]

    为什么需要 Stream Stream 作为 Java 8 的一大亮点,它与 java.io 包里的 InputStream 和 OutputStream 是完全不同的概念.它也不同于 StAX 对 ...

  9. 【官方文档】Nginx模块Nginx-Rtmp-Module学习笔记(一) RTMP 命令详解

    源码地址:https://github.com/Tinywan/PHP_Experience 说明: rtmp的延迟主要取决于播放器设置,但流式传输软件,流的比特率和网络速度(以及响应时间“ping” ...

随机推荐

  1. 万能构造解决Rolle中值问题

    只要原函数是两个函数的乘积形式,皆可此构造.

  2. 2021 ICPC Gran Premio de Mexico 2da Fecha部分题题解

    前面的水题,在队友的配合下,很快就拿下了,剩下几道大毒瘤题,一直罚座三个小时,好让人自闭...但不得不说,这些题的质量是真的高! H. Haunted House 首先看这个题,大眼一扫,觉得是某种数 ...

  3. 一步一步学ROP之gadgets和2free篇(蒸米spark)

    目录 一步一步学ROP之gadgets和2free篇(蒸米spark) 0x00序 0x01 通用 gadgets part2 0x02 利用mmap执行任意shellcode 0x03 堆漏洞利用之 ...

  4. RabbitMQ多消费者顺序性消费消息实现

    最近起了个项目消息中心,用来中转各个系统中产生的消息,用到的是RabbitMQ,由于UAT环境.生产环境每台消费者服务都是多台,有些消息要求按顺序消费,所以需要采取一定的措施保证消息的顺序消费,下面讲 ...

  5. lvs 四层负载相关

    都打开 /etc/sysctl.conf 中的 net.ip4.ip_forward=1.开启路由转发功能. 分发器 : eth0:192.168.1.66 (VIP) eth1:192.168.2. ...

  6. python生成有声小说模拟真人发音

    生成有声小说原理 文字是1500字内的生成微软文档说说 用代码实现小说爬取正本 实现每章小说1450字 实现自动剪切后添加封面 实现自动上传 用python代码实现爬取小说,本案列以一本小说为实列代码 ...

  7. Windows 常用配置

    安装IIS服务器 在服务器管理器中,选择"角色"添加角色 进入添加角色向导,在安装界面,选择服务器角色为:" Web服务器(IIS) " 角色服务勾选:应用程序 ...

  8. 记录一个很傻的错误(C++)

    使用的vscode写代码,导入了vector,memory,然后忘了导入string.但是代码中能够提示std::string也就让我忘了导入string.然后就莫名其妙的报错了.找了很久的错.记录下 ...

  9. SpringCould | Nacos与Feign

    服务注册Nacos 介绍 概念 一个更易于构建云原生应用的动态服务发现.配置管理和服务管理平台. Nacos: Dynamic Naming and Configuration Service Nac ...

  10. Linux usb 2. 协议分析

    文章目录 0. 背景 1. USB 协议传输格式 1.1 Packet 1.1.1 Token Packet 1.1.2 Data Packet 1.1.3 Handshake Packet 1.1. ...