为什么要进行初始化

首先假设有一个两层全连接网络,第一层的第一个节点值为 \(H_{11}= \sum_{i=0}^n X_i*W_{1i}\),

这个时候,方差为 \(D(H_{11}) = \sum_{i=0}^n D(X_i) * D(W_{1i})\), 这个时候,输入\(X_i\)一般会做归一化,那么其方差为1,而权重W如果不进行归一化的话,H的方差就会变得很大,然后多层累计,下一次的输入会越来越大,使得网络不好收敛,如果权重W进行了初始化,使得其方差保持在1/n附近,那么方差H则会收敛在1附近,从而使得网络变得更好优化。 很多初始化都是使用的这个原理,控制每一层的输出,使其保持在一定的范围内。

一些常见初始化方法

Xavier

Xavier初始化也是类似的原理, 假设输入X 以及做了归一化,其方差为1 ,那么Xavier所希望的就是上述公式D(H) 保持在1左右,那么就可以得到公式

\[H_{layer1} = \sum_{i=0}^n D(X_i) * D(W_{1i})=n_1 *D(W) = 1 \\ H_{layer2} =\sum_{i=0}^n D(X_i) * D(W_{1i}) = n_2 *D(W) = 1
\]

其中n1 和 n2 为网络层的输入输出节点数量,一般情况下,输入输出是不一样的,为了均衡考虑,可以做一个平均操作,于是变得到 \(D(W) = \frac{2}{n_1+n_2}\)

这个时候,我们假设 W服从均匀分布 \(U[-a, a]\), 那么在这个条件下,

\[D(W) = \frac{(-a-a)^2}{12} = \frac{a^2}{3}
\]

推出\(a = \frac{\sqrt{6}}{\sqrt{n_1+n_2+1}}\),从而得到:

\[W \sim U[-\frac{\sqrt{6}}{\sqrt{n_1+n_2+1}},\frac{\sqrt{6}}{\sqrt{n_1+n_2+1}}]
\]

这样就可以得到Xavier初始化,在pytorch中使用Xavier初始化方式如下,值得注意的是,Xavier对于sigmoid和tanh比较好,对于其他的可能效果就不是那么好了

nn.init.xavier_uniform_(m.weight.data)

Kaiming

Kaiming 初始化比较适合ReLU激活函数,其原理也跟上述差不多,也是希望将权重的方差保持在一定的范围内,使得正反向传播的值得到有效的控制,在kaiming初始化中,主要将权重的方差设置为 \(D(w) = \frac{2}{ni}\),由于考虑到ReLU激活函数,将方差调整为\(D(w)= \frac{2}{(1+a^2)*n_i}\), 这里的a是ReLU的斜率。

在pytorch中使用Kaiming初始化

nn.init.kaiming_normal_(m.weight.data)


LSTM初始化

LSTM中,公式和参数值的设定如下所示

在LSTM中,由于很多门控的权重尺寸是一样的,所以可以使用如下方法进行初始化

def _init_lstm(self, weight):
for w in weight.chunk(4, 0):
init.xavier_uniform(w) self._init_lstm(self.lstm.weight_ih_l0)
self._init_lstm(self.lstm.weight_hh_l0)
self.lstm.bias_ih_l0.data.zero_()
self.lstm.bias_hh_l0.data.zero_()

Embedding进行初始化

self.embedding = nn.Embedding(embedding_tokens, embedding_features, padding_idx=0)
init.xavier_uniform(self.embedding.weight)

其他通用初始化方法

遍历初始化

for name, param in net.named_parameters():
if 'weight' in name:
init.normal_(param, mean=0, std=0.01)
print(name, param.data) for name, param in net.named_parameters():
if 'bias' in name:
init.constant_(param, val=0)
print(name, param.data) ## 通过instance 初始化
for m in self.children():
if isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, -100)
# 也可以判断是否为conv2d,使用相应的初始化方式
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight.item(), 1)
nn.init.constant_(m.bias.item(), 0)

直接使用pytorch内置初始化

from torch.nn import init 

init.normal_(net[0].weight, mean=0, std=0.01) 

init.constant_(net[0].bias, val=0)

自带初始化方法中,会自动消除梯度反向传播,但是手动情况下必须自己设定

def no_grad_uniform(tensor, a, b):

  with torch.no_grad():

    return tensor.uniform_(a, b)

使用apply进行初始化

批量初始化方法,注意net里面的apply函数,可以作用网络的所有module

def weights_init(m):                                               # 1

  classname = m.__class__.__name__                             # 2

  if classname.find('Conv') != -1:                               # 3

    nn.init.kaiming_normal_(m.weight.data)                  # 4

  elif classname.find('BatchNorm') != -1:                        # 5

    nn.init.normal_(m.weight.data, 1.0, 0.02)                  # 6

    nn.init.constant_(m.bias.data, 0)                          # 7 

net.apply(weights_init)

