mobilenet v1

论文解读

论文地址:https://arxiv.org/abs/1704.04861

核心思想就是通过depthwise conv替代普通conv.



有关depthwise conv可以参考https://www.cnblogs.com/sdu20112013/p/11759928.html

模型结构:



类似于vgg这种堆叠的结构.

每一层的运算量



可以看到,运算量并不是与参数数量绝对成正比,当然整体趋势而言,参数量更少的模型会运算更快.

代码实现

https://github.com/marvis/pytorch-mobilenet

网络结构:

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__() def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
) def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
) self.model = nn.Sequential(
conv_bn( 3, 32, 2),
conv_dw( 32, 64, 1),
conv_dw( 64, 128, 2),
conv_dw(128, 128, 1),
conv_dw(128, 256, 2),
conv_dw(256, 256, 1),
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 1024, 2),
conv_dw(1024, 1024, 1),
nn.AvgPool2d(7),
)
self.fc = nn.Linear(1024, 1000) def forward(self, x):
x = self.model(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x

参考论文中的结构,第一层是普通的卷积层,后面接的都是可分离卷积.

这里注意groups参数的用法. 当groups=输入channel数目时,即对每个channel分别做卷积.默认groups=1,此时即为普通卷积.

训练伪代码

# create model
model = Net() # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay) # load data
train_loader = torch.utils.data.DataLoader() # train
for every epoch:
input,target=get_from_data #前向传播得到预测值
output = model(input_var) #计算loss
loss = criterion(output, target_var) #反向传播更新网络参数
optimizer.zero_grad()
loss.backward()
optimizer.step()

轻量级CNN模型mobilenet v1的更多相关文章

  1. 轻量级CNN模型之squeezenet

    SqueezeNet 论文地址:https://arxiv.org/abs/1602.07360 和别的轻量级模型一样,模型的设计目标就是在保证精度的情况下尽量减少模型参数.核心是论文提出的一种叫&q ...

  2. CNN 模型压缩与加速算法综述

    本文由云+社区发表 导语:卷积神经网络日益增长的深度和尺寸为深度学习在移动端的部署带来了巨大的挑战,CNN模型压缩与加速成为了学术界和工业界都重点关注的研究领域之一. 前言 自从AlexNet一举夺得 ...

  3. 轻量级卷积神经网络——MobileNet

    谷歌论文题目: MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 其他参考: CNN ...

  4. keras入门(三)搭建CNN模型破解网站验证码

    项目介绍   在文章CNN大战验证码中,我们利用TensorFlow搭建了简单的CNN模型来破解某个网站的验证码.验证码如下: 在本文中,我们将会用Keras来搭建一个稍微复杂的CNN模型来破解以上的 ...

  5. 总结近期CNN模型的发展(一)---- ResNet [1, 2] Wide ResNet [3] ResNeXt [4] DenseNet [5] DPNet [9] NASNet [10] SENet [11] Capsules [12]

    总结近期CNN模型的发展(一) from:https://zhuanlan.zhihu.com/p/30746099 余俊 计算机视觉及深度学习   1.前言 好久没有更新专栏了,最近因为项目的原因接 ...

  6. Farseer.net轻量级ORM开源框架 V1.x 入门篇:新版本说明

    导航 目   录:Farseer.net轻量级ORM开源框架 目录 上一篇:没有了 下一篇:Farseer.net轻量级ORM开源框架 V1.x 入门篇:数据库配置 前言 V1.x版本终于到来了.本次 ...

  7. 经典分类CNN模型系列其五:Inception v2与Inception v3

    经典分类CNN模型系列其五:Inception v2与Inception v3 介绍 Inception v2与Inception v3被作者放在了一篇paper里面,因此我们也作为一篇blog来对其 ...

  8. 从卷积拆分和分组的角度看CNN模型的演化

    博客:博客园 | CSDN | blog 写在前面 如题,这篇文章将尝试从卷积拆分的角度看一看各种经典CNN backbone网络module是如何演进的,为了视角的统一,仅分析单条路径上的卷积形式. ...

  9. 基于Pre-Train的CNN模型的图像分类实验

    基于Pre-Train的CNN模型的图像分类实验  MatConvNet工具包提供了好几个在imageNet数据库上训练好的CNN模型,可以利用这个训练好的模型提取图像的特征.本文就利用其中的 “im ...

随机推荐

  1. SUSE Ceph 增加节点、减少节点、 删除OSD磁盘等操作 - Storage6

    一.测试环境描述 之前我们已快速部署好一套Ceph集群(3节点),现要测试在现有集群中在线方式增加节点 如下表中可以看到增加节点node004具体配置 主机名 Public网络 管理网络 集群网络 说 ...

  2. Golang:线程 和 协程 的区别

    作者:林冠宏 / 指尖下的幽灵 博客:http://www.cnblogs.com/linguanh/ GitHub : https://github.com/af913337456/ 掘金:http ...

  3. vue2.0项目记住密码和用户名实例

    的今天突来兴致,试了一下将用户名和密码存在cookie和localStorage里如何实现:从代码难易程度来讲,果断选择了将用户名和密码存在localStorage里面.当然菜鸟上这么说的,楼下. 也 ...

  4. Python内置函数之enumerate() 函数

    enumerate() 函数属于python的内置函数之一: python内置函数参考文档:python内置函数 转载自enumerate参考文档:python-enumerate() 函数 描述 e ...

  5. ELK 学习笔记之 Logstash基本语法

    Logstash基本语法: 处理输入的input 处理过滤的filter 处理输出的output 区域 数据类型 条件判断 字段引用 区域: Logstash中,是用{}来定义区域 区域内,可以定义插 ...

  6. 【SQL】sql查询同一字段相同属性列的值合计

    select  type,sum(value) as valueSum from t group by type

  7. Android系统介绍与框架

    一.Andriod是什么? Android系统是Google开发的一款开源移动OS,Android中文名被国内用户俗称“安卓”.Android操作系统基于Linux内核设计,使用了Google公司自己 ...

  8. spring boot参数验证

    必须要知道 简述 JSR303/JSR-349,hibernate validation,spring validation 之间的关系 JSR303 是一项标准,JSR-349 是其的升级版本,添加 ...

  9. Web高性能动画及渲染原理(1)CSS动画和JS动画

    目录 一. CSS动画 和 JS动画 1.1 CSS动画 1.2 JS动画 1.3 小结 二. 使用Velocity.js实现动画 示例代码托管在:http://www.github.com/dash ...

  10. hibernate 搭建框架

    需要用的包 Hibernate的日志记录: * Hibernate日志记录使用了一个slf4j: * SLF4J,即简单日志门面(Simple Logging Facade for Java),不是具 ...