PyTorch对ResNet网络的实现解析
PyTorch对ResNet网络的实现解析
1.首先导入需要使用的包
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
# 默认的resnet网络,已预训练
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
2.定义一个3*3的卷积层
def conv3x3(in_channels,out_channels,stride=1):
return nn.Conv2d(
in_channels, # 输入深度(通道)
out_channels, # 输出深度
kernel_size=3,# 滤波器(过滤器)大小为3*3
stride=stride,# 步长,默认为1
padding=1, # 0填充一层
bias=False # 不设偏置
)
下面会重复使用到这个3*3卷积层,虽然只使用了几次...
这里为什么用深度而不用通道,是因为我觉得深度相比通道更有数量上感觉,其实都一样。
3.定义最重要的残差模块
这个是基础块,由两个叠加的3*3卷积组成
class BasicBlock(nn.Module):
expansion = 1 # 是对输出深度的倍乘,在这里等同于忽略
def __init__(self,in_channels,out_channels,stride=1,downsample=None):
super(BasicBlock,self).__init__()
self.conv1 = conv3x3(in_channels,out_channels,stride) # 3*3卷积层
self.bn1 = nn.BatchNorm2d(out_channels) # 批标准化层
self.relu = nn.ReLU(True) # 激活函数
self.conv2 = conv3x3(out_channels,out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample # 这个是shortcut的操作
self.stride = stride # 得到步长
def forward(self,x):
residual = x # 获得上一层的输出
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None: # 当shortcut存在的时候
residual = self.downsample(x)
# 我们将上一层的输出x输入进这个downsample所拥有一些操作(卷积等),将结果赋给residual
# 简单说,这个目的就是为了应对上下层输出输入深度不一致问题
out += residual # 将bn2的输出和shortcut过来加在一起
out = self.relu(out)
return out
瓶颈块,有三个卷积层分别是1x1,3x3,1x1,分别用来降低维度,卷积处理,升高维度
class Bottleneck(nn.Module): # 由于bottleneck译意为瓶颈,我这里就称它为瓶颈块
expansion = 4 # 若我们输入深度为64,那么扩张4倍后就变为了256
# 其目的在于使得当前块的输出深度与下一个块的输入深度保持一致
# 而为什么是4,这是因为在设计网络的时候就规定了的
# 我想应该可以在保证各层之间的输入输出一致的情况下修改扩张的倍数
def __init__(self,in_channels,out_channels,stride=1,downsample=None):
super(Bottleneck,self).__init__()
self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=1,bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
# 这层1*1卷积层,是为了降维,把输出深度降到与3*3卷积层的输入深度一致
self.conv2 = nn.conv3x3(out_channels,out_channels) # 3*3卷积操作
self.bn2 = nn. BatchNorm2d(out_channels)
# 这层3*3卷积层的channels是下面_make_layer中的第二个参数规定的
self.conv3 = nn.Conv2d(out_channels,out_channels*self.expansion,kernel_size=1,bias=False)
self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)
# 这层1*1卷积层,是在升维,四倍的升
self.relu = nn.ReLU(True) # 激活函数
self.downsample = downsample # shortcut信号
self.stride = stride # 获取步长
def forward(self,x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out) # 连接一个激活函数
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x) # 目的同上
out += residual
out = self.relu(out)
return out
注意:降维只发生在当1*1卷积层的输出深度大于输入深度的时候,当输入输出深度一样时是没有降维的。Resnet中没有降维的情况只发生在刚开始第一个残差块那。
引入Bottleneck的目的是,减少参数的数目,Bottleneck相比较BasicBlock在参数的数目上少了许多,但是精度上却差不多。减少参数同时还会减少计算量,使模型更快的收敛。
4.ResNet主体部分的实现
class ResNet(nn.Module):
def __init__(self,block,layers,num_classes=10):
# block:为上边的基础块BasicBlock或瓶颈块Bottleneck,它其实就是一个对象
# layers:每个大layer中的block个数,设为blocks更好,但每一个block实际上也很是一些小layer
# num_classes:表示最终分类的种类数
super(ResNet,self).__init__()
self.in_channels = 64 # 输入深度为64,我认为把这个理解为每一个残差块块输入深度最好
self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
# 输入深度为3(正好是彩色图片的3个通道),输出深度为64,滤波器为7*7,步长为2,填充3层,特征图缩小1/2
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(True) # 激活函数
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) # 最大池化,滤波器为3*3,步长为2,填充1层,特征图又缩小1/2
# 此时,特征图的尺寸已成为输入的1/4
# 下面的每一个layer都是一个大layer
# 第二个参数是残差块中3*3卷积层的输入输出深度
self.layer1 = self._make_layer(block,64,layers[0]) # 特征图大小不变
self.layer2 = self._make_layer(block,128,layers[1],stride=2) # 特征图缩小1/2
self.layer3 = self._make_layer(block,256,layers[2],stride=2) # 特征图缩小1/2
self.layer4 = self._make_layer(block,512,layers[3],stride=2) # 特征图缩小1/2
# 这里只设置了4个大layer是设计网络时规定的,我们也可以视情况自己往上加
# 这里可以把4个大layer和上边的一起看成是五个阶段
self.avgpool = nn.AvgPool2d(7,stride=1) # 平均池化,滤波器为7*7,步长为1,特征图大小变为1*1
self.fc = nn.Linear(512*block.expansion,num_classes) # 全连接层
# 这里进行的是网络的参数初始化,可以看出卷积层和批标准化层的初始化方法是不一样的
for m in self.modules():
# self.modules()采取深度优先遍历的方式,存储了网络的所有模块,包括本身和儿子
if isinstance(m,nn.Conv2d): # isinstance()判断一个对象是否是一个已知的类型
nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
# 9. kaiming_normal 初始化 (这里是nn.init初始化函数的源码,有好几种初始化方法)
# torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
# nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
# tensor([[ 0.2530, -0.4382, 1.5995],
# [ 0.0544, 1.6392, -2.0752]])
elif isinstance(m,nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
# 3. 常数 - 固定值 val
# torch.nn.init.constant_(tensor, val)
# nn.init.constant_(w, 0.3)
# tensor([[ 0.3000, 0.3000, 0.3000],
# [ 0.3000, 0.3000, 0.3000]])
def _make_layer(self,block,out_channels,blocks,stride=1):
# 这里的blocks就是该大layer中的残差块数
# out_channels表示的是这个块中3*3卷积层的输入输出深度
downsample = None # shortcut内部的跨层实现
if stride != 1 or self.in_channels != out_channels*block.expansion:
# 判断步长是否为1,判断当前块的输入深度和当前块卷积层深度乘于残差块的扩张
# 为何用步长来判断,我现在还不明白,感觉没有也行
downsample = nn.Sequential(
nn.Conv2d(self.in_channels,out_channels*block.expansion,kernel_size=1,stride=stride,bias=False),
nn.BatchNorm2d(out_channels*block.expansion)
)
# 一旦判断条件成立,那么给downsample赋予一层1*1卷积层和一层批标准化层。并且这一步将伴随这特征图缩小1/2
# 而为何要在shortcut中再进行卷积操作?是因为在残差块之间,比如当要从64深度的3*3卷积层阶段过渡到128深度的3*3卷积层阶段,主分支为64深度的输入已经通过128深度的3*3卷积层变成了128深度的输出,而shortcut分支中x的深度仍为64,而主分支和shortcut分支相加的时候,深度不一致会报错。这就需要进行升维操作,使得shortcut分支中的x从64深度升到128深度。
# 而且需要这样操作的其实只是在基础块BasicBlock中,在瓶颈块Bottleneck中主分支中自己就存在升维操作,那么Bottleneck还在shortcut中引入卷积层的目的是什么?能带来什么帮助?
layers = []
layers.append(block(self.in_channels,out_channels,stride,downsample))
# block()生成上面定义的基础块和瓶颈块的对象,并将dowsample传递给block
self.in_channels = out_channels*block.expansion # 改变下面的残差块的输入深度
# 这使得该阶段下面blocks-1个block,即下面循环内构造的block与下一阶段的第一个block的在输入深度上是相同的。
for i in range(1,blocks): # 这里面所有的block
layers.append(block(self.in_channels,out_channels))
#一定要注意,out_channels一直都是3*3卷积层的深度
return nn.Sequential(*layers) # 这里表示将layers中的所有block按顺序接在一起
def forward(self,x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.maxpool(out) # 写代码时一定要仔细,别把out写成x了,我在这里吃了好大的亏
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = out.view(out.size(0),-1) # 将原有的多维输出拉回一维
out = self.fc(out)
return out
5.定义各种ResNet网络
resnet18,共有18层卷积层
def resnet18(pretrained=False,**kwargs):
'''
pretrained:若为True,则返回在ImageNet数据集上预先训练的模型
**kwargs:应该只包括两个参数,一个是输入x,一个是输出分类个数num_classes
'''
model = ResNet(BasicBlock,[2,2,2,2],**kwargs)
# block对象为 基础块BasicBlock
# layers列表为 [2,2,2,2],这表示网络中每个大layer阶段都是由两个BasicBlock组成
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
resnet34,共有34层卷积层
def resnet34(pretrained=False,**kwargs):
model = ResNet(BasicBlock,[3,4,6,3],**kwargs)
# block对象为 基础块BasicBlock
# layers列表 [3,4,6,3]
# 这表示layer1、layer2、layer3、layer4分别由3、4、6、3个BasicBlock组成
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
resnet50,共有50层卷积层
def resnet50(pretrained=False,**kwargs):
model = ResNet(Bottleneck,[3,4,6,3],**kwargs)
# block对象为 瓶颈块Bottleneck
# layers列表 [3,4,6,3]
# 这表示layer1、layer2、layer3、layer4分别由3、4、6、3个Bottleneck组成
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
resnet101,共有101层卷积层
def resnet101(pretrained=False,**kwargs):
model = ResNet(Bottleneck,[3,4,23,3],**kwargs)
# block对象为 瓶颈块Bottleneck
# layers列表 [3,4,23,3]
# 这表示layer1、layer2、layer3、layer4分别由3、4、23、3个Bottleneck组成
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
resnet152,共有152层卷积层
def resnet152(pretrained=False,**kwargs):
model = ResNet(Bottleneck,[3,8,36,3],**kwargs)
# block对象为 瓶颈块Bottleneck
# layers列表 [3,8,36,3]
# 这表示layer1、layer2、layer3、layer4分别由3、8、36、3个Bottleneck组成
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
6.总结
我们可以从上面看出:
- resnet18和resnet34只用到了简单的BasicBlock,resnet50、resnet101和resnet152用的是Bottleneck。
- Bottleneck相比较BasicBlock在参数量上减少了16.94倍。
- resnet50、resnet101和resnet152三个网络输入输出大小都一样,只是中间的参数个数不一样。
- resnet网络中第一个残差块的输入深度都为64,其他的为残差块中3*3卷积层的深度乘以block.expansion。
- 从每一个layer阶段到下一个layer阶段都伴随着特征图缩小1/2,特征图深度加深1/2。这发生在除第一个layer外的每个layer中的第一个残差块中。
- resnet网络的四个layer前后的操作都是一样,因此resnet网络输入的图片尺寸固定为224*224(还不确定)。
- 在理解网络的时候最好结合resnet18、resnet50的结构图。
PyTorch对ResNet网络的实现解析的更多相关文章
- 学习笔记-ResNet网络
ResNet网络 ResNet原理和实现 总结 一.ResNet原理和实现 神经网络第一次出现在1998年,当时用5层的全连接网络LetNet实现了手写数字识别,现在这个模型已经是神经网络界的“hel ...
- 0609-搭建ResNet网络
0609-搭建ResNet网络 目录 一.ResNet 网络概述 二.利用 torch 实现 ResNet34 网络 三.torchvision 中的 resnet34网络调用 四.第六章总结 pyt ...
- Resnet网络详细结构(针对Cifar10)
Resnet网络详细结构(针对Cifar10) 结构 具体结构(Pytorch) conv1 (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, ...
- 基于pytorch实现Resnet对本地数据集的训练
本文是使用pycharm下的pytorch框架编写一个训练本地数据集的Resnet深度学习模型,其一共有两百行代码左右,分成mian.py.network.py.dataset.py以及train.p ...
- Android网络之数据解析----使用Google Gson解析Json数据
[声明] 欢迎转载,但请保留文章原始出处→_→ 生命壹号:http://www.cnblogs.com/smyhvae/ 文章来源:http://www.cnblogs.com/smyhvae/p/4 ...
- Android网络之数据解析----SAX方式解析XML数据
[声明] 欢迎转载,但请保留文章原始出处→_→ 生命壹号:http://www.cnblogs.com/smyhvae/ 文章来源:http://www.cnblogs.com/smyhvae/p/ ...
- PrismCDN 网络的架构解析,以及低延迟、低成本的奥秘
5 月 19.20 日,行业精英齐聚的 WebRTCon 2018 在上海举办.又拍云 PrismCDN 项目负责人凌建发在大会做了<又拍云低延时的 WebP2P 直播实践>的精彩分享. ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
- ResNet网络再剖析
随着2018年秋季的到来,提前批和内推大军已经开始了,自己也成功得当了几次炮灰,不过在总结的过程中,越是了解到自己的不足,还是需要加油. 最近重新复习了resnet网络,又能发现一些新的理念,感觉很f ...
随机推荐
- codeforces 1301C Ayoub's function
题目链接:http://codeforces.com/problemset/problem/1301/C 思路: 纯想想了一次,发现one_cnt >= zero_cnt的时候很简单,就是(n) ...
- 剑指offer-面试题60-n个骰子的点数-动态规划
/* 题目: 计算n个骰子,出现和s的概率. */ #include<iostream> #include<cstdlib> #include<stack> #in ...
- 通过shell模拟redis-trib.rb info的输出
需求:模拟redis-trib.rb info ip:port输出的结果 如: [redis@lxd-vm3 ~]$ redis-trib.rb info 5.5.5.101:29001 /usr/l ...
- 令人抓狂的redis和rediscluster Python驱动包的安装
本文环境:centos 7,Python3编译安装成功,包括pip3,然后需要安装redis相关的Python3驱动包,本的redis指redis包而非redis数据库,rediscluster类似. ...
- 浅谈python的第三方库——numpy(终)
本文作为numpy系列的总结篇,继续介绍numpy中常见的使用小贴士 1 手动转换矩阵规格 转换矩阵规格,就是在保持原矩阵的元素数量和内容不变的情况下,改变原矩阵的行列数目.比如,在得到一个5x4的矩 ...
- spring中JdbcTemplate使用
1.maven依赖 <?xml version="1.0" encoding="UTF-8"?> <project xmlns="h ...
- 11、C++之const类成员变量,const成员函数
//转载 类的成员函数后面加 const,表明这个函数不会对这个类对象的数据成员(准确地说是非静态数据成员)作任何改变. 在设计类的时候,一个原则就是对于不改变数据成员的成员函数都要在后面加 cons ...
- c++并发编程之进程创建(给那些想知道细节的人)
关于多进程创建,此处只讲解一个函数fork(). 1.进程创建 先上代码: #include"iostream" #include<unistd.h> //unix标准 ...
- C#加密与解密(DES\RSA)学习笔记
本笔记摘抄自:https://www.cnblogs.com/skylaugh/archive/2011/07/12/2103572.html,记录一下学习过程以备后续查用. 数据加密技术是网络中最基 ...
- springboot web - 启动(2) run()
接上一篇 在创建 SpringApplication 之后, 调用了 run() 方法. public ConfigurableApplicationContext run(String... arg ...