torch.optim优化算法理解之optim.Adam()
torch.optim是一个实现了多种优化算法的包,大多数通用的方法都已支持,提供了丰富的接口调用,未来更多精炼的优化算法也将整合进来。
为了使用torch.optim,需先构造一个优化器对象Optimizer,用来保存当前的状态,并能够根据计算得到的梯度来更新参数。
要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。 然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
- 1
- 2
- 3
Optimizer还支持指定每个参数选项。 只需传递一个可迭代的dict来替换先前可迭代的Variable。dict中的每一项都可以定义为一个单独的参数组,参数组用一个params键来包含属于它的参数列表。其他键应该与优化器接受的关键字参数相匹配,才能用作此组的优化选项。
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
- 1
- 2
- 3
- 4
如上,model.base.parameters()将使用1e-2的学习率,model.classifier.parameters()将使用1e-3的学习率。0.9的momentum作用于所有的parameters。
优化步骤:
所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新,它有两种调用方法:
optimizer.step()
- 1
这是大多数优化器都支持的简化版本,使用如下的backward()方法来计算梯度的时候会调用它。
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
- 1
- 2
- 3
- 4
- 5
- 6
optimizer.step(closure)
- 1
一些优化算法,如共轭梯度和LBFGS需要重新评估目标函数多次,所以你必须传递一个closure以重新计算模型。 closure必须清除梯度,计算并返回损失。
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
Adam算法:
adam算法来源:Adam: A Method for Stochastic Optimization
Adam(Adaptive Moment Estimation)本质上是带有动量项的RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。它的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。其公式如下:
其中,前两个公式分别是对梯度的一阶矩估计和二阶矩估计,可以看作是对期望E|gt|,E|gt^2|的估计;
公式3,4是对一阶二阶矩估计的校正,这样可以近似为对期望的无偏估计。可以看出,直接对梯度的矩估计对内存没有额外的要求,而且可以根据梯度进行动态调整。最后一项前面部分是对学习率n形成的一个动态约束,而且有明确的范围。
class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
- 1
参数:
params(iterable):可用于迭代优化的参数或者定义参数组的dicts。
lr (float, optional) :学习率(默认: 1e-3)
betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))
eps (float, optional):为了提高数值稳定性而添加到分母的一个项(默认: 1e-8)
weight_decay (float, optional):权重衰减(如L2惩罚)(默认: 0)
- 1
- 2
- 3
- 4
- 5
step(closure=None)函数:执行单一的优化步骤
closure (callable, optional):用于重新评估模型并返回损失的一个闭包
- 1
- 2
torch.optim.adam源码:
import math
from .optimizer import Optimizer
class Adam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps,weight_decay=weight_decay)
super(Adam, self).__init__(params, defaults)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
Adam的特点有:
1、结合了Adagrad善于处理稀疏梯度和RMSprop善于处理非平稳目标的优点;
2、对内存需求较小;
3、为不同的参数计算不同的自适应学习率;
4、也适用于大多非凸优化-适用于大数据集和高维空间。
torch.optim优化算法理解之optim.Adam()的更多相关文章
- 优化深度神经网络(二)优化算法 SGD Momentum RMSprop Adam
Coursera吴恩达<优化深度神经网络>课程笔记(2)-- 优化算法 深度机器学习中的batch的大小 深度机器学习中的batch的大小对学习效果有何影响? 1. Mini-batch ...
- Adam优化算法
Question? Adam 算法是什么,它为优化深度学习模型带来了哪些优势? Adam 算法的原理机制是怎么样的,它与相关的 AdaGrad 和 RMSProp 方法有什么区别. Adam 算法应该 ...
- PyTorch官方中文文档:torch.optim 优化器参数
内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...
- PyTorch-Adam优化算法原理,公式,应用
概念:Adam 是一种可以替代传统随机梯度下降过程的一阶优化算法,它能基于训练数据迭代地更新神经网络权重.Adam 最开始是由 OpenAI 的 Diederik Kingma 和多伦多大学的 Jim ...
- [DeeplearningAI笔记]改善深层神经网络_优化算法2.6_2.9Momentum/RMSprop/Adam优化算法
Optimization Algorithms优化算法 觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.6 动量梯度下降法(Momentum) 另一种成本函数优化算法,优化速度一般快于标准 ...
- 优化算法:AdaGrad | RMSProp | AdaDelta | Adam
0 - 引入 简单的梯度下降等优化算法存在一个问题:目标函数自变量的每一个元素在相同时间步都使用同一个学习率来迭代,如果存在如下图的情况(不同自变量的梯度值有较大差别时候),存在如下问题: 选择较小的 ...
- 机器学习中几种优化算法的比较(SGD、Momentum、RMSProp、Adam)
有关各种优化算法的详细算法流程和公式可以参考[这篇blog],讲解比较清晰,这里说一下自己对他们之间关系的理解. BGD 与 SGD 首先,最简单的 BGD 以整个训练集的梯度和作为更新方向,缺点是速 ...
- 改善深层神经网络_优化算法_mini-batch梯度下降、指数加权平均、动量梯度下降、RMSprop、Adam优化、学习率衰减
1.mini-batch梯度下降 在前面学习向量化时,知道了可以将训练样本横向堆叠,形成一个输入矩阵和对应的输出矩阵: 当数据量不是太大时,这样做当然会充分利用向量化的优点,一次训练中就可以将所有训练 ...
- zz:一个框架看懂优化算法之异同 SGD/AdaGrad/Adam
首先定义:待优化参数: ,目标函数: ,初始学习率 . 而后,开始进行迭代优化.在每个epoch : 计算目标函数关于当前参数的梯度: 根据历史梯度计算一阶动量和二阶动量:, 计算当前时刻的下降 ...
随机推荐
- Mysql 5.7.17安装后登录mysql的教程方法
在运行 ./bin/mysqld Cinitialize 初始化数据库时,会生成随机密码,示例: [Note] A temporary password is generated for root@l ...
- 导入pymysql模块出错:No module named 'pymysql'
前提: 使用的版本为:Python 3.6.4 pymysql已经被成功安装了,并通过命令行的方式验证已成功安装. 但在pycharm中运行工程时候时候报错:No module named 'pymy ...
- Ubuntu的网络共享
实际场景 公司项目中遇到一个场景:Ubuntu的主机上装了个4G卡(USB模式),需要将这个4G网共享给一个AP,使得所有连接AP的移动设备都可以通过4G上外网 方法很简单: 1. 将4G网口之外的另 ...
- Minimal coverage (贪心,最小覆盖)
题目大意:先确定一个M, 然后输入多组线段的左端和右端的端点坐标,然后让你求出来在所给的线段中能够 把[0, M] 区域完全覆盖完的最少需要的线段数,并输出这些线段的左右端点坐标. 思路分析: 线段区 ...
- React项目动态设置title标题
在React搭建的SPA项目中页面的title是直接写在入口index.html中,当路由在切换不用页面时,title是不会动态变化的.那么怎么让title随着路由的切换动态变化呢?1.在定义路由时增 ...
- vue项目开发过程常见问题
更新时间:2018-07-29 1.data functions should return an object // 这个问题是 Vue 实例内,单组件的data必须返回一个对象;如下 <sc ...
- JS判断PC 手机端显示不同的内容
方法一: function goPAGE() { if ((navigator.userAgent.match(/(phone|pad|pod|iPhone|iPod|ios|iPad|Android ...
- SPSS统计分析过程包括描述性统计、均值比较、一般线性模型、相关分析、回归分析、对数线性模型、聚类分析、数据简化、生存分析、时间序列分析、多重响应等几大类
https://www.zhihu.com/topic/19582125/top-answershttps://wenku.baidu.com/search?word=spss&ie=utf- ...
- XML配置里的Bean自动装配与Bean之间的关系
需要在<bean>的autowire属性里指定自动装配的模式 byType(根据类型自动装配) byName(根据名称自动装配) constructor(通过构造器自动装配) 名字须与属性 ...
- Hibernate→ 《Hibernate程序开发》教材大纲
Hibernate ORM 概览 Hibernate 简介 Hibernate 架构 Hibernate 环境 Hibernate 配置 Hibernate 会话 Hibernate 持久化类 Hib ...