【tf.keras】AdamW: Adam with Weight decay
论文 Decoupled Weight Decay Regularization 中提到,Adam 在使用时,L2 regularization 与 weight decay 并不等价,并提出了 AdamW,在神经网络需要正则项时,用 AdamW 替换 Adam+L2 会得到更好的性能。
TensorFlow 2.x 在 tensorflow_addons 库里面实现了 AdamW,可以直接pip install tensorflow_addons进行安装(在 windows 上需要 TF 2.1),也可以直接把这个仓库下载下来使用。
下面是一个利用 AdamW 的示例程序(TF 2.0, tf.keras),在使用 AdamW 的同时,使用 learning rate decay:(以下程序中,AdamW 的结果不如 Adam,这是因为模型比较简单,加多了 regularization 反而影响性能)
import tensorflow as tf
import os
from tensorflow_addons.optimizers import AdamW
import numpy as np
from tensorflow.python.keras import backend as K
from tensorflow.python.util.tf_export import keras_export
from tensorflow.keras.callbacks import Callback
def lr_schedule(epoch):
"""Learning Rate Schedule
Learning rate is scheduled to be reduced after 20, 30 epochs.
Called automatically every epoch as part of callbacks during training.
# Arguments
epoch (int): The number of epochs
# Returns
lr (float32): learning rate
"""
lr = 1e-3
if epoch >= 30:
lr *= 1e-2
elif epoch >= 20:
lr *= 1e-1
print('Learning rate: ', lr)
return lr
def wd_schedule(epoch):
"""Weight Decay Schedule
Weight decay is scheduled to be reduced after 20, 30 epochs.
Called automatically every epoch as part of callbacks during training.
# Arguments
epoch (int): The number of epochs
# Returns
wd (float32): weight decay
"""
wd = 1e-4
if epoch >= 30:
wd *= 1e-2
elif epoch >= 20:
wd *= 1e-1
print('Weight decay: ', wd)
return wd
# just copy the implement of LearningRateScheduler, and then change the lr with weight_decay
@keras_export('keras.callbacks.WeightDecayScheduler')
class WeightDecayScheduler(Callback):
"""Weight Decay Scheduler.
Arguments:
schedule: a function that takes an epoch index as input
(integer, indexed from 0) and returns a new
weight decay as output (float).
verbose: int. 0: quiet, 1: update messages.
```python
# This function keeps the weight decay at 0.001 for the first ten epochs
# and decreases it exponentially after that.
def scheduler(epoch):
if epoch < 10:
return 0.001
else:
return 0.001 * tf.math.exp(0.1 * (10 - epoch))
callback = WeightDecayScheduler(scheduler)
model.fit(data, labels, epochs=100, callbacks=[callback],
validation_data=(val_data, val_labels))
```
"""
def __init__(self, schedule, verbose=0):
super(WeightDecayScheduler, self).__init__()
self.schedule = schedule
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'weight_decay'):
raise ValueError('Optimizer must have a "weight_decay" attribute.')
try: # new API
weight_decay = float(K.get_value(self.model.optimizer.weight_decay))
weight_decay = self.schedule(epoch, weight_decay)
except TypeError: # Support for old API for backward compatibility
weight_decay = self.schedule(epoch)
if not isinstance(weight_decay, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.weight_decay, weight_decay)
if self.verbose > 0:
print('\nEpoch %05d: WeightDecayScheduler reducing weight '
'decay to %s.' % (epoch + 1, weight_decay))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['weight_decay'] = K.get_value(self.model.optimizer.weight_decay)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, enable=True)
print(gpus)
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])
optimizer = AdamW(learning_rate=lr_schedule(0), weight_decay=wd_schedule(0))
# optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
tb_callback = tf.keras.callbacks.TensorBoard(os.path.join('logs', 'adamw'),
profile_batch=0)
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
wd_callback = WeightDecayScheduler(wd_schedule)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=40, validation_split=0.1,
callbacks=[tb_callback, lr_callback, wd_callback])
model.evaluate(x_test, y_test, verbose=2)
以上代码实现了在 learning rate decay 时使用 AdamW,虽然只能是在 epoch 层面进行学习率衰减。
在使用 AdamW 时,如果要使用 learning rate decay,那么对 weight_decay 的值要进行同样的学习率衰减,不然训练会崩掉。
References
How to use AdamW correctly? -- wuliytTaotao
Loshchilov, I., & Hutter, F. Decoupled Weight Decay Regularization. ICLR 2019. Retrieved from http://arxiv.org/abs/1711.05101
【tf.keras】AdamW: Adam with Weight decay的更多相关文章
- 【tf.keras】tf.keras使用tensorflow中定义的optimizer
Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...
- 【tf.keras】使用手册
目录 0. 简介 1. 安装 1.1 安装 CUDA 和 cuDNN 2. 数据集 2.1 使用 tensorflow_datasets 导入公共数据集 2.2 数据集过大导致内存溢出 2.3 加载 ...
- 【tf.keras】在 cifar 上训练 AlexNet,数据集过大导致 OOM
cifar-10 每张图片的大小为 32×32,而 AlexNet 要求图片的输入是 224×224(也有说 227×227 的,这是 224×224 的图片进行大小为 2 的 zero paddin ...
- 【tf.keras】实现 F1 score、precision、recall 等 metric
tf.keras.metric 里面竟然没有实现 F1 score.recall.precision 等指标,一开始觉得真不可思议.但这是有原因的,这些指标在 batch-wise 上计算都没有意义, ...
- 【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化
TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了.tf.keras 从 1.x 版本迁移到 2.0 版本,需要修改几个地方. ...
- 【tf.keras】Resource exhausted: OOM when allocating tensor with shape [9216,4096] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
运行以下类似代码: while True: inputs, outputs = get_AlexNet() model = tf.keras.Model(inputs=inputs, outputs= ...
- 【tf.keras】tf.keras加载AlexNet预训练模型
目录 从 PyTorch 中导出模型参数 第 0 步:配置环境 第 1 步:安装 MMdnn 第 2 步:得到 PyTorch 保存完整结构和参数的模型(pth 文件) 第 3 步:导出 PyTorc ...
- 【tf.keras】tensorflow datasets,tfds
一些最常用的数据集如 MNIST.Fashion MNIST.cifar10/100 在 tf.keras.datasets 中就能找到,但对于其它也常用的数据集如 SVHN.Caltech101,t ...
- 【tf.keras】ssl.SSLError: [SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption failed or bad record mac (_ssl.c:1977)
问题描述 tf.keras 在加载 cifar10 数据时报错,ssl.SSLError: [SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption ...
随机推荐
- 【Pandas】Pandas求某列字符串的长度,总结经验教训
测试集大小: test.shape(898, 11) 对某列的字符串做统计长度1.for遍历法:start = time.time()for i in test.index.values: test. ...
- java 反射和泛型-反射来获取泛型信息
通过指定对应的Class对象,程序可以获得该类里面所有的Field,不管该Field使用private 方法public.获得Field对象后都可以使用getType()来获取其类型. Class&l ...
- 分析JVM动态生成的类
总结思考:让jvm创建动态类及其实例对象,需要给它提供哪些信息? 三个方面: 1.生成的类中有哪些方法,通过让其实现哪些接口的方式进行告知: 2.产生的类字节码必须有个一个关联的类加载器对象: 3.生 ...
- [学习笔记]Pollard-Rho
之前学的都是假的 %%zzt Miller_Rabin:Miller-Rabin与二次探测 大质数分解: 找到所有质因子,再logn搞出质因子的次数 方法:不断找到一个约数d,递归d,n/d进行分解, ...
- 2019牛客暑期多校训练营(第二场)F.Partition problem
链接:https://ac.nowcoder.com/acm/contest/882/F来源:牛客网 Given 2N people, you need to assign each of them ...
- vue移动端图片上传压缩
上传压缩方法 import {api} from '../../api/api.js'; import axios from 'axios'; export function imgPreview ( ...
- POJ3237 Tree 树链剖分 边权
POJ3237 Tree 树链剖分 边权 传送门:http://poj.org/problem?id=3237 题意: n个点的,n-1条边 修改单边边权 将a->b的边权取反 查询a-> ...
- 查看虚拟机里的Centos7的IP(设置centos网卡)
这里之所以是查看下IP ,是我们后面要建一个Centos远程工具Xshell 连接Centos的时候,需要IP地址,所以我们这里先 学会查看虚拟机里的Centos7的IP地址 首先我们登录操作系统 用 ...
- 使用Git和Github来管理自己的代码和笔记
一.Github注册 1.先注册github.com的账号,官方网站: https://github.com/ 2.登录 3.创建仓库,仓库分公开的和私有的,公开的是免费的,私有的是收费的.我现在创建 ...
- python元祖(tuple)
# 列表:有序,元素可以被修改 # 列表 # list # li = [111,22,33,44] # 元组:元素不可被修改,不能被增加或者删除 # ps: # tuple # tu = (11,22 ...