ResNet网络的Pytorch实现
1.文章原文地址
Deep Residual Learning for Image Recognition
2.文章摘要
神经网络的层次越深越难训练。我们提出了一个残差学习框架来简化网络的训练,这些网络比之前使用的网络都要深的多。我们明确地将层变为学习关于层输入的残差函数,而不是学习未参考的函数。我们提供了综合的实验证据来表明这个残差网络更容易优化,以及通过极大提升网络深度可以获得更好的准确率。在ImageNet数据集上,我们评估了残差网络,该网络有152层,层数是VGG网络的8倍,但是有更低的复杂度。几个残差网络的集成在ImageNet数据集上取得了3.57%错误率。这个结果在ILSVRC2015分类任务上取得第一名的成绩。我们也使用了100和1000层网络用在了数据集CIFAR-10上加以分析。
在许多视觉识别任务中,表征的深度是至关重要的。仅仅通过极端深的表征,我们在COCO目标检测数据集上得到了28%的相对提高。深度残差网络是我们提交到ILSVRC & COCO2015竞赛的网络基础,在这里我们获得了ImageNet检测任务、ImageNet定位任务,COCO检测任务和COCO分割任务的第一名。
3.网络结构
4.Pytorch实现
import torch.nn as nn
from torch.utils.model_zoo import load_url as load_state_dict_from_url __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
} def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module):
expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride def forward(self, x):
identity = x out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out) out = self.conv2(out)
out = self.bn2(out) if self.downsample is not None:
identity = self.downsample(x) out += identity
out = self.relu(out) return out class Bottleneck(nn.Module):
expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride def forward(self, x):
identity = x out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out) out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out) out = self.conv3(out)
out = self.bn3(out) if self.downsample is not None:
identity = self.downsample(x) out += identity
out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
) layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x) x = self.avgpool(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x) return x def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs):
model = ResNet(inplanes, planes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model def resnet18(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs) def resnet34(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs) def resnet50(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs) def resnet101(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs) def resnet152(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs) def resnext50_32x4d(**kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained=False, progress=True, **kwargs) def resnext101_32x8d(**kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained=False, progress=True, **kwargs)
参考
https://github.com/pytorch/vision/tree/master/torchvision/models
ResNet网络的Pytorch实现的更多相关文章
- PyTorch对ResNet网络的实现解析
PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...
- 学习笔记-ResNet网络
ResNet网络 ResNet原理和实现 总结 一.ResNet原理和实现 神经网络第一次出现在1998年,当时用5层的全连接网络LetNet实现了手写数字识别,现在这个模型已经是神经网络界的“hel ...
- 0609-搭建ResNet网络
0609-搭建ResNet网络 目录 一.ResNet 网络概述 二.利用 torch 实现 ResNet34 网络 三.torchvision 中的 resnet34网络调用 四.第六章总结 pyt ...
- Resnet网络详细结构(针对Cifar10)
Resnet网络详细结构(针对Cifar10) 结构 具体结构(Pytorch) conv1 (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, ...
- ResNet网络再剖析
随着2018年秋季的到来,提前批和内推大军已经开始了,自己也成功得当了几次炮灰,不过在总结的过程中,越是了解到自己的不足,还是需要加油. 最近重新复习了resnet网络,又能发现一些新的理念,感觉很f ...
- 深度学习之ResNet网络
介绍 Resnet分类网络是当前应用最为广泛的CNN特征提取网络. 我们的一般印象当中,深度学习愈是深(复杂,参数多)愈是有着更强的表达能力.凭着这一基本准则CNN分类网络自Alexnet的7层发展到 ...
- ResNet网络的训练和预测
ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...
- 残差网络resnet理解与pytorch代码实现
写在前面 深度残差网络(Deep residual network, ResNet)自提出起,一次次刷新CNN模型在ImageNet中的成绩,解决了CNN模型难训练的问题.何凯明大神的工作令人佩服 ...
- 深度残差网络(DRN)ResNet网络原理
一说起“深度学习”,自然就联想到它非常显著的特点“深.深.深”(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很容易就想到:深的网络一般会比浅的网络效果 ...
随机推荐
- 常见问题:MySQL/排序
MySQL的排序分为两种,通过排序操作和按索引扫描排序. 按索引顺序扫描是一种很高效的方式,但使用的条件较为严格,只有orderby语句使用索引最左前列,或where语句与orderby语句条件列组合 ...
- Andrew Ng机器学习课程9
Andrew Ng机器学习课程9 首先以一个工匠为例,说明要成为一个出色的工匠,就需要掌握各种工具的使用,才能知道在具体的任务中选择什么工具来做.所以今天要讲的就是机器学习的理论部分. bias va ...
- 高级UI-TableLayout
TableLayout选项卡,用于需要使用选项卡的场景,一般是用于切换Fragment,现在的主流做法一般是TableLayout+ViewPager+Fragment,综合实现选项卡的操作 由于Ta ...
- 按键板的原理和实现--基于GPIO的按键板
上篇介绍简单的ADC实现,需要IC提供一个额外的ADC.但出于IC成本的考虑,无法提供这个的ADC时,但提供了多个额外的GPIO(General Purpose Input Output:双向的:可以 ...
- Ubuntu18.04彻底删除MySQL数据库
首先在终端中查看MySQL的依赖项:dpkg --list|grep mysql 卸载: sudo apt-get remove mysql-common 卸载:sudo apt-get autore ...
- php 文字转换成拼音
<?php //中文字转拼音 $d=array( array("a",-20319), array("ai",-20317), array("a ...
- kafka为什么吞吐量高,怎样保证高可用
1:kafka可以通过多个broker形成集群,来存储大量数据:而且便于横向扩展. 2:kafka信息存储核心的broker,通过partition的segment只关心信息的存储,而生产者只负责向l ...
- docker 实践六:dockerfile 详解
本篇开始来学习关于 dockerfile 的知识. 注:环境为 CentOS7,docker 19.03. dockerfile 是⼀个⽂本格式的配置⽂件, ⽤户可以使⽤ dockerfile 来快速 ...
- AtCoder Grand Contest 034
A:如果C在D左侧,显然先让B到达终点再让A走即可,否则先判断一下A是否可以在某处超过B.也就是先判断一下起点与终点之间是否有连续的障碍,若有则无解:然后若C在D左侧输出Yes,否则判断B和D之间是否 ...
- JS 中document.write()的用法和清空的原因浅析(转)
转自:https://www.jb51.net/article/129715.htm