~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/8297793.html

参考网址:

http://pytorch.org/docs/0.3.0/nn.html?highlight=kaiming#torch.nn.init.kaiming_normal

https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py

https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua

https://github.com/bamos/densenet.pytorch/blob/master/densenet.py

https://github.com/szagoruyko/wide-residual-networks/blob/master/models/utils.lua

说明:暂时就这么多吧,错误之处请见谅。前两个初始化的方法见pytorch官方文档

1. xavier初始化

torch.nn.init.xavier_uniform(tensor, gain=1)

对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。初始化服从均匀分布U(−a,a)" role="presentation" style="position: relative;">U(−a,a)U(−a,a),其中a=gain×2/(fan_in+fan_out)×3" role="presentation" style="position: relative;">a=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√×3–√a=gain×2/(fan_in+fan_out)×3,该初始化方法也称Glorot initialisation。

参数:

tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

a:可选择的缩放参数

例如:

w = torch.Tensor(3, 5)
nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu'))

torch.nn.init.xavier_normal(tensor, gain=1)

对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。初始化服从高斯分布N(0,std)" role="presentation" style="position: relative;">N(0,std)N(0,std),其中std=gain×2/(fan_in+fan_out)" role="presentation" style="position: relative;">std=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√std=gain×2/(fan_in+fan_out),该初始化方法也称Glorot initialisation。

参数:

tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

a:可选择的缩放参数

例如:

w = torch.Tensor(3, 5)
nn.init.xavier_normal(w)

2. kaiming初始化

torch.nn.init.kaiming_uniform(tensor, a=0, mode='fan_in')

对于输入的tensor或者变量,通过论文“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015)的方法初始化数据。初始化服从均匀分布U(−bound,bound)" role="presentation" style="position: relative;">U(−bound,bound)U(−bound,bound),其中bound=2/((1+a2)×fan_in)×3" role="presentation" style="position: relative;">bound=2/((1+a2)×fan_in)−−−−−−−−−−−−−−−−−−√×3–√bound=2/((1+a2)×fan_in)×3,该初始化方法也称He initialisation。

参数:

tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

a:该层后面一层的激活函数中负的斜率(默认为ReLU,此时a=0)

mode:‘fan_in’ (default) 或者 ‘fan_out’. 使用fan_in保持weights的方差在前向传播中不变;使用fan_out保持weights的方差在反向传播中不变。

例如:

w = torch.Tensor(3, 5)
nn.init.kaiming_uniform(w, mode='fan_in')

torch.nn.init.kaiming_normal(tensor, a=0, mode='fan_in')

对于输入的tensor或者变量,通过论文“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015)的方法初始化数据。初始化服从高斯分布N(0,std)" role="presentation" style="position: relative;">N(0,std)N(0,std),其中std=2/((1+a2)×fan_in)" role="presentation" style="position: relative;">std=2/((1+a2)×fan_in)−−−−−−−−−−−−−−−−−−√std=2/((1+a2)×fan_in),该初始化方法也称He initialisation。

参数:

tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据

a:该层后面一层的激活函数中负的斜率(默认为ReLU,此时a=0)

mode:‘fan_in’ (default) 或者 ‘fan_out’. 使用fan_in保持weights的方差在前向传播中不变;使用fan_out保持weights的方差在反向传播中不变。

例如:

w = torch.Tensor(3, 5)
nn.init.kaiming_normal(w, mode='fan_out')

使用的例子(具体参见原始网址):

https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py

from torch.nn import init
self.classifier = nn.Linear(self.stages[3], nlabels)
init.kaiming_normal(self.classifier.weight)
for key in self.state_dict():
if key.split('.')[-1] == 'weight':
if 'conv' in key:
init.kaiming_normal(self.state_dict()[key], mode='fan_out')
if 'bn' in key:
self.state_dict()[key][...] = 1
elif key.split('.')[-1] == 'bias':
self.state_dict()[key][...] = 0

3. 实际使用中看到的初始化

3.1 ResNeXt,densenet中初始化

https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua

https://github.com/bamos/densenet.pytorch/blob/master/densenet.py

conv

n = kW* kH*nOutputPlane
weight:normal(,math.sqrt(/n))
bias:zero()

batchnorm

weight:fill()
bias:zero()

linear

bias:zero()

