squeezenet是16年发布的一款轻量级网络模型,模型很小,只有4.8M,可用于移动设备,嵌入式设备。

关于squeezenet的原理可自行阅读论文或查找博客,这里主要解读下pytorch对squeezenet的官方实现。

地址:https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py

首先定义fire模块,这是squeezenet的核心所在,降低3X3卷积的数量。

class Fire(nn.Module):

    def __init__(self, inplanes, squeeze_planes,
expand1x1_planes, expand3x3_planes):
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)#定义压缩层,1X1卷积
self.squeeze_activation = nn.ReLU(inplace=True)
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,#定义扩展层,1X1卷积
kernel_size=1)
self.expand1x1_activation = nn.ReLU(inplace=True)
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,#定义扩展层,3X3卷积
kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True) def forward(self, x):
x = self.squeeze_activation(self.squeeze(x))
return torch.cat([
self.expand1x1_activation(self.expand1x1(x)),
self.expand3x3_activation(self.expand3x3(x))
], 1)

可以看到首先定义压缩层与两个扩展层,压缩层用的是1X1卷积,扩展层是1X1卷积和3X3卷积的混合使用,网络inference的脉络是先经过压缩层,然后并行经过两个扩展层,最后将扩展层串联。

定义完核心模块,来看网络整体。

class SqueezeNet(nn.Module):

    def __init__(self, version=1.0, num_classes=1000):
super(SqueezeNet, self).__init__()
if version not in [1.0, 1.1]:
raise ValueError("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
self.num_classes = num_classes
if version == 1.0:
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256),
)
else:
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(64, 16, 64, 64),
Fire(128, 16, 64, 64),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(128, 32, 128, 128),
Fire(256, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
)
# Final convolution is initialized differently form the rest
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
final_conv,
nn.ReLU(inplace=True),
nn.AvgPool2d(13, stride=1)
) for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is final_conv:
init.normal_(m.weight, mean=0.0, std=0.01)
else:
init.kaiming_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0) def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x.view(x.size(0), self.num_classes)

首先依然是定义网络层,在这里有两个版本,差别不大,都是fire模块的堆积,最后经过全局平均池化输出1000类。这里对卷积层采用了不同的初始化策略,我还没仔细研究过,就不说了。

pytorch实现squeezenet的更多相关文章

  1. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  2. pytorch预训练

    Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...

  3. 生产与学术之Pytorch模型导出为安卓Apk尝试记录

    生产与学术 写于 2019-01-08 的旧文, 当时是针对一个比赛的探索. 觉得可能对其他人有用, 就放出来分享一下 生产与学术, 真实的对立... 这是我这两天对pytorch深度学习->a ...

  4. 深度学习框架PyTorch一书的学习-第六章-实战指南

    参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter6-实战指南 希望大家直接到上面的网址去查看代码,下面是本人的笔记 将上面地 ...

  5. 深度学习框架PyTorch一书的学习-第五章-常用工具模块

    https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...

  6. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  7. (转) The Incredible PyTorch

    转自:https://github.com/ritchieng/the-incredible-pytorch The Incredible PyTorch What is this? This is ...

  8. PyTorch源码解读之torchvision.models(转)

    原文地址:https://blog.csdn.net/u014380165/article/details/79119664 PyTorch框架中有一个非常重要且好用的包:torchvision,该包 ...

  9. PyTorch深度学习计算机视觉框架

    Taylor Guo @ Shanghai - 2018.10.22 - 星期一 PyTorch 资源链接 图像分类 VGG ResNet DenseNet MobileNetV2 ResNeXt S ...

随机推荐

  1. jquery ajax中error返回错误解决办法

    转自:https://www.jb51.net/article/72198.htm 进入百度搜索此问题,发现有人这么说了一句 Jquery中的Ajax的async默认是true(异步请求),如果想一个 ...

  2. MySQL学习笔记之一---字符编码和字符集

    前言: 一般来说,出现中文乱码,都是客户端和服务端字符集不匹配导致的原因. (默认未指定字符集创建的数据库表,都是latinl字符集, 强烈建议使用utf8字符集)   保证不出现乱码的思想:保证客户 ...

  3. python爬虫实战(2)--爬取百度贴吧

    本篇目标 1.对百度贴吧的任意帖子进行抓取 2.指定是否只抓取楼主发帖内容 3.将抓取到的内容分析并保存到文件 1.URL格式的确定 先观察百度贴吧url格式,以中南财经政法大学迎新帖为例,URL我们 ...

  4. HDU 4912 LCA + 贪心

    题意及思路 说一下为什么按LCA深度从深到浅贪心是对的.我们可以直观感受一下,一条的路径会影响以这个lca为根的这颗树中的链,而深度越深,影响范围越小,所以先选影响范围小的路径. #include & ...

  5. echarts柱状图每个柱子显示不同颜色,并且能够实现点击每种颜色影藏对应柱子的功能

    ---------------------------------------------------------代码区---------------------------------------- ...

  6. unity3d 5.6参考手册

    http://www.vfkjsd.cn/unity3d/Manual/index.html http://www.vfkjsd.cn/unity/unity3d.html

  7. Python 网络爬虫 001 (科普) 网络爬虫简介

    Python 网络爬虫 001 (科普) 网络爬虫简介 1. 网络爬虫是干什么的 我举几个生活中的例子: 例子一: 我平时会将 学到的知识 和 积累的经验 写成博客发送到CSDN博客网站上,那么对于我 ...

  8. 激光SLAM Vs 视觉SLAM

    博客转载自:https://www.leiphone.com/news/201707/ETupJVkOYdNkuLpz.html 雷锋网(公众号:雷锋网)按:本文作者SLAMTEC(思岚科技公号sla ...

  9. 前端学习笔记2017.6.12 DIV布局网页

    DIV的功能就是把网页划分成逻辑块的. 看下豆瓣东西页面的布局,我们来分析下. 按照先从上到下的原则,把这个页面分成几个块: 首先是最顶端的这个条,这是一个DIV,我们给它起个名字,叫banner 然 ...

  10. Linux kdb命令

    一.简介 Linux 内核调试器(KDB)允许您调试 Linux 内核.这个恰如其名的工具实质上是内核代码的补丁,它允许高手访问内核内存和数据结构.KDB 的主要优点之一就是它不需要用另一台机器进行调 ...