[PyTorch 学习笔记] 4.1 权值初始化
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/grad_vanish_explod.py
在搭建好网络模型之后,一个重要的步骤就是对网络模型中的权值进行初始化。适当的权值初始化可以加快模型的收敛,而不恰当的权值初始化可能引发梯度消失或者梯度爆炸,最终导致模型无法收敛。下面分 3 部分介绍。第一部分介绍不恰当的权值初始化是如何引发梯度消失与梯度爆炸的,第二部分介绍常用的 Xavier 方法与 Kaiming 方法,第三部分介绍 PyTorch 中的 10 种初始化方法。
梯度消失与梯度爆炸
考虑一个 3 层的全连接网络。
$H_{1}=X \times W_{1}$,$H_{2}=H_{1} \times W_{2}$,$Out=H_{2} \times W_{3}$
其中第 2 层的权重梯度如下:
$\begin{aligned} \Delta \mathrm{W}{2} &=\frac{\partial \mathrm{Loss}}{\partial \mathrm{W}{2}}=\frac{\partial \mathrm{Loss}}{\partial \mathrm{out}} * \frac{\partial \mathrm{out}}{\partial \mathrm{H}{2}} * \frac{\partial \mathrm{H}{2}}{\partial \mathrm{w}{2}} \ &=\frac{\partial \mathrm{Loss}}{\partial \mathrm{out}} * \frac{\partial \mathrm{out}}{\partial \mathrm{H}{2}} * \mathrm{H}_{1} \end{aligned}$
所以 $\Delta \mathrm{W}{2}$ 依赖于前一层的输出 $H{1}$。如果 $H_{1}$ 趋近于零,那么 $\Delta \mathrm{W}{2}$ 也接近于 0,造成梯度消失。如果 $H{1}$ 趋近于无穷大,那么 $\Delta \mathrm{W}_{2}$ 也接近于无穷大,造成梯度爆炸。要避免梯度爆炸或者梯度消失,就要严格控制网络层输出的数值范围。
下面构建 100 层全连接网络,先不使用非线性激活函数,每层的权重初始化为服从 $N(0,1)$ 的正态分布,输出数据使用随机初始化的数据。
import torch
import torch.nn as nn
from common_tools import set_seed
set_seed(1) # 设置随机种子
class MLP(nn.Module):
def __init__(self, neural_num, layers):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
return x
def initialize(self):
for m in self.modules():
# 判断这一层是否为线性层,如果为线性层则初始化权值
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data) # normal: mean=0, std=1
layer_nums = 100
neural_nums = 256
batch_size = 16
net = MLP(neural_nums, layer_nums)
net.initialize()
inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1
output = net(inputs)
print(output)
输出为:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], grad_fn=<MmBackward>)
也就是数据太大(梯度爆炸)或者太小(梯度消失)了。接下来我们在forward()
函数中判断每一次前向传播的输出的标准差是否为 nan,如果是 nan 则停止前向传播。
def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
print("layer:{}, std:{}".format(i, x.std()))
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break
return x
输出如下:
layer:0, std:15.959932327270508
layer:1, std:256.6237487792969
layer:2, std:4107.24560546875
.
.
.
layer:29, std:1.322983152787379e+36
layer:30, std:2.0786820453988485e+37
layer:31, std:nan
output is nan in 31 layers
可以看到每一层的标准差是越来越大的,并在在 31 层时超出了数据可以表示的范围。
下面推导为什么网络层输出的标准差越来越大。
首先给出 3 个公式:
$E(X \times Y)=E(X) \times E(Y)$:两个相互独立的随机变量的乘积的期望等于它们的期望的乘积。
$D(X)=E(X^{2}) - [E(X)]^{2}$:一个随机变量的方差等于它的平方的期望减去期望的平方
$D(X+Y)=D(X)+D(Y)$:两个相互独立的随机变量之和的方差等于它们的方差的和。
可以推导出两个随机变量的乘积的方差如下:
$D(X \times Y)=E[(XY)^{2}] - [E(XY)]^{2}=D(X) \times D(Y) + D(X) \times [E(Y)]^{2} + D(Y) \times [E(X)]^{2}$
如果 $E(X)=0$,$E(Y)=0$,那么 $D(X \times Y)=D(X) \times D(Y)$
我们以输入层第一个神经元为例:
$\mathrm{H}{11}=\sum{i=0}^{n} X_{i} \times W_{1 i}$
其中输入 X 和权值 W 都是服从 $N(0,1)$ 的正态分布,所以这个神经元的方差为:
$\begin{aligned} \mathbf{D}\left(\mathrm{H}{11}\right) &=\sum{i=0}^{n} \boldsymbol{D}\left(X_{i}\right) * \boldsymbol{D}\left(W_{1 i}\right) \ &=n *(1 * 1) \ &=n \end{aligned}$
标准差为:$\operatorname{std}\left(\mathrm{H}{11}\right)=\sqrt{\mathbf{D}\left(\mathrm{H}{11}\right)}=\sqrt{n}$,所以每经过一个网络层,方差就会扩大 n 倍,标准差就会扩大 $\sqrt{n}$ 倍,n 为每层神经元个数,直到超出数值表示范围。对比上面的代码可以看到,每层神经元个数为 256,输出数据的标准差为 1,所以第一个网络层输出的标准差为 16 左右,第二个网络层输出的标准差为 256 左右,以此类推,直到 31 层超出数据表示范围。可以把每层神经元个数改为 400,那么每层标准差扩大 20 倍左右。从 $D(\mathrm{H}{11})=\sum{i=0}^{n} D(X_{i}) \times D(W_{1 i})$,可以看出,每一层网络输出的方差与神经元个数、输入数据的方差、权值方差有关,其中比较好改变的是权值的方差 $D(W)$,所以 $D(W)= \frac{1}{n}$,标准差为 $std(W)=\sqrt\frac{1}{n}$。因此修改权值初始化代码为nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))
,结果如下:
layer:0, std:0.9974957704544067
layer:1, std:1.0024365186691284
layer:2, std:1.002745509147644
.
.
.
layer:94, std:1.031973123550415
layer:95, std:1.0413124561309814
layer:96, std:1.0817031860351562
修改之后,没有出现梯度消失或者梯度爆炸的情况,每层神经元输出的方差均在 1 左右。通过恰当的权值初始化,可以保持权值在更新过程中维持在一定范围之内,不过过大,也不会过小。
上述是没有使用非线性变换的实验结果,如果在forward()
中添加非线性变换tanh
,每一层的输出方差还是会越来越小,会导致梯度消失。因此出现了 Xavier 初始化方法与 Kaiming 初始化方法。
Xavier 方法与 Kaiming 方法
Xavier 方法
Xavier 是 2010 年提出的,针对有非线性激活函数时的权值初始化方法,目标是保持数据的方差维持在 1 左右,主要针对饱和激活函数如 sigmoid 和 tanh 等。同时考虑前向传播和反向传播,需要满足两个等式:$\boldsymbol{n}{\boldsymbol{i}} * \boldsymbol{D}(\boldsymbol{W})=\mathbf{1}$ 和 $\boldsymbol{n}{\boldsymbol{i+1}} * \boldsymbol{D}(\boldsymbol{W})=\mathbf{1}$,可得:$D(W)=\frac{2}{n_{i}+n_{i+1}}$。为了使 Xavier 方法初始化的权值服从均匀分布,假设 $W$ 服从均匀分布 $U[-a, a]$,那么方差 $D(W)=\frac{(-a-a)^{2}}{12}=\frac{(2 a){2}}{12}=\frac{a{2}}{3}$,令 $\frac{2}{n_{i}+n_{i+1}}=\frac{a^{2}}{3}$,解得:$\boldsymbol{a}=\frac{\sqrt{6}}{\sqrt{n_{i}+n_{i+1}}}$,所以 $W$ 服从分布 $U\left[-\frac{\sqrt{6}}{\sqrt{n_{i}+n_{i+1}}}, \frac{\sqrt{6}}{\sqrt{n_{i}+n_{i+1}}}\right]$
所以初始化方法改为:
a = np.sqrt(6 / (self.neural_num + self.neural_num))
# 把 a 变换到 tanh,计算增益
tanh_gain = nn.init.calculate_gain('tanh')
a *= tanh_gain
nn.init.uniform_(m.weight.data, -a, a)
并且每一层的激活函数都使用 tanh,输出如下:
layer:0, std:0.7571136355400085
layer:1, std:0.6924336552619934
layer:2, std:0.6677976846694946
.
.
.
layer:97, std:0.6426210403442383
layer:98, std:0.6407480835914612
layer:99, std:0.6442216038703918
可以看到每层输出的方差都维持在 0.6 左右。
PyTorch 也提供了 Xavier 初始化方法,可以直接调用:
tanh_gain = nn.init.calculate_gain('tanh')
nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)
nn.init.calculate_gain()
上面的初始化方法都使用了tanh_gain = nn.init.calculate_gain('tanh')
。
nn.init.calculate_gain(nonlinearity,param=**None**)
的主要功能是经过一个分布的方差经过激活函数后的变化尺度,主要有两个参数:
- nonlinearity:激活函数名称
- param:激活函数的参数,如 Leaky ReLU 的 negative_slop。
下面是计算标准差经过激活函数的变化尺度的代码。
x = torch.randn(10000)
out = torch.tanh(x)
gain = x.std() / out.std()
print('gain:{}'.format(gain))
tanh_gain = nn.init.calculate_gain('tanh')
print('tanh_gain in PyTorch:', tanh_gain)
输出如下:
gain:1.5982500314712524
tanh_gain in PyTorch: 1.6666666666666667
结果表示,原有数据分布的方差经过 tanh 之后,标准差会变小 1.6 倍左右。
Kaiming 方法
虽然 Xavier 方法提出了针对饱和激活函数的权值初始化方法,但是 AlexNet 出现后,大量网络开始使用非饱和的激活函数如 ReLU 等,这时 Xavier 方法不再适用。2015 年针对 ReLU 及其变种等激活函数提出了 Kaiming 初始化方法。
针对 ReLU,方差应该满足:$\mathrm{D}(W)=\frac{2}{n_{i}}$;针对 ReLu 的变种,方差应该满足:$\mathrm{D}(W)=\frac{2}{n_{i}}$,a 表示负半轴的斜率,如 PReLU 方法,标准差满足 $\operatorname{std}(W)=\sqrt{\frac{2}{\left(1+a^{2}\right) * n_{i}}}$。代码如下:nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))
,或者使用 PyTorch 提供的初始化方法:nn.init.kaiming_normal_(m.weight.data)
,同时把激活函数改为 ReLU。
常用初始化方法
PyTorch 中提供了 10 中初始化方法
- Xavier 均匀分布
- Xavier 正态分布
- Kaiming 均匀分布
- Kaiming 正态分布
- 均匀分布
- 正态分布
- 常数分布
- 正交矩阵初始化
- 单位矩阵初始化
- 稀疏矩阵初始化
每种初始化方法都有它自己适用的场景,原则是保持每一层输出的方差不能太大,也不能太小。
参考资料
如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。
[PyTorch 学习笔记] 4.1 权值初始化的更多相关文章
- [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...
- [PyTorch 学习笔记] 6.2 Normalization
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson6/bn_and_initialize.py https: ...
- caffe中权值初始化方法
首先说明:在caffe/include/caffe中的 filer.hpp文件中有它的源文件,如果想看,可以看看哦,反正我是不想看,代码细节吧,现在不想知道太多,有个宏观的idea就可以啦,如果想看代 ...
- SQL反模式学习笔记14 关于Null值的使用
目标:辨别并使用Null值 反模式:将Null值作为普通的值,反之亦然 1.在表达式中使用Null: Null值与空字符串是不一样的,Null值参与任何的加.减.乘.除等其他运算,结果都是Null: ...
- 神经网络权值初始化方法-Xavier
https://blog.csdn.net/u011534057/article/details/51673458 https://blog.csdn.net/qq_34784753/article/ ...
- pytorch(14)权值初始化
权值的方差过大导致梯度爆炸的原因 方差一致性原则分析Xavier方法与Kaiming初始化方法 饱和激活函数tanh,非饱和激活函数relu pytorch提供的十种初始化方法 梯度消失与爆炸 \[H ...
- 权值初始化 - Xavier和MSRA方法
设计好神经网络结构以及loss function 后,训练神经网络的步骤如下: 初始化权值参数 选择一个合适的梯度下降算法(例如:Adam,RMSprop等) 重复下面的迭代过程: 输入的正向传播 计 ...
- matlab学习笔记13_1 函数返回值
一起来学matlab-matlab学习笔记13函数 13_1 函数返回值 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 https://blog.csdn.net/qq_36556 ...
- [深度学习] pytorch学习笔记(2)(梯度、梯度下降、凸函数、鞍点、激活函数、Loss函数、交叉熵、Mnist分类实现、GPU)
一.梯度 导数是对某个自变量求导,得到一个标量. 偏微分是在多元函数中对某一个自变量求偏导(将其他自变量看成常数). 梯度指对所有自变量分别求偏导,然后组合成一个向量,所以梯度是向量,有方向和大小. ...
随机推荐
- Python编程第四版中文 上下册完整版pdf|网盘下载附提取码
点击此处下载 提取码:drjh 作者简介 Mark Lutz是Python培训的世界的领先者,他是最初和最畅销的Python著作的作者,从1992年起就是Python社区的先锋人物.Mark有25年的 ...
- luogu P3829 [SHOI2012]信用卡凸包 凸包 点的旋转
LINK:信用卡凸包 当 R==0的时候显然是一个点的旋转 之后再求凸包即可. 这里先说点如何旋转 如果是根据原点旋转的话 经过一个繁杂的推导可以得到一个矩阵. [cosw,-sinw] [sinw, ...
- 玩转 SpringBoot2.x 之整合邮件发送
序 在实际项目中,经常需要用到邮件通知功能.比如,用户通过邮件注册,通过邮件找回密码等:又比如通过邮件发送系统情况,通过邮件发送报表信息等等,实际应用场景很多. 原文地址:https://www.mm ...
- VulnHub靶场学习_HA: Natraj
HA: Natraj Vulnhub靶场 下载地址:https://www.vulnhub.com/entry/ha-natraj,489/ 背景: Nataraj is a dancing avat ...
- 当asp.net core偶遇docker二(打造个人docker镜像)
网络上的docker容器总有一些不尽人意的感觉,这个时候,就需要自己diy一个自用的. 比如我们想在163的mysql 5.7内diy一下,结果发现,这个不带vim,我想改造一个自用的mysql镜像, ...
- 极简 Node.js 入门 - 2.1 Path
极简 Node.js 入门系列教程:https://www.yuque.com/sunluyong/node 本文更佳阅读体验:https://www.yuque.com/sunluyong/node ...
- Consul服务治理发现学习记录
Consul 简介 Consul是一个服务网格(微服务间的 TCP/IP,负责服务之间的网络调用.限流.熔断和监控)解决方案,它是一个一个分布式的,高度可用的系统,而且开发使用都很简便.它提供了一个功 ...
- Springboot调用Oracle存储过程的几种方式
因工作需要将公司SSH项目改为Spingboot项目,将项目中部分需要调用存储过程的部分用entityManagerFactory.unwrap(SessionFactory.class).openS ...
- elementUI 表单清空问题
在使用表单的清空方法时,我们需要注意几个问题: 1.我们需要为每个form-item加上prop属性,要不然无法清空(大部分的问题就是出在这) 2.resetFields()方法是重置表单,重置为默认 ...
- Spark优化之小文件是否需要合并?
我们知道,大部分Spark计算都是在内存中完成的,所以Spark的瓶颈一般来自于集群(standalone, yarn, mesos, k8s)的资源紧张,CPU,网络带宽,内存.Spark的性能,想 ...