1. 权值的方差过大导致梯度爆炸的原因
  2. 方差一致性原则分析Xavier方法与Kaiming初始化方法

    饱和激活函数tanh,非饱和激活函数relu
  3. pytorch提供的十种初始化方法

梯度消失与爆炸

\[H_2 = H_1 * W_2\\
\Delta W_2 = \frac{\partial Loss}{\partial W_2}
=\frac{\partial Loss}{\partial out}
*\frac{\partial out}{\partial H_2}
*\frac{\partial H_2}{\partial W_2}
=\frac{\partial Loss}{\partial out}
*\frac{\partial out}{\partial H_2}*H_1
\]
\[{梯度消失:}H_1 \rightarrow 0 \Rightarrow \Delta W_2 \rightarrow 0\\
{梯度爆炸:}H_1 \rightarrow \infty \Rightarrow \Delta W_2 \rightarrow \infty
\]
\[1. E(X*Y)=E(X)*E(Y)\\
2. D(X)=E(X^2)-[E(X)]^2\\
3. D(X+Y)=D(X)+D(Y)\\
1.2.3. \Rightarrow D(X*Y)=D(X)D(Y)+D(X)*[E(Y)]^2+D(Y)*[E(X)]^2\\
若E(X)=0,E(Y)=0 \Rightarrow D(X*Y)=D(X)*D(Y)
\]
\[H_{11} = \sum ^{n}_{i=0} X_i * W_{1i}\\
D(X*Y) = D(X)*D(Y)\\
D(H_{11})=\sum ^{n}_{i=0} D(X_i)*D(W_1i)=n*(1*1)=n\\
std(H_{11})=\sqrt D(H_11) = \sqrt n\\
D(H_1) = n*D(X)*D(W)=1\\
D(W)=\frac{1}{n}\Rightarrow std(W)=\sqrt \frac {1}{n}
\]

Xavier方法与Kaiming方法

Xavier初始化

方差一致性,保持数据尺度维持在恰当范围,通常方差为1

激活函数:饱和函数,如Sigmoid,Tanh

\[n_i * D(W)=1\\
n_{i+1} *D(W)=1\\
\Rightarrow D(W)=\frac{2}{n_i+n_i+1}
\]
\[W \sim U[-a,a]\\
D(W) = \frac {(-a-a)^2}{12} = \frac {(2a)^2}{12}=\frac {a^2}{3}\\
\frac{2}{n_i+n_{i+1}}=\frac{a^2}{3}\Rightarrow a = \frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}}\\
\Rightarrow W \sim U[-\frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}},\frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}}]
\]

Kaiming初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1

激活函数:ReLU及其变种

\[D(W) = \frac{2}{n_i}\\
D(W) = \frac{2}{(1+a^2)*n_i}\\
std(W) = \sqrt{\frac{2}{(1+a^2)*n_i}}
\]

参考文献:

《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》

常用初始化方法

  1. Xavier均匀分布
  2. Xavier正态分布
  3. Kaiming均匀分布
  4. Kaiming正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化
nn.init.calculate_gain(nonlinearity, param=None)

功能:计算激活函数的方差变化尺度

输入数据的方差和输出数据方差的比例。

参数:

  • nonlinearity:激活函数名称
  • param:激活函数参数,Leaky ReLU的negative_slop
# -*- coding: utf-8 -*-
"""
# @file name : grad_vanish_explod.py
# @author : TingsongYu https://github.com/TingsongYu
# @date : 2019-09-30 10:08:00
# @brief : 梯度消失与爆炸实验
"""
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import torch
import random
import numpy as np
import torch.nn as nn
path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools)) import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR) from tools.common_tools import set_seed set_seed(1) # 设置随机种子 class MLP(nn.Module):
def __init__(self, neural_num, layers):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.neural_num = neural_num def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
x = torch.relu(x) print("layer:{}, std:{}".format(i, x.std()))
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break return x def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num)) # normal: mean=0, std=1 # a = np.sqrt(6 / (self.neural_num + self.neural_num))
#
# tanh_gain = nn.init.calculate_gain('tanh')
# a *= tanh_gain
#
# nn.init.uniform_(m.weight.data, -a, a) # nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain) # nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))
nn.init.kaiming_normal_(m.weight.data) flag = 0
# flag = 1 if flag:
layer_nums = 100
neural_nums = 256
batch_size = 16 net = MLP(neural_nums, layer_nums)
net.initialize() inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1 output = net(inputs)
print(output) # ======================================= calculate gain ======================================= # flag = 0
flag = 1 if flag: x = torch.randn(10000)
out = torch.tanh(x) gain = x.std() / out.std()
print('gain:{}'.format(gain)) tanh_gain = nn.init.calculate_gain('tanh')
print('tanh_gain in PyTorch:', tanh_gain)

