1. ResNet理论

论文:https://arxiv.org/pdf/1512.03385.pdf

残差学习基本单元:

在ImageNet上的结果:

效果会随着模型层数的提升而下降,当更深的网络能够开始收敛时,就会出现降级问题:随着网络深度的增加,准确度变得饱和(这可能不足为奇),然后迅速降级。

ResNet模型:

2. pytorch实现

2.1 基础卷积

conv3$\times\(3 和conv1\)\times$1 基础模块

  1. def conv3x3(in_channel, out_channel, stride=1, groups=1, dilation=1):
  2. return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
  3. def conv1x1(in_channel, out_channel, stride=1):
  4. return nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False)

参数解释:

in_channel: 输入的通道数目

out_channel:输出的通道数目

stride, padding: 步长和补0

dilation: 空洞卷积中的参数

groups: 从输入通道到输出通道的阻塞连接数

  1. feature size 计算:
  2. output = (intput - filter_size + 2 x padding) / stride + 1

空洞卷积实际卷积核大小:

  1. K = K + (K-1)x(R-1)
  2. K 是原始卷积核大小
  3. R 是空洞卷积参数的空洞率(普通卷积为1)

2.2 模块

  1. - resnet34
  2. - _resnet
  3. - ResNet
  4. - _make_layer
  5. - block
  6. - Bottleneck
  7. - BasicBlock

Bottlenect

  1. class Bottleneck(nn.Module):
  2. expansion = 4
  3. __constants__ = ['downsample']
  4. def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
  5. base_width=64, dilation=1, norm_layer=None):
  6. super(Bottleneck, self).__init__()
  7. if norm_layer is None:
  8. norm_layer = nn.BatchNorm2d
  9. width = int(planes * (base_width / 64.)) * groups
  10. # Both self.conv2 and self.downsample layers downsample the input when stride != 1
  11. self.conv1 = conv1x1(inplanes, width)
  12. self.bn1 = norm_layer(width)
  13. self.conv2 = conv3x3(width, width, stride, groups, dilation)
  14. self.bn2 = norm_layer(width)
  15. self.conv3 = conv1x1(width, planes * self.expansion)
  16. self.bn3 = norm_layer(planes * self.expansion)
  17. self.relu = nn.ReLU(inplace=True)
  18. self.downsample = downsample
  19. self.stride = stride
  20. def forward(self, x):
  21. identity = x
  22. out = self.conv1(x)
  23. out = self.bn1(out)
  24. out = self.relu(out)
  25. out = self.conv2(out)
  26. out = self.bn2(out)
  27. out = self.relu(out)
  28. out = self.conv3(out)
  29. out = self.bn3(out)
  30. if self.downsample is not None:
  31. identity = self.downsample(x)
  32. out += identity
  33. out = self.relu(out)
  34. return out

BasicBlock

  1. class BasicBlock(nn.Module):
  2. expansion = 1
  3. __constants__ = ['downsample']
  4. def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
  5. base_width=64, dilation=1, norm_layer=None):
  6. super(BasicBlock, self).__init__()
  7. if norm_layer is None:
  8. norm_layer = nn.BatchNorm2d
  9. if groups != 1 or base_width != 64:
  10. raise ValueError('BasicBlock only supports groups=1 and base_width=64')
  11. if dilation > 1:
  12. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  13. # Both self.conv1 and self.downsample layers downsample the input when stride != 1
  14. self.conv1 = conv3x3(inplanes, planes, stride)
  15. self.bn1 = norm_layer(planes)
  16. self.relu = nn.ReLU(inplace=True)
  17. self.conv2 = conv3x3(planes, planes)
  18. self.bn2 = norm_layer(planes)
  19. self.downsample = downsample
  20. self.stride = stride
  21. def forward(self, x):
  22. identity = x
  23. out = self.conv1(x)
  24. out = self.bn1(out)
  25. out = self.relu(out)
  26. out = self.conv2(out)
  27. out = self.bn2(out)
  28. if self.downsample is not None:
  29. identity = self.downsample(x)
  30. out += identity
  31. out = self.relu(out)
  32. return out

2.3 使用ResNet模块进行迁移学习

  1. import torchvision.models as models
  2. import torch.nn as nn
  3. class RES18(nn.Module):
  4. def __init__(self):
  5. super(RES18, self).__init__()
  6. self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
  7. self.base = torchvision.models.resnet18(pretrained=False)
  8. self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
  9. def forward(self, x):
  10. out = self.base(x)
  11. return out
  12. class RES34(nn.Module):
  13. def __init__(self):
  14. super(RES34, self).__init__()
  15. self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
  16. self.base = torchvision.models.resnet34(pretrained=False)
  17. self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
  18. def forward(self, x):
  19. out = self.base(x)
  20. return out
  21. class RES50(nn.Module):
  22. def __init__(self):
  23. super(RES50, self).__init__()
  24. self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
  25. self.base = torchvision.models.resnet50(pretrained=False)
  26. self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
  27. def forward(self, x):
  28. out = self.base(x)
  29. return out
  30. class RES101(nn.Module):
  31. def __init__(self):
  32. super(RES101, self).__init__()
  33. self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
  34. self.base = torchvision.models.resnet101(pretrained=False)
  35. self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
  36. def forward(self, x):
  37. out = self.base(x)
  38. return out
  39. class RES152(nn.Module):
  40. def __init__(self):
  41. super(RES152, self).__init__()
  42. self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
  43. self.base = torchvision.models.resnet152(pretrained=False)
  44. self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
  45. def forward(self, x):
  46. out = self.base(x)
  47. return out

