Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)
1.优化器算法简述
首先来看一下梯度下降最常见的三种变形 BGD,SGD,MBGD,这三种形式的区别就是取决于我们用多少数据来计算目标函数的梯度,这样的话自然就涉及到一个 trade-off,即参数更新的准确率和运行时间。
2.Batch Gradient Descent (BGD)
梯度更新规则:
BGD 采用整个训练集的数据来计算 cost function 对参数的梯度:
缺点:
由于这种方法是在一次更新中,就对整个数据集计算梯度,所以计算起来非常慢,遇到很大量的数据集也会非常棘手,而且不能投入新数据实时更新模型。
for i in range(nb_epochs):
params_grad = evaluate_gradient(loss_function, data, params)
params = params - learning_rate * params_grad
我们会事先定义一个迭代次数 epoch,首先计算梯度向量 params_grad,然后沿着梯度的方向更新参数 params,learning rate 决定了我们每一步迈多大。
Batch gradient descent 对于凸函数可以收敛到全局极小值,对于非凸函数可以收敛到局部极小值。
3.Stochastic Gradient Descent (SGD)
梯度更新规则:
和 BGD 的一次用所有数据计算梯度相比,SGD 每次更新时对每个样本进行梯度更新,对于很大的数据集来说,可能会有相似的样本,这样 BGD 在计算梯度时会出现冗余,而 SGD 一次只进行一次更新,就没有冗余,而且比较快,并且可以新增样本。

for i in range(nb_epochs):
np.random.shuffle(data)
for example in data:
params_grad = evaluate_gradient(loss_function, example, params)
params = params - learning_rate * params_grad
看代码,可以看到区别,就是整体数据集是个循环,其中对每个样本进行一次参数更新。

随机梯度下降是通过每个样本来迭代更新一次,如果样本量很大的情况,那么可能只用其中部分的样本,就已经将theta迭代到最优解了,对比上面的批量梯度下降,迭代一次需要用到十几万训练样本,一次迭代不可能最优,如果迭代10次的话就需要遍历训练样本10次。缺点是SGD的噪音较BGD要多,使得SGD并不是每次迭代都向着整体最优化方向。所以虽然训练速度快,但是准确度下降,并不是全局最优。虽然包含一定的随机性,但是从期望上来看,它是等于正确的导数的。
缺点:
SGD 因为更新比较频繁,会造成 cost function 有严重的震荡。
BGD 可以收敛到局部极小值,当然 SGD 的震荡可能会跳到更好的局部极小值处。
当我们稍微减小 learning rate,SGD 和 BGD 的收敛性是一样的。
4.Mini-Batch Gradient Descent (MBGD)
梯度更新规则:
MBGD 每一次利用一小批样本,即 n 个样本进行计算,这样它可以降低参数更新时的方差,收敛更稳定,另一方面可以充分地利用深度学习库中高度优化的矩阵操作来进行更有效的梯度计算。

和 SGD 的区别是每一次循环不是作用于每个样本,而是具有 n 个样本的批次。
for i in range(nb_epochs):
np.random.shuffle(data)
for batch in get_batches(data, batch_size=50):
params_grad = evaluate_gradient(loss_function, batch, params)
params = params - learning_rate * params_grad
超参数设定值: n 一般取值在 50~256
缺点:(两大缺点)
- 不过 Mini-batch gradient descent 不能保证很好的收敛性,learning rate 如果选择的太小,收敛速度会很慢,如果太大,loss function 就会在极小值处不停地震荡甚至偏离。(有一种措施是先设定大一点的学习率,当两次迭代之间的变化低于某个阈值后,就减小 learning rate,不过这个阈值的设定需要提前写好,这样的话就不能够适应数据集的特点。)对于非凸函数,还要避免陷于局部极小值处,或者鞍点处,因为鞍点周围的error是一样的,所有维度的梯度都接近于0,SGD 很容易被困在这里。(会在鞍点或者局部最小点震荡跳动,因为在此点处,如果是训练集全集带入即BGD,则优化会停止不动,如果是mini-batch或者SGD,每次找到的梯度都是不同的,就会发生震荡,来回跳动。)
- SGD对所有参数更新时应用同样的 learning rate,如果我们的数据是稀疏的,我们更希望对出现频率低的特征进行大一点的更新。LR会随着更新的次数逐渐变小。
5.Adam:Adaptive Moment Estimation
这个算法是另一种计算每个参数的自适应学习率的方法。相当于 RMSprop + Momentum
除了像 Adadelta 和 RMSprop 一样存储了过去梯度的平方 vt 的指数衰减平均值 ,也像 momentum 一样保持了过去梯度 mt 的指数衰减平均值:

