文章来自微信公众号【机器学习炼丹术】。我是炼丹兄,有什么问题都可以来找我交流,近期建立了微信交流群,也在朋友圈抽奖赠书十多本了。我的微信是cyx645016617,欢迎各位朋友。

参考目录:

@

上一节课讲解了MobileNet的一个DSC深度可分离卷积的概念,希望大家可以在实际的任务中使用这种方法,现在再来介绍EfficientNet的另外一个基础知识—,Squeeze-and-Excitation Networks压缩-激活网络

1 网络结构

可以看出来,左边的图是一个典型的Resnet的结构,Resnet这个残差结构特征图求和而不是通道拼接,这一点可以注意一下

这个SENet结构式融合在残差网络上的,我来分析一下上图右边的结构:

  • 输出特征图假设shape是\(W \times H \times C\)的;
  • 一般的Resnet就是这个特征图经过残差网络的基本组块,得到了输出特征图,然后输入特征图和输入特征图通过残差结构连在一起(通过加和的方式连在一起);
  • SE模块就是输出特征图先经过一个全局池化层,shape从\(W \times H \times C\)变成了\(1 \times 1 \times C\),这个就变成了一个全连接层的输入啦
    • 压缩Squeeze:先放到第一个全连接层里面,输入\(C\)个元素,输出\(\frac{C}{r}\),r是一个事先设置的参数;

    • 激活Excitation:在接上一个全连接层,输入是\(\frac{C}{r}\)个神经元,输出是\(C\)个元素,实现激活的过程;

  • 现在我们有了一个\(C\)个元素的经过了两层全连接层的输出,这个C个元素,刚好表示的是原来输出特征图\(W \times H \times C\)中C个通道的一个权重值,所以我们让C个通道上的像素值分别乘上全连接的C个输出,这个步骤在图中称为Scale而这个调整过特征图每一个通道权重的特征图是SE-Resnet的输出特征图,之后再考虑残差接连的步骤。

在原文论文中还有另外一个结构图,供大家参考:

2 参数量分析

每一个卷积层都增加了额外的两个全连接层,不够好在全连接层的参数非常小,所以直观来看应该整体不会增加很多的计算量。 Resnet50的参数量为25M的大小,增加了SE模块,增加了2.5M的参数量,所以大概增加了10%左右,而且这2.5M的参数主要集中在final stage的se模块,因为在最后一个卷积模块中,特征图拥有最大的通道数,所以这个final stage的参数量占据了增加的2.5M参数的96%。

这里放一个几个网络结构的对比:

3 PyTorch实现与解析

先上完整版的代码,大家可以复制本地IDE跑一跑,如果代码有什么问题可以联系我:

import torch
import torch.nn as nn
import torch.nn.functional as F class PreActBlock(nn.Module):
def __init__(self, in_planes, planes, stride=1):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
) # SE layers
self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1)
self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) def forward(self, x):
out = F.relu(self.bn1(x))
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
out = self.conv1(out)
out = self.conv2(F.relu(self.bn2(out))) # Squeeze
w = F.avg_pool2d(out, out.size(2))
w = F.relu(self.fc1(w))
w = F.sigmoid(self.fc2(w))
# Excitation
out = out * w out += shortcut
return out class SENet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(SENet, self).__init__()
self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512, num_classes) def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes
return nn.Sequential(*layers) def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out def SENet18():
return SENet(PreActBlock, [2,2,2,2]) net = SENet18()
y = net(torch.randn(1,3,32,32))
print(y.size())
print(net)

输出和注解我都整理了一下:

