Tensorflow2 深度学习十必知
博主根据自身多年的深度学习算法研发经验,整理分享以下十条必知。
含参考资料链接,部分附上相关代码实现。
独乐乐不如众乐乐,希望对各位看客有所帮助。
待回头有时间再展开细节说一说深度学习里的那些道道。
有什么技术需求需要有偿解决的也可以邮件或者QQ联系博主。
邮箱QQ同ID:gaozhihan@vip.qq.com
当然除了这十条,肯定还有其他“必知”,
欢迎评论分享更多,这里只是暂时拟定的十条,别较真哈。
主要学习其中的思路,切记,以下思路在个别场景并不适用 。
1.数据回流
[1907.05550] Faster Neural Network Training with Data Echoing
def data_echoing(factor):
return lambda image, label: tf.data.Dataset.from_tensors((image, label)).repeat(factor)
作用:
数据集加载后,在数据增广前后重复当前批次进模型的次数,减少数据的加载耗时。
等价于让模型看n次当前的数据,或者看n个增广后的数据样本。
2.AMP 自动精度混合
在bert4keras中使用混合精度和XLA加速训练 - 科学空间|Scientific Spaces
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
作用:
降低显存占用,加速训练,将部分网络计算转为等价的低精度计算,以此降低计算量。
3.优化器节省显存
3.1 [1804.04235]Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
mesh/optimize.py at master · tensorflow/mesh · GitHub
3.2 [1901.11150] Memory-Efficient Adaptive Optimization
google-research/sm3 at master · google-research/google-research (github.com)
作用:
节省显存,加速训练,
主要是对二阶动量进行特例化解构,减少显存存储。
4.权重标准化(归一化)
[2102.06171] High-Performance Large-Scale Image Recognition Without Normalization
deepmind-research/nfnets at master · deepmind/deepmind-research · GitHub
class WSConv2D(tf.keras.layers.Conv2D):
def __init__(self, *args, **kwargs):
super(WSConv2D, self).__init__(
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=1.0, mode='fan_in', distribution='untruncated_normal',
),
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(1e-4), *args, **kwargs
)
self.gain = self.add_weight(
name='gain',
shape=(self.filters,),
initializer="ones",
trainable=True,
dtype=self.dtype
) def standardize_weight(self, eps):
mean, var = tf.nn.moments(self.kernel, axes=[0, 1, 2], keepdims=True)
fan_in = np.prod(self.kernel.shape[:-1])
# Manually fused normalization, eq. to (w - mean) * gain / sqrt(N * var)
scale = tf.math.rsqrt(
tf.math.maximum(
var * fan_in,
tf.convert_to_tensor(eps, dtype=self.dtype)
)
) * self.gain
shift = mean * scale
return self.kernel * scale - shift def call(self, inputs):
eps = 1e-4
weight = self.standardize_weight(eps)
return tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
) if self.bias is None else tf.nn.bias_add(
tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
), self.bias)
作用:
通过对kernel进行标准化或归一化,相当于对kernel做一个先验约束,以此加速模型训练收敛。
5.自适应梯度裁剪
deepmind-research/agc_optax.py at master · deepmind/deepmind-research · GitHub
def unitwise_norm(x):
if len(tf.squeeze(x).shape) <= 1: # Scalars and vectors
axis = None
keepdims = False
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
axis = 0
keepdims = True
elif len(x.shape) == 4: # Conv kernels of shape HWIO
axis = [0, 1, 2, ]
keepdims = True
else:
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
square_sum = tf.reduce_sum(tf.square(x), axis, keepdims=keepdims)
return tf.sqrt(square_sum) def gradient_clipping(grad, var):
clipping = 0.01
max_norm = tf.maximum(unitwise_norm(var), 1e-3) * clipping
grad_norm = unitwise_norm(grad)
trigger = (grad_norm > max_norm)
clipped_grad = (max_norm / tf.maximum(grad_norm, 1e-6))
return grad * tf.where(trigger, clipped_grad, tf.ones_like(clipped_grad))
作用:
防止梯度爆炸,稳定训练。通过梯度和参数的关系,对梯度进行裁剪,约束学习率。
6.recompute_grad
[1604.06174] Training Deep Nets with Sublinear Memory Cost
google-research/recompute_grad.py at master · google-research/google-research (github.com)
bojone/keras_recompute: saving memory by recomputing for keras (github.com)
作用:
通过梯度重计算,节省显存。
7.归一化
[2003.05569] Extended Batch Normalization (arxiv.org)
from keras.layers.normalization.batch_normalization import BatchNormalizationBase class ExtendedBatchNormalization(BatchNormalizationBase):
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
trainable=True,
name=None,
**kwargs):
# Currently we only support aggregating over the global batch size.
super(ExtendedBatchNormalization, self).__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint,
gamma_constraint=gamma_constraint,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum,
fused=False,
trainable=trainable,
virtual_batch_size=None,
name=name,
**kwargs) def _calculate_mean_and_var(self, x, axes, keep_dims):
with tf.keras.backend.name_scope('moments'):
y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x
replica_ctx = tf.distribute.get_replica_context()
if replica_ctx:
local_sum = tf.math.reduce_sum(y, axis=axes, keepdims=True)
local_squared_sum = tf.math.reduce_sum(tf.math.square(y), axis=axes,
keepdims=True)
batch_size = tf.cast(tf.shape(y)[0], tf.float32)
y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
local_squared_sum)
global_batch_size = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
batch_size)
axes_vals = [(tf.shape(y))[i] for i in range(1, len(axes))]
multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
multiplier = multiplier * global_batch_size
mean = y_sum / multiplier
y_squared_mean = y_squared_sum / multiplier
# var = E(x^2) - E(x)^2
variance = y_squared_mean - tf.math.square(mean)
else:
# Compute true mean while keeping the dims for proper broadcasting.
mean = tf.math.reduce_mean(y, axes, keepdims=True, name='mean')
variance = tf.math.reduce_mean(
tf.math.squared_difference(y, tf.stop_gradient(mean)),
axes,
keepdims=True,
name='variance')
if not keep_dims:
mean = tf.squeeze(mean, axes)
variance = tf.squeeze(variance, axes)
variance = tf.math.reduce_mean(variance)
if x.dtype == tf.float16:
return (tf.cast(mean, tf.float16),
tf.cast(variance, tf.float16))
else:
return mean, variance
作用:
一个简易改进版的Batch Normalization,思路简单有效。
8.学习率策略
[1506.01186] Cyclical Learning Rates for Training Neural Networks (arxiv.org)
作用:
一个推荐的学习率策略方案,特定情况下可以取得更好的泛化。
9.重参数化
https://zhuanlan.zhihu.com/p/361090497
作用:
通过同时训练多份参数,合并权重的思路来提升模型泛化性。
10.长尾学习
[2110.04596] Deep Long-Tailed Learning: A Survey (arxiv.org)
Jorwnpay/A-Long-Tailed-Survey: 本项目是 Deep Long-Tailed Learning: A Survey 文章的中译版 (github.com)
作用:
解决长尾问题,可以加速收敛,提升模型泛化,稳定训练。
Tensorflow2 深度学习十必知的更多相关文章
- 对比深度学习十大框架:TensorFlow 并非最好?
http://www.oschina.net/news/80593/deep-learning-frameworks-a-review-before-finishing-2016 TensorFlow ...
- 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战
推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...
- 《TensorFlow2深度学习》学习笔记(一)Tensorflow基础
本系列笔记记录了学习TensorFlow2的过程,主要依据 https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book 进行学习 ...
- mysql学习--mysql必知必会1
例如以下为mysql必知必会第九章開始: 正則表達式用于匹配特殊的字符集合.mysql通过where子句对正則表達式提供初步的支持. keywordregexp用来表示后面跟的东西作为正則表達式 ...
- mysql学习--mysql必知必会
上图为数据库操作分类: 下面的操作參考(mysql必知必会) 创建数据库 运行脚本建表: mysql> create database mytest; Query OK, 1 row ...
- 学习axios必知必会(2)~axios基本使用、使用axios前必知细节、axios和实例对象区别、拦截器、取消请求
一.axios的基本使用: ✿ 使用axios前必知细节: 1.axios 函数对象(可以作为axios(config)函数使用去发送请求,也可以作为对象调用方法axios.request(confi ...
- 【Android Api 翻译4】android api 完整翻译之Contacts Provider (学习安卓必知的api,中英文对照)
Contacts Provider 电话簿(注:联系人,联络人.通信录)提供者 ------------------------------- QUICKVIEW 快速概览 * Android's r ...
- 《TensorFlow2深度学习》学习笔记(四)对笔记二中的模型增加正确率展示
全部代码如下:(红色部分为与笔记二不同之处) #1.Import the neccessary libraries needed import numpy as np import tensorflo ...
- 学习MyBatis必知必会(2)~MyBatis基本介绍和MyBatis基本使用
一.MyBatis框架基本介绍: 1.认识 MyBatis: MyBatis 是支持普通 SQL 查询,存储过程和高级映射的持久层框架,严格上说应该是一个 SQL 映射框架. 其前身是 iBatis, ...
随机推荐
- 【HarmonyOS学习笔记】记第一次使用IDE
哈喽大家好我是脸皮贼厚的小威 愚人节刚过先给大家拜个早年吧 最近在HarmonyOS官网下载了IDE,并抱着学(wan)习(wan)的心态试着跑出了Hello World,并安装到手机上 这是一个简单 ...
- 简单说一说jsonp原理
背景:由于浏览器同源策略的限制,非同源下的请求,都会产生跨域问题,jsonp即是为了解决这个问题出现的一种简便解决方案. 同源策略即:同一协议,同一域名,同一端口号.当其中一个不满足时,我们的请求即会 ...
- MySQL 的 GRANT和REVOKE 命令
MySQL 的 GRANT和REVOKE 命令 GRANT - 授权 将指定 操作对象 的指定 操作权限 授予指定的 用户; 发出该 GRANT语句的可以是数据库管理员,也可以是该数据库对象的创建者; ...
- windows下的操作
1.java -jar启动war包 将打好的war包丢到tomcat的webapps目录,然后进入tomcat的bin目录双击运行startup.bat会自动解压war包,在浏览器直接可访问web项目
- screen使用小结
目录 安装 shell-screen-window关系 常用参数 快捷键 离开当前screen 打开一个新的窗口 查看窗口列表 窗口的快速切换 回到行首 关闭窗口 关闭所有窗口 关闭screen 删除 ...
- Linux操作系统基本知识
1.Linux开发环境 2.GCC 2.1GCC工作流程 预处理:只运行 C 预编译器. 宏去掉了,注释没有了 汇编 编译 链接 2.2GCC常用参数选择 选项 解释 -ansi 只支持 ANSI 标 ...
- Linux-3作业练习
1.自建yum仓库,分别为网络源和本地源 请移步: yum源配置 2.编译安装http2.4,实现可以正常访问,并将编译步骤和结果提交. 请移步:http2.4编译安装 总结参照https ...
- C++基础-3-函数
3. 函数 3.1 函数默认参数 1 #include<iostream> 2 using namespace std; 3 4 //函数的默认参数 5 //自己传参,就用自己的,如果没有 ...
- viewport布局
1.viewport实例 <!DOCTYPE html> <html xmlns="http://www.w3.org/1999/xhtml"> <h ...
- Python图像处理:如何获取图像属性、兴趣ROI区域及通道处理
摘要:本篇文章主要讲解Python调用OpenCV获取图像属性,截取感兴趣ROI区域,处理图像通道. 本文分享自华为云社区<[Python图像处理] 三.获取图像属性.兴趣ROI区域及通道处理 ...