如果 mt 和 vt 被初始化为 0 向量,那它们就会向 0 偏置,所以做了偏差校正,通过计算偏差校正后的 mt 和 vt 来抵消这些偏差:

梯度更新规则:

超参数设定值:
建议 β1 = 0.9,β2 = 0.999,ϵ = 10e−8
实践表明,Adam 比其他适应性学习方法效果要好。
参考文献:
https://www.cnblogs.com/guoyaohua/p/8542554.html
Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)的更多相关文章
- 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)
在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...
- pytorch学习笔记(十二):详解 Module 类
Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...
- 机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集
机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集 关键字:FPgrowth.频繁项集.条件FP树.非监督学习作者:米 ...
- Linux防火墙iptables学习笔记(三)iptables命令详解和举例[转载]
Linux防火墙iptables学习笔记(三)iptables命令详解和举例 2008-10-16 23:45:46 转载 网上看到这个配置讲解得还比较易懂,就转过来了,大家一起看下,希望对您工作能 ...
- (转)live555学习笔记10-h264 RTP传输详解(2)
参考: 1,live555学习笔记10-h264 RTP传输详解(2) http://blog.csdn.net/niu_gao/article/details/6936108 2,H264 sps ...
- 大数据学习笔记——Spark工作机制以及API详解
Spark工作机制以及API详解 本篇文章将会承接上篇关于如何部署Spark分布式集群的博客,会先对RDD编程中常见的API进行一个整理,接着再结合源代码以及注释详细地解读spark的作业提交流程,调 ...
- 学习笔记--Grunt、安装、图文详解
学习笔记--Git安装.图文详解 安装Git成功后,现在安装Gruntjs,官网:http://gruntjs.com/ 一.安装node 参考node.js 安装.图文详解 (最新的node会自动安 ...
- Java8学习笔记(五)--Stream API详解[转]
为什么需要 Stream Stream 作为 Java 8 的一大亮点,它与 java.io 包里的 InputStream 和 OutputStream 是完全不同的概念.它也不同于 StAX 对 ...
- 【官方文档】Nginx模块Nginx-Rtmp-Module学习笔记(一) RTMP 命令详解
源码地址:https://github.com/Tinywan/PHP_Experience 说明: rtmp的延迟主要取决于播放器设置,但流式传输软件,流的比特率和网络速度(以及响应时间“ping” ...
随机推荐
- numpy读取本地数据和索引
1.numpy读取数据 np.loadtxt(fname,dtype=np.float,delimiter=None,skiprows=0,usecols=None,unpack=False) 做一个 ...
- openstack 后期维护(四)--- 删除僵尸卷
前言: 在长时间使用openstack之后,删除虚机后,经常会有因这样那样的问题,导致卷处于僵尸状态,无法删除! 状态一: 虚机已近删除,然而卷却挂在到了 None上无法删除 解决办法: 1.# ci ...
- RabbitMQ的安装及入门使(Windows)
1.安装Erlang所以在安装rabbitMQ之前,需要先安装Erlang .点击下载Erlang 执行下载下来的Erlang,全部点击"下一步"就行.安装完成设置一下环境变量. ...
- Java测试开发--MySql之C3P0连接池(八)
连接池C3P0! 连接池技术的目的:解决建立数据库连接耗费资源和时间很多的问题,提高性能 ! 下面以案例演示下C3P0的操作流程. 1.测试准备: ①MySql数据库一枚②database名为myte ...
- yum install hadoop related client
yum list avaliable hadoop\* yum list installed yum repolist repo is in /etc/yum.repos.d yum install ...
- Flask WTForm disable choice field
Flask disable choice field ChoiceField = { render_kw={'disabled':''} } form.my_field.render_kw = {'d ...
- sqlalchemy 执行sql
关键需要使用text from sqlalchemy import create_engine, text sql = 'SELECT * FROM my_table WHERE account_id ...
- @Autowired注解注入失败,提示could not autowire的解决办法
autowire异常主要由三个情况发生的 像上面的情况是BrandDao没有注入, 1.你的BrandServiceImpl必须以@Service或@Component注解才行. 2.自动写入的时候把 ...
- MyScript 开发文档
一.IInk SDK runtime 1.1 引擎创建 1.2 对象释放 1.3 获取并设置配置 配置 访问配置 配置识别 二.文件存储 2.1 支持的内容的类型 2.2 模型结构 2.3 Conte ...
- N体模拟数据可视化 LightningChart®
N体模拟数据可视化 LightningChart N体模拟也许是目前最先进的数据可视化类型之一.事实上,我们现在谈论的不再是以商业为中心的传统数据的可视化,现在它甚至超越了比如振动分析等先进数据源 ...