DL基础补全计划(五)---数值稳定性及参数初始化(梯度消失、梯度爆炸)
PS:要转载请注明出处,本人版权所有。
PS: 这个只是基于《我自己》的理解,
如果和你的原则及想法相冲突,请谅解,勿喷。
前置说明
本文作为本人csdn blog的主站的备份。(BlogID=109)
环境说明
- Windows 10
- VSCode
- Python 3.8.10
- Pytorch 1.8.1
- Cuda 10.2
前言
如果有计算机背景的相关童鞋,都应该知道数值计算中的上溢和下溢的问题。关于计算机中的数值表示,在我的《数与计算机 (编码、原码、反码、补码、移码、IEEE 754、定点数、浮点数)》 (https://blog.csdn.net/u011728480/article/details/100277582) 一文中有比较好的介绍。计算机中的数值表示,相对于实数数轴来说是离散且有限的,意思就是计算机中的能表示的数有最大值和最小值以及最小单位,特别是浮点数表示,有兴趣的可以看看上文。
其实很好理解,深度学习里面具有大量的乘法加法,一不小心你就会遇见上溢和下溢的问题,因此我们一不小心就会遇见NAN和INF的问题(NAN和INF详见上文提到的文章)。此外,由于一些特殊的情况,可能会导致我们的参数的偏导数接近于0,让我们的模型收敛的非常的慢。因此我们可能需要从模型的初始化以及相关的模型构造方面来好好的讨论一下我们在训练过程中可能出现的问题。
一般来说,我们训练的时候都非常的关注我们的损失函数,如果损失函数值异常,会导致相关的偏导数出现接近于0或者接近于无限大,那么就会直接导致模型训练及其困难。此外,我们的权重参数也会参与网络计算,按照上述的描述,权重参数的初始值也可能导致损失函数的值异常。因此大佬们也引入了另外一种常见的初始化方式Xavier,比较具有普适性。下面我们简单的验证一下我们训练过程中出现梯度接近于0和接近于无限大的情况,这里也就是说的梯度消失和梯度爆炸问题。同时也简单说明参数初始化相关的问题。
梯度消失(gradient vanishing)
在深度学习中有一个激活层叫做Sigmoid层,其定义如下是:\(Sigmoid(x)=1/(1+\exp(-x))\),如果我们的模型里面接入了这种激活函数,很容易构造出梯度消失的情况,下面我们看一下其导数和函数值相对于X的相关关系。
代码如下:
import torch
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
xdata, ydata = [[], []], [[], []]
line0, = ax.plot([], [], 'r-', label='sigmoid')
line1, = ax.plot([], [], 'b-', label='gradient-sigmoid')
def init_and_show(xlim_min, xlim_max, ylim_min, ylim_max):
ax.set_xlabel('x')
ax.set_ylabel('sigmoid(x)')
ax.set_title('sigmoid/gradient-sigmoid')
ax.set_xlim(xlim_min, xlim_max)
ax.set_ylim(ylim_min, ylim_max)
ax.legend([line0, line1], ('sigmoid', 'gradient-sigmoid'))
line0.set_data(xdata[0], ydata[0])
line1.set_data(xdata[1], ydata[1])
plt.show()
def sigmoid_test():
x = np.arange(-10.0, 10.0, 0.1)
x = torch.tensor(x, dtype=torch.float, requires_grad=True)
sig_fun = torch.nn.Sigmoid()
y = sig_fun(x)
y.backward(torch.ones_like(y))
xdata[0] = x.detach().numpy()
xdata[1] = x.detach().numpy()
ydata[0] = y.detach().numpy()
ydata[1] = x.grad.detach().numpy()
init_and_show(-10.0, 10.0, 0, 1)
def multi_mat_dot():
M = np.random.normal(size=(4, 4))
print('⼀个矩阵\n', M)
for i in range(10000):
M = np.dot(M, np.random.normal(size=(4, 4)))
print('乘以100个矩阵后\n', M)
if __name__ == '__main__':
sigmoid_test()
结果图如下
我们可以从图中看到,当x小于-5和大于+5的时候,其导数的值接近于0,导致bp的时候,参数更新小,模型收敛的特别的慢。
梯度爆炸(gradient exploding)
现在我们假设我们有一个模型,其有N个线性层构成,定义输入为X,标签为Y,模型为 \(M(X) = X*W_1 .... W_{n-2}*W_{n-1}*W_n\),损失函数为\(L(X) = M(X) - Y = X*W_1 .... W_{n-2}*W_{n-1}*W_n - Y\),求W1关于损失函数的偏导数\(\frac{dL(X)}{dW_1} = X*W_2 .... W_{n-2}*W_{n-1}*W_n\)。从这里我们可以看到W2到Wn与输入的X的乘积构成了W1的偏导数。
下面我们简单的构造一个矩阵,然后让他计算100次乘法。代码如下:
import torch
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
xdata, ydata = [[], []], [[], []]
line0, = ax.plot([], [], 'r-', label='sigmoid')
line1, = ax.plot([], [], 'b-', label='gradient-sigmoid')
def init_and_show(xlim_min, xlim_max, ylim_min, ylim_max):
ax.set_xlabel('x')
ax.set_ylabel('sigmoid(x)')
ax.set_title('sigmoid/gradient-sigmoid')
ax.set_xlim(xlim_min, xlim_max)
ax.set_ylim(ylim_min, ylim_max)
ax.legend([line0, line1], ('sigmoid', 'gradient-sigmoid'))
line0.set_data(xdata[0], ydata[0])
line1.set_data(xdata[1], ydata[1])
plt.show()
def sigmoid_test():
x = np.arange(-10.0, 10.0, 0.1)
x = torch.tensor(x, dtype=torch.float, requires_grad=True)
sig_fun = torch.nn.Sigmoid()
y = sig_fun(x)
y.backward(torch.ones_like(y))
xdata[0] = x.detach().numpy()
xdata[1] = x.detach().numpy()
ydata[0] = y.detach().numpy()
ydata[1] = x.grad.detach().numpy()
init_and_show(-10.0, 10.0, 0, 1)
def multi_mat_dot():
M = np.random.normal(size=(4, 4))
print('⼀个矩阵\n', M)
for i in range(100):
M = np.dot(M, np.random.normal(size=(4, 4)))
print('乘以100个矩阵后\n', M)
if __name__ == '__main__':
multi_mat_dot()
他计算100次乘法后结果如下:
我们可以看到,经过100次乘法后,其值已经非常大(小)了指数都是到了25了。这个时候算出来的损失非常大的,这个时候梯度也非常大,很容易导致训练异常。
参数初始化之Xavier
文首我们提到,我们之前的参数初始化都是基于期望为0,方差为一个指定值初始化的,这里面的指定值是随个人定义的,这个可能会给我们的训练过程带来困扰。
但是我们可以从以下的角度来看待这个事情,我们的权重参数W是一个期望为0,方差为\(\delta^2\)的特定分布。我们的输入特征X是一个期望为0,方差为\(\lambda^2\)的特定分布(注意这里不仅仅是正态分布)。我们假设我们的模型是线性模型,那么其输出为:\(O_i = \sum\limits_{j=1}^{n}W_{ij}X_{j}\),\(O_i\)是代表第i层的输出。这个时候,我们求出\(O_i\)的期望是:\(E(O_i) = \sum\limits_{j=1}^{n}E(W_{ij}X_{j}) = \sum\limits_{j=1}^{n}E(W_{ij})E(X_{j}) = 0\),其方差为:\(Variance(O_i) = E(O_i^2) - (E(O_i))^2 = \sum\limits_{j=1}^{n}E(W_{ij}^2X_{j}^2) - 0 = \sum\limits_{j=1}^{n}E(W_{ij}^2)E(X_{j}^2) = n*\delta^2*\lambda^2\)。我们现在假设如果要\(O_i\)的方差等于X的方差,那么\(n*\delta^2 = 1\)才能够满足要求。现在我们考虑BP的时候,也需要\(n_{out}*\delta^2 = 1\)才能够保证方差不会变,至少从数值稳定性来说,我们应该保证方差尽量稳定,不应该放大。我们同时考虑n和\(n_{out}\),那么我们可以认为当\(1/2*(n+n_{out})*\delta^2 = 1\)时,我们保证了输出O的方差在约定范围内,尽量保证了其数值的稳定性,这就是Xavier方法的核心内容。
初始化方法有很多,但是Xavier方法有较大的普适性。对于某些模型,特定的初始化方法有奇效。
后记
到本文结束,其实我们可以训练一些简单的模型了,但是本文所介绍的3个概念会一直伴随着我们以后的学习过程,如果训练出现了INF,NAN这些特殊的值,基本我们就需要往这方面去想和解决问题。
参考文献
- https://github.com/d2l-ai/d2l-zh/releases (V1.0.0)
- https://github.com/d2l-ai/d2l-zh/releases (V2.0.0 alpha1)
- https://blog.csdn.net/u011728480/article/details/100277582 《数与计算机 (编码、原码、反码、补码、移码、IEEE 754、定点数、浮点数)》
打赏、订阅、收藏、丢香蕉、硬币,请关注公众号(攻城狮的搬砖之路)
PS: 请尊重原创,不喜勿喷。
PS: 要转载请注明出处,本人版权所有。
PS: 有问题请留言,看到后我会第一时间回复。
DL基础补全计划(五)---数值稳定性及参数初始化(梯度消失、梯度爆炸)的更多相关文章
- DL基础补全计划(二)---Softmax回归及示例(Pytorch,交叉熵损失)
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- DL基础补全计划(三)---模型选择、欠拟合、过拟合
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- DL基础补全计划(六)---卷积和池化
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- DL基础补全计划(一)---线性回归及示例(Pytorch,平方损失)
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- OSPF补全计划-0 preface
哇靠,一看日历吓了我一跳,我这一个月都没写任何东西,好吧,事情的确多了点儿,同事离职,我需要处理很多untechnical的东西,弄得我很烦,中间学的一点小东西(关于Linux的)也没往这里记,但是我 ...
- 【hjmmm网络流24题补全计划】
本文食用方式 按ABC--分层叙述思路 可以看完一步有思路后自行思考 飞行员配对问题 题目链接 这可能是24题里最水的一道吧... 很显然分成两个集合 左外籍飞行员 右皇家飞行员 跑二分图最大匹配 输 ...
- 2018.我的NOIP补全计划
code: efzoi.tk @ shleodai noip2011 D1 选择客栈 这道题是一道大水题,冷静分析一会就会发现我们需要维护最后一个不合法点和前缀和. 维护最后一个不合法点只要边扫描边维 ...
- OSPF补全计划-2
想起来几个面试题: 1. OSPF在什么情况下会stuck in Exstart /Exchange状态? 我知道的一个答案是两个端口的mtu不一致.当然整个也不是绝对,因为可以用ip ospf mt ...
- OSPF补全计划-1
OSPF全称是啥我就不絮叨了,什么迪杰斯特拉,什么开放最短路径优先算法都是人尽皆知的事儿,尤其是一提算法还会被学数据结构的童鞋鄙视,干脆就不提了,直接开整怎么用吧.(不过好像真有人不知道OSPF里的F ...
随机推荐
- NTLM协议与Pass the Hash的爱情
0x01.前言 NTLM使用在Windows NT和Windows 2000 Server或者之后的工作组环境中(Kerberos用在域模式下).在AD域环境中,如果需要认证Windows NT系统, ...
- 3、mysql的多实例配置(2)
4.设置mysql多实例启动脚本: (1)3306: [root@backup application]# cat /data/3306/mysql #!/bin/sh . /etc/init.d/f ...
- 使用Flex实现图片旋转。
当用flex实现图片旋转的时候,遇到了这样的问题:截图之后,图片还是会继续旋转,应该是canvas这个还有旋转的角度,所以看到效果跟你截图保存下来的效果不一样. 函数: 角度转换为弧度,这里面涉及到了 ...
- LeSS 的诞生(一):大规模团队该何去何从
<敏捷宣言>发布后,"敏捷"被越来越多的小型开发团队认可.与此同时,另一个问题也逐渐暴露了出来:以 Scrum 为首的敏捷方法论对那些大规模的开发团队并不友好. 基于此 ...
- drf-路由和认证
目录 一.路由Routers SimpleRouter DefaultRouter action的使用 二.认证 认证的写法 认证源码分析 认证组件的使用 一.路由Routers 在 Rest Fra ...
- ansible 任务执行
ansible 任务执行模式 Ansible 系统由控制主机对被管节点的操作方式可分为两类,即adhoc和playbook: ad-hoc模式(点对点模式)• ad-hoc模式(点对点模式) 使用单个 ...
- Kafka:docker安装Kafka消息队列
安装之前先看下图 Kafka基础架构及术语 Kafka基本组成 Kafka cluster: Kafka消息队列(存储消息的队列组件) Zookeeper: 注册中心(kafka集群依赖zookee ...
- shell 调用其他shell脚本中的变量、函数
在Shell中要如何调用别的shell脚本,或别的脚本中的变量,函数呢? 方法一: . ./subscript.sh (两个点之间,有空格) 方法二: source ./subscript. ...
- awk中printf的用法
printf函数 打印输出时,可能需要指定字段间的空格数,从而把列排整齐.在print函数中使用制表符并不能保证得到想要的输出,因此,可以用printf函数来格式化特别的输出. printf函数返 ...
- leetcode156场周赛5206
思路分析: 1.两个数组,一个保存字符,一个保存字符出现次数 2.遍历一遍字符串,出现相同的字符,次数加一,且次数到k的话,那么就剔除,没到k,就次数加一.如果不同,就把它加入到字符的数组里面,对应次 ...