【深度学习】基于Pytorch的ResNet实现
1. ResNet理论
论文:https://arxiv.org/pdf/1512.03385.pdf
残差学习基本单元:
在ImageNet上的结果:
效果会随着模型层数的提升而下降,当更深的网络能够开始收敛时,就会出现降级问题:随着网络深度的增加,准确度变得饱和(这可能不足为奇),然后迅速降级。
ResNet模型:
2. pytorch实现
2.1 基础卷积
conv3$\times\(3 和conv1\)\times$1 基础模块
def conv3x3(in_channel, out_channel, stride=1, groups=1, dilation=1):
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_channel, out_channel, stride=1):
return nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False)
参数解释:
in_channel: 输入的通道数目
out_channel:输出的通道数目
stride, padding: 步长和补0
dilation: 空洞卷积中的参数
groups: 从输入通道到输出通道的阻塞连接数
feature size 计算:
output = (intput - filter_size + 2 x padding) / stride + 1
空洞卷积实际卷积核大小:
K = K + (K-1)x(R-1)
K 是原始卷积核大小
R 是空洞卷积参数的空洞率(普通卷积为1)
2.2 模块
- resnet34
- _resnet
- ResNet
- _make_layer
- block
- Bottleneck
- BasicBlock
Bottlenect
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
BasicBlock
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
2.3 使用ResNet模块进行迁移学习
import torchvision.models as models
import torch.nn as nn
class RES18(nn.Module):
def __init__(self):
super(RES18, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet18(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
class RES34(nn.Module):
def __init__(self):
super(RES34, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet34(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
class RES50(nn.Module):
def __init__(self):
super(RES50, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet50(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
class RES101(nn.Module):
def __init__(self):
super(RES101, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet101(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
class RES152(nn.Module):
def __init__(self):
super(RES152, self).__init__()
self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet152(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
使用模块直接生成一个类即可,比如训练的时候:
cnn = RES101()
cnn.train() # 改为训练模式
prediction = cnn(img) #进行预测
目前先写这么多,看过了源码以后感觉写的很好,不仅仅有论文中最基础的部分,还有一些额外的功能,模块的组织也很整齐。
平时使用一般都进行迁移学习,使用的话可以把上述几个类中pretrained=False
参数改为True
.
实战篇:以上迁移学习代码来自我的一个小项目,验证码识别,地址:https://github.com/pprp/captcha_identify.torch
【深度学习】基于Pytorch的ResNet实现的更多相关文章
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
- 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码
PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...
- 参考《深度学习之PyTorch实战计算机视觉》PDF
计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...
- 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题
在学习陈云的教程<深度学习框架PyTorch:入门与实践>的损失函数构建时代码如下: 可我运行如下代码: output = net(input) target = Variable(t.a ...
- 深度学习|基于LSTM网络的黄金期货价格预测--转载
深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...
- Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易
近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...
- 【新生学习】深度学习与 PyTorch 实战课程大纲
各位20级新同学好,我安排的课程没有教材,只有一些视频.论文和代码.大家可以看看大纲,感兴趣的同学参加即可.因为是第一次开课,大纲和进度会随时调整,同学们可以随时关注.初步计划每周两章,一个半月完成课 ...
- 基于pytorch实现Resnet对本地数据集的训练
本文是使用pycharm下的pytorch框架编写一个训练本地数据集的Resnet深度学习模型,其一共有两百行代码左右,分成mian.py.network.py.dataset.py以及train.p ...
- 深度学习框架PyTorch一书的学习-第六章-实战指南
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter6-实战指南 希望大家直接到上面的网址去查看代码,下面是本人的笔记 将上面地 ...
- 深度学习框架PyTorch一书的学习-第五章-常用工具模块
https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...
随机推荐
- 无法复制CSD内容,复制后出现一行长字符串解决
先打开一个linux文件,然后把复制的内容放到linux文件中即可解决
- js删除json指定元素
var obj = {‘id’:1, ‘name’:‘张三’}; delete obj.id; // 或者 delete obj[id];
- QT QML之Label, TextField
现在不是去想缺少什么的时候,该想一想凭现有的东西你能做什么.------ 海明威 <老人与海> Label { id: tipLabel width: 120 height: 40 tex ...
- Win10使用Xmanager6远程桌面连接CentOS7服务器
服务器:CentOS 7.6 GNOME桌面环境(若最小化安装,默认是无桌面的,那么就要安装桌面,参考百度) 个人主机:Windows 10专业版,请安装Xmanager Power Suite 6( ...
- [转帖]TPC-C解析系列04_TPC-C基准测试之数据库事务引擎的挑战
TPC-C解析系列04_TPC-C基准测试之数据库事务引擎的挑战 http://www.itpub.net/2019/10/08/3331/ OceanBase这次TPC-C测试与榜单上Oracl ...
- [SQL] - 报表查询效率优化
背景 系统将数据对象JSON序列化后存放到数据库字段中.Report 模块需要获取实时数据对象数值,当前在SQL中进行数值判断的耗时长,效率低. 分析 当前执行效率低主要是程序结构设计的不合理. SQ ...
- 《Mysql - 读写分离有哪些坑?》
一:读写分离 - 概念 - 读写分离的主要目标就是分摊主库的压力. - 基本架构 - - 二:两种读写分离的架构特点 - 客户端直连方案 - 因为少了一层 proxy 转发,所以查询性能稍 ...
- MATLAB:一个K×M的矩阵,第一列是1,其它都是0,从最后一行开始,每循环一次,最后一行的1往右边移一位,移动到末尾后溢出,重新回到最左边,同时上一行的1往右边移一位
问题:一个K×M的矩阵,第一列是1,其它都是0,从最后一行开始,每循环一次,最后一行的1往右边移一位,移动到末尾后溢出,重新回到最左边,同时上一行的1往右边移一位.上一行溢出时,上上一行的1移动一位, ...
- 转:数据库实例自动crash并报ORA-27157、ORA-27300等错误
rhel7.2上安装12C RAC数据库后,其中一个数据库实例经常会自动crash.查看alert日志发现以下错误信息: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 Errors ...
- 【优先队列】Function
Function 题目描述 wls有n个二次函数Fi(x)=aix2+bix+ci(1≤i≤n).现在他想在且x为正整数的条件下求的最小值.请求出这个最小值. 输入 第一行两个正整数n,m.下面n行, ...