Pytorch系列:(七)模型初始化的更多相关文章

  1. 计算广告CTR预估系列(七)--Facebook经典模型LR+GBDT理论与实践

    计算广告CTR预估系列(七)--Facebook经典模型LR+GBDT理论与实践 2018年06月13日 16:38:11 轻春 阅读数 6004更多 分类专栏: 机器学习 机器学习荐货情报局   版 ...

  2. Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager)

    Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager) 本篇主要讲解iOS开发中的网络监控 前言 在开发中,有时候我们需要获取这些信息: 手机是否联网 ...

  3. [Asp.net MVC]Asp.net MVC5系列——在模型中添加验证规则

    目录 概述 在模型中添加验证规则 自定义验证规则 伙伴类的使用 总结 系列文章 [Asp.net MVC]Asp.net MVC5系列——第一个项目 [Asp.net MVC]Asp.net MVC5 ...

  4. WCF编程系列(七)信道及信道工厂

    WCF编程系列(七)信道及信道工厂   信道及信道栈 前面已经提及过,WCF中客户端与服务端的交互都是通过消息来进行的.消息从客户端传送到服务端会经过多个处理动作,在WCF编程模型中,这些动作是按层 ...

  5. Asp.net MVC]Asp.net MVC5系列——在模型中添加

    目录 概述 在模型中添加验证规则 自定义验证规则 伙伴类的使用 总结 系列文章 [Asp.net MVC]Asp.net MVC5系列——第一个项目 [Asp.net MVC]Asp.net MVC5 ...

  6. iOS流布局UICollectionView系列七——三维中的球型布局

      摘要: 类似标签云的球状布局,也类似与魔方的3D布局 iOS流布局UICollectionView系列七——三维中的球型布局 一.引言 通过6篇的博客,从平面上最简单的规则摆放的布局,到不规则的瀑 ...

  7. [源码解析] PyTorch分布式(6) -------- DistributedDataParallel -- 初始化&store

    [源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store 目录 [源码解析] PyTorch分布式(6) ---Distribu ...

  8. Keil MDK STM32系列(七) STM32F4基于HAL的PWM和定时器

    Keil MDK STM32系列 Keil MDK STM32系列(一) 基于标准外设库SPL的STM32F103开发 Keil MDK STM32系列(二) 基于标准外设库SPL的STM32F401 ...

  9. SQL Server 2008空间数据应用系列七:基于Bing Maps(Silverlight) 的空间数据展现

    原文:SQL Server 2008空间数据应用系列七:基于Bing Maps(Silverlight) 的空间数据展现 友情提示,您阅读本篇博文的先决条件如下: 1.本文示例基于Microsoft ...

随机推荐

  1. GPU加速计算

    GPU加速计算 NVIDIA A100 Tensor Core GPU 可针对 AI.数据分析和高性能计算 (HPC),在各种规模上实现出色的加速,应对极其严峻的计算挑战.作为 NVIDIA 数据中心 ...

  2. 车联网V-2X智能汽车驾驶

    车联网V-2X智能汽车驾驶 早期的功能互联汽车无法满足全球车主针对不同应用和定制移动服务的各种需求.这导致较低的客户续订率,较高的建造和运营成本以及较低的组装率.通常,在没有统一平台的情况下,不同的车 ...

  3. 对SpringBoot和SpringCloud的理解

    1.SpringCloud是什么 SpringCloud基于SpringBoot提供了一整套微服务的解决方案,包括服务注册与发现,配置中心,全链路监控,服务网关,负载均衡,熔断器等组件,除了基于Net ...

  4. [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入

    [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 目录 [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 0x00 摘要 0 ...

  5. 孟老板 ListAdapter封装, 告别Adapter代码 (上)

    BaseAdapter封装(一) 简单封装 BaseAdapter封装(二) Header,footer BaseAdapter封装(三) 空数据占位图 BaseAdapter封装(四) PageHe ...

  6. 查找文件与cron计划任务

    查找文件 • 根据预设的条件递归查找对应的文件 find [目录] [条件1] [-a|-o] [条件2] ... -type  类型(f文件.d目录.l快捷方式) -name  "文档名称 ...

  7. 排查bug:竟然是同事把Redis用成这鬼样子,坑了我

    首先说下问题现象:内网sandbox环境API持续1周出现应用卡死,所有api无响应现象 刚开始当测试抱怨环境响应慢的时候 ,我们重启一下应用,应用恢复正常,于是没做处理.但是后来问题出现频率越来越频 ...

  8. NOIP模拟测试2「排列 (搜索)·APIO划艇」

    排序 内存限制:128 MiB 时间限制:1000 ms 标准输入输出     题目描述 输入格式 数据范围与提示 对于30%的数据,1<=N<=4: 对于全部的数据,1<=N< ...

  9. c#创建windows服务(创建,安装,删除)

    一.在vs中创建一个window服务 二.进入Service1.cs页面后 右击----创建安装程序,安装程序创建成功后---会出现ProjectInstaller.cs文件 三.进入ProjectI ...

  10. 四、JavaSE语言基础之运算符

    什么是是运算符 运算符:用于数据运算的符号,运算是一种处理.(注:浮点型数据(float.double)进行运算会出现精度丢失的情况) 运算符大致可分为以下六种: 一.算术运算符:+.-.*./.%. ...