pytorch(14)权值初始化的更多相关文章

  1. [PyTorch 学习笔记] 4.1 权值初始化

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/grad_vanish_explod.py 在搭建好网络 ...

  2. caffe中权值初始化方法

    首先说明:在caffe/include/caffe中的 filer.hpp文件中有它的源文件,如果想看,可以看看哦,反正我是不想看,代码细节吧,现在不想知道太多,有个宏观的idea就可以啦,如果想看代 ...

  3. 神经网络权值初始化方法-Xavier

    https://blog.csdn.net/u011534057/article/details/51673458 https://blog.csdn.net/qq_34784753/article/ ...

  4. 权值初始化 - Xavier和MSRA方法

    设计好神经网络结构以及loss function 后,训练神经网络的步骤如下: 初始化权值参数 选择一个合适的梯度下降算法(例如:Adam,RMSprop等) 重复下面的迭代过程: 输入的正向传播 计 ...

  5. PyTorch 学习笔记(四):权值初始化的十种方法

    pytorch在torch.nn.init中提供了常用的初始化方法函数,这里简单介绍,方便查询使用. 介绍分两部分: 1. Xavier,kaiming系列: 2. 其他方法分布 Xavier初始化方 ...

  6. 【5】激活函数的选择与权值w的初始化

    激活函数的选择: 西格玛只在二元分类的输出层还可以用,但在二元分类中,其效果不如tanh,效果不好的原因是当Z大时,斜率变化很小,会导致学习效率很差,从而很影响运算的速度.绝大多数情况下用的激活函数是 ...

  7. 2019.01.14 bzoj5343: [Ctsc2018]混合果汁(整体二分+权值线段树)

    传送门 整体二分好题. 题意简述:nnn种果汁,每种有三个属性:美味度,单位体积价格,购买体积上限. 现在有mmm个询问,每次问能否混合出总体积大于某个值,总价格小于某个值的果汁,如果能,求所有方案中 ...

  8. 【机器学习的Tricks】随机权值平均优化器swa与pseudo-label伪标签

    文章来自公众号[机器学习炼丹术] 1 stochastic weight averaging(swa) 随机权值平均 这是一种全新的优化器,目前常见的有SGB,ADAM, [概述]:这是一种通过梯度下 ...

  9. 51nod1459(带权值的dijkstra)

    题目链接:https://www.51nod.com/onlineJudge/questionCode.html#!problemId=1459 题意:中文题诶- 思路:带权值的最短路,这道题数据也没 ...

随机推荐

  1. AcWing 247. 亚特兰蒂斯 (线段树,扫描线,离散化)

    题意:给你\(n\)个矩形,求矩形并的面积. 题解:我们建立坐标轴,然后可以对矩形的横坐标进行排序,之后可以遍历这些横坐标,这个过程可以想像成是一条线从左往右扫过x坐标轴,假如这条线是第一次扫过矩形的 ...

  2. POJ 3281 Dining(最大流板子)

    牛是很挑食的.每头牛都偏爱特定的食物和饮料,其他的就不吃了. 农夫约翰为他的牛做了美味的饭菜,但他忘了根据它们的喜好检查菜单.虽然他不可能喂饱所有的人,但他想让尽可能多的奶牛吃上一顿有食物和水的大餐. ...

  3. hdu2126 Buy the souvenirs

    Problem Description When the winter holiday comes, a lot of people will have a trip. Generally, ther ...

  4. HDU 3605 Escape 最大流

    题意: 如果这是2012年世界末日怎么办?我不知道该怎么做.但是现在科学家们已经发现,有些星球上的人可以生存,但有些人却不适合居住.现在科学家们需要你的帮助,就是确定所有人都能在这些行星上生活.输入多 ...

  5. Codeforces Round #654 (Div. 2) C. A Cookie for You (思维)

    题意:有\(a\)个蛋糕,\(b\)个巧克力,第一类人有\(n\)个,总是吃多的东西(若\(a>b\),吃蛋糕,否则吃巧克力),第二类人有\(m\)个,总是吃少的,可以随便调整这两类人吃的顺序, ...

  6. 📚C#/.NET/.NET Core推荐学习书籍(升职加薪,你值得拥有)

    前言: 作为一名程序员,我们无时无刻都要考虑着如何通过不断地学习来提升自己的核心竞争力.古人有云:"书中自有黄金屋,书中只有颜如玉",说明了书籍的重要性,没错工作多年来,发现身边那 ...

  7. Open3d之交互式可视化

    本篇教程介绍了Open3D的可视化窗口的交互功能. # -*- coding:utf-8 -*- import copy import numpy as np import open3d as o3d ...

  8. dll的注册与反注册

    regsvr32.exe是32位系统下使用的DLL注册和反注册工具,使用它必须通过命令行的方式使用,格式是:regsvr32 [/i[:cmdline]] DLL文件名命令可以在"开始→运行 ...

  9. 1009E Intercity Travelling 【数学期望】

    题目:戳这里 题意:从0走到n,难度分别为a1~an,可以在任何地方休息,每次休息难度将重置为a1开始.求总难度的数学期望. 解题思路: 跟这题很像,利用期望的可加性,我们分析每个位置的状态,不管怎么 ...

  10. php 文件包含base64读取文件 preg_replace函数

    解题部分题目来源攻防世界web高手进阶区1.拿到题目以后,发现是一个index.php的页面,并且设备-没有显示完全,此位置可疑.2.源代码中发现?page=index,出现page这个get参数,联 ...