pytorch对模型参数初始化
1.使用apply()
举例说明:
- Encoder :设计的编码其模型
- weights_init(): 用来初始化模型
- model.apply():实现初始化
# coding:utf-
from torch import nn def weights_init(mod):
"""设计初始化函数"""
classname=mod.__class__.__name__
# 返回传入的module类型
print(classname)
if classname.find('Conv')!= -: #这里的Conv和BatchNnorm是torc.nn里的形式
mod.weight.data.normal_(0.0,0.02)
elif classname.find('BatchNorm')!= -:
mod.weight.data.normal_(1.0,0.02) #bn层里初始化γ,服从(,0.02)的正态分布
mod.bias.data.fill_() #bn层里初始化β,默认为0 class Encoder(nn.Module):
def __init__(self, input_size, input_channels, base_channnes, z_channels): super(Encoder, self).__init__()
# input_size必须为16的倍数
assert input_size % == , "input_size has to be a multiple of 16" models = nn.Sequential()
models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, , , , bias=False))
models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True))
# 此时图片大小已经下降一倍
temp_size = input_size/ # 直到特征图高宽为4
# 目的是保证无论输入什么大小的图片,经过这几层后特征图大小为4*
while temp_size > :
models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*), nn.Conv2d(base_channnes, base_channnes*, , , , bias=False))
models.add_module('BatchNorm2d_{0}'.format(base_channnes*), nn.BatchNorm2d(base_channnes*))
models.add_module('LeakyReLU_{0}'.format(base_channnes*), nn.LeakyReLU(0.2, inplace=True))
base_channnes *=
temp_size /= # 特征图高宽为4后面则添加上最后一层
# 让输出为1*
models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, , , , bias=False))
self.models = models def forward(self, x):
x = self.models(x)
return x if __name__ == '__main__':
e = Encoder(, , , )
# 对e模型中的每个module和其本身都会调用一次weights_init函数,mod参数的值即这些module
e.apply(weights_init)
# 根据名字来查看参数
for name, param in e.named_parameters():
print(name)
# 举个例子看看是否按照设计进行初始化
# 可见BatchNorm2d的weight是正态分布形的参数,bias参数都是0
if name == 'models.BatchNorm2d_128.weight' or name == 'models.BatchNorm2d_128.bias':
print(param)
返回:
# 返回的是依次传入初始化函数的module
Conv2d
LeakyReLU
Conv2d
BatchNorm2d
LeakyReLU
Conv2d
BatchNorm2d
LeakyReLU
Conv2d
BatchNorm2d
LeakyReLU
Conv2d
BatchNorm2d
LeakyReLU
Conv2d
BatchNorm2d
LeakyReLU
Conv2d
Sequential
Encoder # 输出name的格式,并根据条件打印出BatchNorm2d-128的两个参数
models.Conv2_3_64.weight
models.Conv2_64_128.weight
models.BatchNorm2d_128.weight
Parameter containing:
tensor([1.0074, 0.9865, 1.0188, 1.0015, 0.9757, 1.0393, 0.9813, 1.0135, 1.0227,
0.9903, 1.0490, 1.0102, 0.9920, 0.9878, 1.0060, 0.9944, 0.9993, 1.0139,
0.9987, 0.9888, 0.9816, 0.9951, 1.0017, 0.9818, 0.9922, 0.9627, 0.9883,
0.9985, 0.9759, 0.9962, 1.0183, 1.0199, 1.0033, 1.0475, 0.9586, 0.9916,
1.0354, 0.9956, 0.9998, 1.0022, 1.0307, 1.0141, 1.0062, 1.0082, 1.0111,
0.9683, 1.0372, 0.9967, 1.0157, 1.0299, 1.0352, 0.9961, 0.9901, 1.0274,
0.9727, 1.0042, 1.0278, 1.0134, 0.9648, 0.9887, 1.0225, 1.0175, 1.0002,
0.9988, 0.9839, 1.0023, 0.9913, 0.9657, 1.0404, 1.0197, 1.0221, 0.9925,
0.9962, 0.9910, 0.9865, 1.0342, 1.0156, 0.9688, 1.0015, 1.0055, 0.9751,
1.0304, 1.0132, 0.9778, 0.9900, 1.0092, 0.9745, 1.0067, 1.0077, 1.0057,
1.0117, 0.9850, 1.0309, 0.9918, 0.9945, 0.9935, 0.9746, 1.0366, 0.9913,
0.9564, 1.0071, 1.0370, 0.9774, 1.0126, 1.0040, 0.9946, 1.0080, 1.0126,
0.9761, 0.9811, 0.9974, 0.9992, 1.0338, 1.0104, 0.9931, 1.0204, 1.0230,
1.0255, 0.9969, 1.0079, 1.0127, 0.9816, 1.0132, 0.9884, 0.9691, 0.9922,
1.0166, 0.9980], requires_grad=True)
models.BatchNorm2d_128.bias
Parameter containing:
tensor([., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., .,
., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., .,
., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., .,
., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., .,
., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., ., .,
., ., ., ., ., ., ., .], requires_grad=True)
models.Conv2_128_256.weight
models.BatchNorm2d_256.weight
models.BatchNorm2d_256.bias
models.Conv2_256_512.weight
models.BatchNorm2d_512.weight
models.BatchNorm2d_512.bias
models.Conv2_512_1024.weight
models.BatchNorm2d_1024.weight
models.BatchNorm2d_1024.bias
models.Conv2_1024_2048.weight
models.BatchNorm2d_2048.weight
models.BatchNorm2d_2048.bias
models.Conv2_2048_100.weight
2.直接在定义网络时定义
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F class Discriminator(nn.Module):
"""
6层全连接层
"""
def __init__(self, z_dim):
super(Discriminator, self).__init__()
self.z_dim = z_dim
self.net = nn.Sequential(
nn.Linear(z_dim, ),
nn.LeakyReLU(0.2, True),
nn.Linear(, ),
nn.LeakyReLU(0.2, True),
nn.Linear(, ),
nn.LeakyReLU(0.2, True),
nn.Linear(, ),
nn.LeakyReLU(0.2, True),
nn.Linear(, ),
nn.LeakyReLU(0.2, True),
nn.Linear(, ),
)
self.weight_init() # 参数初始化
def weight_init(self, mode='normal'):
if mode == 'kaiming':
initializer = kaiming_init
elif mode == 'normal':
initializer = normal_init for block in self._modules:
for m in self._modules[block]:
initializer(m) def forward(self, z):
return self.net(z).squeeze() def kaiming_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.fill_()
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_()
if m.bias is not None:
m.bias.data.fill_() def normal_init(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.normal_(m.weight, , 0.02)
if m.bias is not None:
m.bias.data.fill_()
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_()
if m.bias is not None:
m.bias.data.fill_()
然后调用即可
pytorch对模型参数初始化的更多相关文章
- PyTorch保存模型与加载模型+Finetune预训练模型使用
Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- Pytorch基础(6)----参数初始化
一.使用Numpy初始化:[直接对Tensor操作] 对Sequential模型的参数进行修改: import numpy as np import torch from torch import n ...
- pytorch和tensorflow的爱恨情仇之参数初始化
pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch和tensorflow的爱恨情仇之定义可训练的参数 pytorch版本 ...
- PyTorch常用参数初始化方法详解
1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...
- 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?
原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...
- ubuntu之路——day15.1 只用python的numpy在底层检验参数初始化对模型的影响
首先感谢这位博主整理的Andrew Ng的deeplearning.ai的相关作业:https://blog.csdn.net/u013733326/article/details/79827273 ...
- [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题
[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...
- pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件
本文分为两部分,第一部分讲如何保存模型参数,优化器参数等等,第二部分则讲如何读取. 假设网络为model = Net(), optimizer = optim.Adam(model.parameter ...
随机推荐
- 原生JavaScript和jQuery的较量
JavaScript和jQuery有很多相似知促,那么二者又是如何进行较量,我们先了解一下什么是JavaScript和jQuery,知其源头,才能知其所以然. 简介: [JavaScript] 一种直 ...
- house买房原理,2019,第一版
,购买框架 1,通过自己的买房预算金额 和 pre-approval 确定你要的房屋总价, 估计到自己可以接受的房子,卖方也喜欢这样的买家,但不一定能拿全额贷款 2,pre-approval对信用分数 ...
- IP地址与Mac地址绑定错误
有个application,有时候可以正常访问,有时候又返回404错误,百思不得其解.刚开始以为是文件夹权限问题,折腾了好久. 后来没在服务器上monitor到包,所以猜想是到了错误的mac地址,用a ...
- 【JS】基础知识
引言 在互联网的演化过程中,网页制作是Web1.0时代的产物,那时网站的主要内容都是静态的,用户使用网站的行为也以浏览为主. 2005年以后,互联网进入了Web2.0时代,各类似桌面软件的Web应用大 ...
- mysqli扩展和持久化连接
mysqli扩展的持久化连接在PHP5.3中被引入.支持已经存在于PDO MYSQL 和ext/mysql中. 持久化连接背后的思想是客户端进程和数据库之间的连接可以通过一个客户端进程来保持重用, 而 ...
- apt-get 和dpkg命令
软件包下载:apt-get 1.apt-get install vim 下载vim 2.apt-get upgrade vim 升级vim 3.apt-get update 列出更新 debian软 ...
- 关于css3属性filter
今天看百度百科,看到其中一页所有图片背景全都设置为了灰白色,于是研究了番,发现是应用了filter滤镜这个属性. // 修改所有图片的颜色为黑白 (100% 灰度): img { -webkit-fi ...
- Flower(规律+逆向思维)
Flower: 传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6486 题解: 逆向思维+规律 因为每次剪n-1,所以逆向就是控制n-1朵不变,每次增高1 ...
- Mybatis mapper接口与xml文件路径分离
为什么分离 对于Maven项目,IntelliJ IDEA默认是不处理src/main/java中的非java文件的,不专门在pom.xml中配置<resources>是会报错的,参考这里 ...
- [NOI.AC]NOI2019省选模拟赛 第二场
传送门 Solution A. 一共有\(T\)组数据 每次询问你\([l,r]\)中有多少个数能被他的所有数位整除(如果数位中含有\(0\)忽略掉) 数位dp,咕咕咕 B. 题面略 考虑一个个只有两 ...