【小白学PyTorch】12 SENet详解及PyTorch实现的更多相关文章

  1. 【小白学PyTorch】11 MobileNet详解及PyTorch实现

    文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...

  2. html5--1.12表格详解

    html5--1.12表格详解 一.总结 一句话总结: 二.详解 1.表格构成三个基本要素 table:表格的范围,外框:用来定义表格,表格的其他元素包含在table标签里面: tr: 表格的行: t ...

  3. Pytorch autograd,backward详解

    平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...

  4. SENet详解及Keras复现代码

    转: SENet详解及Keras复现代码 论文地址:https://arxiv.org/pdf/1709.01507.pdf 代码地址:https://github.com/hujie-frank/S ...

  5. 【小白学PyTorch】13 EfficientNet详解及PyTorch实现

    参考目录: 目录 1 EfficientNet 1.1 概述 1.2 把扩展问题用数学来描述 1.3 实验内容 1.4 compound scaling method 1.5 EfficientNet ...

  6. 【小白学PyTorch】10 pytorch常见运算详解

    参考目录: 目录 1 矩阵与标量 2 哈达玛积 3 矩阵乘法 4 幂与开方 5 对数运算 6 近似值运算 7 剪裁运算 这一课主要是讲解PyTorch中的一些运算,加减乘除这些,当然还有矩阵的乘法这些 ...

  7. Pytorch数据读取详解

    原文:http://studyai.com/article/11efc2bf#%E9%87%87%E6%A0%B7%E5%99%A8%20Sampler%20&%20BatchSampler ...

  8. javaweb基础(12)_session详解

    一.Session简单介绍 在WEB开发中,服务器可以为每个用户浏览器创建一个会话对象(session对象),注意:一个浏览器独占一个session对象(默认情况下).因此,在需要保存用户数据时,服务 ...

  9. 【Linux】一步一步学Linux——Linux系统目录详解(09)

    目录 00. 目录 01. 文件系统介绍 02. 常用目录介绍 03. /etc目录文件 04. /dev目录文件 05. /usr目录文件 06. /var目录文件 07. /proc 08. 比较 ...

随机推荐

  1. Mybatis 中判断参数长度

    <if test="params.length()!=2">

  2. Robot Framework(8)——脚本语法示例记录

    大神写了一个Robot Framework的脚本,好多语法之前没接触过,就有了这篇,记录下来一起学习,欢迎纠错 第二三四五列,一般是入参,红色的表示必填的入参.浅灰色表示选填的入参.深灰色表示无需填写 ...

  3. 第五篇Scrum冲刺博客--Interesting-Corps

    第五篇Scrum冲刺博客 站立式会议 1.会议照片 2.队友完成情况 团队成员 昨日完成 今日计划 鲍鱼铭 音乐详情页面跳转.设计及布局实现设计 搜索页面以及音乐详情页面数据导入及测试 叶学涛 编写分 ...

  4. PyTorch学习笔记及问题处理

    1.torch.nn.state_dict(): 返回一个字典,保存着module的所有状态(state). parameters和persistent_buffers都会包含在字典中,字典的key就 ...

  5. TCP/IP的三次握手, 四次挥手

    三次握手: 1. X初始序号, SYN:   , 发送  将syn=1, X发送至client 2. 服务器发送 ACK(确认包)=1, SYN=1, 接受顺序号(acknowledge number ...

  6. 细说强网杯Web辅助

    本文首发于“合天智汇”公众号 作者:Ch3ng 这里就借由强网杯的一道题目“Web辅助”,来讲讲从构造POP链,字符串逃逸到最后获取flag的过程 题目源码 index.php 获取我们传入的user ...

  7. windows下安装jdk+tomcat+maven并配置

    一.下载安装jdk并配置 1.1 进行JDK下载 下载地址:一键直达 一般下载后,安装位置默认,一路下一步,一直到安装完毕-"关闭". 1.2 环境变量配置 不要管是不是一般情况, ...

  8. [U3D + GAD]Egametang开源服务器框架资源管理系统

    Egametang开源服务器框架资源管理系统详解 http://m.gad.qq.com/article/detail/36409 ET GitHub https://github.com/egame ...

  9. 安装Android Studio之后无法直接打开SDK Manager

    之前安装的android studio之后,SDK Manager和AVD Manager两个运行程序双击都打不开页面了,之前都是正常的,所以java环境变量的问题是不存在的. SDK Manager ...

  10. AcWing 285. 没有上司的舞会(树形dp入门)

    Ural大学有N名职员,编号为1~N. 他们的关系就像一棵以校长为根的树,父节点就是子节点的直接上司. 每个职员有一个快乐指数,用整数 HiHi 给出,其中 1≤i≤N1≤i≤N. 现在要召开一场周年 ...