3.2 wide-residual-networks中初始化(MSRinit

https://github.com/szagoruyko/wide-residual-networks/blob/master/models/utils.lua

conv

n = kW* kH*nInputPlane
weight:normal(,math.sqrt(/n))
bias:zero()

linear

bias:zero()

[PyTorch]PyTorch中模型的参数初始化的几种方法(转)的更多相关文章

  1. java中Map和List初始化的两种方法

    第一种方法(常用方法): //初始化List List<string> list = new ArrayList</string><string>(); list. ...

  2. Pytorch基础(6)----参数初始化

    一.使用Numpy初始化:[直接对Tensor操作] 对Sequential模型的参数进行修改: import numpy as np import torch from torch import n ...

  3. 服务器文档下载zip格式 SQL Server SQL分页查询 C#过滤html标签 EF 延时加载与死锁 在JS方法中返回多个值的三种方法(转载) IEnumerable,ICollection,IList接口问题 不吹不擂,你想要的Python面试都在这里了【315+道题】 基于mvc三层架构和ajax技术实现最简单的文件上传 事件管理

    服务器文档下载zip格式   刚好这次项目中遇到了这个东西,就来弄一下,挺简单的,但是前台调用的时候弄错了,浪费了大半天的时间,本人也是菜鸟一枚.开始吧.(MVC的) @using Rattan.Co ...

  4. Spring3 MVC请求参数获取的几种方法

    Spring3 MVC请求参数获取的几种方法 一.      通过@PathVariabl获取路径中的参数 @RequestMapping(value="user/{id}/{name}&q ...

  5. 获取网页URL地址及参数等的两种方法(js和C#)

    转:获取网页URL地址及参数等的两种方法(js和C#) 一 js 先看一个示例 用javascript获取url网址信息 <script type="text/javascript&q ...

  6. 在Java Web程序中使用监听器可以通过以下两种方法

    之前学习了很多涉及servlet的内容,本小结我们说一下监听器,说起监听器,编过桌面程序和手机App的都不陌生,常见的套路都是拖一个控件,然后给它绑定一个监听器,即可以对该对象的事件进行监听以便发生响 ...

  7. Spring3 MVC请求参数获取的几种方法[转]

    Spring3 MVC请求参数获取的几种方法 Spring3 MVC请求参数获取的几种方法 一.      通过@PathVariabl获取路径中的参数 @RequestMapping(value=& ...

  8. PHP中获取文件扩展名的N种方法

    PHP中获取文件扩展名的N种方法 从网上收罗的,基本上就以下这几种方式: 第1种方法:function get_extension($file){substr(strrchr($file, '.'), ...

  9. 在MySQL中设置事务隔离级别有2种方法:

    在MySQL中设置事务隔离级别有2种方法: 1 在my.cnf中设置,在mysqld选项中如下设置 [mysqld] transaction-isolation = READ-COMMITTED 2 ...

随机推荐

  1. go学习笔记二:运行使用命令行参数

    本文只作为博主的go语言学习笔记. 对命令行参数的解析,只是在运行时使用的,比如以下命令:go run gomain -conf conf.toml 没有办法再go build时使用. 一.运行时命令 ...

  2. Apache mahout 源码阅读笔记--DataModel之FileDataModel

    要做推荐,用户行为数据是基础. 用户行为数据有哪些字段呢? mahout的DataModel支持,用户ID,ItemID是必须的,偏好值(用户对当前Item的评分),时间戳 这四个字段 {@code ...

  3. redis之django-redis

    自定义连接池 这种方式跟普通py文件操作redis一样,代码如下: views.py import redis from django.shortcuts import render,HttpResp ...

  4. Mysql索引长度和区分度

    首先  索引长度和区分度是相互矛盾的, 索引长度太短,那么区分度就很低,吧索引长度加长,区分度就高,但是索引也是要占内存的,所以我们需要找到一个平衡点: 那么这个平衡点怎么来定? 比如用户表有个字段 ...

  5. centos删除乱码名称的文件

    常规方法rm已经木有办法删除该文件了. 原理: 当文件名为乱码的时候,无法通过键盘输入文件名,所以在终端下就不能直接利用rm,mv等命令管理文件了.但是每个文件都有一个i节点号,可以通过i节点号来管理 ...

  6. C# 编写 TensorFlow 人工智能应用

    TensorFlowSharp入门使用C#编写TensorFlow人工智能应用学习. TensorFlow简单介绍 TensorFlow 是谷歌的第二代机器学习系统,按照谷歌所说,在某些基准测试中,T ...

  7. 初识Java集合框架(Iterator、Collection、Map)

    1. Java集合框架提供了一套性能优良.使用方便的接口和类,它们位于java.util包中 注意: 既有接口也有类,图中画实线的是类,画虚线的是接口 使用之前须要到导入java.util包 List ...

  8. 创建JOB定时执行存储过程

    创建JOB定时执行存储过程有两种方式 方式1:通过plsql手动配置job,如下图: 方式2:通过sql语句,如下sql declare job_OpAutoDta pls_integer;--声明一 ...

  9. ftp 工作原理

  10. Linux SSH免登录配置总结(转)

    转载请出自出处:http://eksliang.iteye.com/blog/2187265 一.原理 我们使用ssh-keygen在ServerA上生成私钥跟公钥,将生成的公钥拷贝到远程机器Serv ...