使用模块直接生成一个类即可,比如训练的时候:

  1. cnn = RES101()
  2. cnn.train() # 改为训练模式
  3. prediction = cnn(img) #进行预测

目前先写这么多,看过了源码以后感觉写的很好,不仅仅有论文中最基础的部分,还有一些额外的功能,模块的组织也很整齐。

平时使用一般都进行迁移学习,使用的话可以把上述几个类中pretrained=False参数改为True.

实战篇:以上迁移学习代码来自我的一个小项目,验证码识别,地址:https://github.com/pprp/captcha_identify.torch

【深度学习】基于Pytorch的ResNet实现的更多相关文章

  1. 深度学习之PyTorch实战(1)——基础学习及搭建环境

    最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...

  2. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

  3. 参考《深度学习之PyTorch实战计算机视觉》PDF

    计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...

  4. 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题

    在学习陈云的教程<深度学习框架PyTorch:入门与实践>的损失函数构建时代码如下: 可我运行如下代码: output = net(input) target = Variable(t.a ...

  5. 深度学习|基于LSTM网络的黄金期货价格预测--转载

    深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...

  6. Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易

    近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...

  7. 【新生学习】深度学习与 PyTorch 实战课程大纲

    各位20级新同学好,我安排的课程没有教材,只有一些视频.论文和代码.大家可以看看大纲,感兴趣的同学参加即可.因为是第一次开课,大纲和进度会随时调整,同学们可以随时关注.初步计划每周两章,一个半月完成课 ...

  8. 基于pytorch实现Resnet对本地数据集的训练

    本文是使用pycharm下的pytorch框架编写一个训练本地数据集的Resnet深度学习模型,其一共有两百行代码左右,分成mian.py.network.py.dataset.py以及train.p ...

  9. 深度学习框架PyTorch一书的学习-第六章-实战指南

    参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter6-实战指南 希望大家直接到上面的网址去查看代码,下面是本人的笔记 将上面地 ...

  10. 深度学习框架PyTorch一书的学习-第五章-常用工具模块

    https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...

随机推荐

  1. Bmp格式图片与16进制的互相转换简解 Python

    BMP TO HEX 首先介绍Github上一个简单的Bmp转成16进制的py: https://github.com/robertgallup/bmp2hex 网上这种例子很多.思路也简单:将bmp ...

  2. Go语言学习笔记——Go语言数据类型

    布尔型 布尔型的值只可以是常量 true 或者 false.一个简单的例子:var b bool = true. 数字类型 整型 int 和浮点型 float32.float64,Go 语言支持整型和 ...

  3. Cas(04)——更改认证方式

    在Cas Server的WEB-INF目录下有一个deployerConfigContext.xml文件,该文件是基于Spring的配置文件,里面存放的内容常常是部署人员需要修改的内容.其中认证方式也 ...

  4. LeetCode:三数之和【15】

    LeetCode:三数之和[15] 题目描述 给定一个包含 n 个整数的数组 nums,判断 nums 中是否存在三个元素 a,b,c ,使得 a + b + c = 0 ?找出所有满足条件且不重复的 ...

  5. AWS 数据库(七)

    数据库概念 关系型数据库 关系数据库提供了一个通用接口,使用户可以使用使用 编写的命令或查询从数据库读取和写入数据. 关系数据库由一个或多个表格组成,表格由与电子表格相似的列和行组成. 以行列形式存储 ...

  6. 微服务Consul系列之集群搭建

    在上一篇中讲解了Consul的安装.部署.基本的使用,使得大家有一个基本的了解,本节开始重点Consul集群搭建,官方推荐3-5台Server,因为在异常处理中,如果出现Leader挂了,只要有超过一 ...

  7. element组件 MessageBox不能显示确认和取消按钮,记录正确使用方法!

    这里是局部引入 调用方式:

  8. box-shadow 用法总结

    一.基础知识 box-shadow 属性向框添加一个或多个阴影. 语法 box-shadow: offset-x offset-y blur spread color inset; box-shado ...

  9. python基础学习(九)

    19.解包 # 解包 unpacking user1 = ["张三", 21, "1999.1.1"] # tuple 类型 user2 = ("李四 ...

  10. (四)Spring Boot官网文档学习

    文章目录 关于默认包的问题 加载启动类 配置 Bean管理和依赖注入 @SpringBootApplication Developer Tools 关于 Developer Tools 的一些细节 原 ...