MnasNet:经典轻量级神经网络搜索方法 | CVPR 2019
论文提出了移动端的神经网络架构搜索方法,该方法主要有两个思路,首先使用多目标优化方法将模型在实际设备上的耗时融入搜索中,然后使用分解的层次搜索空间,来让网络保持层多样性的同时,搜索空间依然很简洁,能够使得搜索的模型在准确率和耗时中有更好的trade off
来源:【晓飞的算法工程笔记】 公众号
论文: MnasNet: Platform-Aware Neural Architecture Search for Mobile
- 论文地址:https://arxiv.org/abs/1807.11626
- 代码地址:https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Introduction
在设计移动端卷积网络时,常常会面临着速度与准确率的取舍问题,为了设计更好的移动端卷积网络,论文提出移动网络的神经网络架构搜索方法,大概步骤如图1所示。对比之前的方法,该方法主贡献有3点:
- 将设计问题转化为多目标优化问题(multi-objective optimization),同时考虑准确率和实际推理耗时。由于计算量FLOPS其实和实际推理耗时并不总是一致的(MobileNet,575M,113ms vs NASNet,564 M,183ms),所以论文通过实际移动设备上运行来测量推理耗时
- 之前的搜索方法大都先搜索最优的单元,然后堆叠成网络,虽然这样能优化搜索空间,但抑制了层多样性。为了解决这个问题,论文提出分解的层次搜索空间(factorized hierarchical search space),使得层能存在结构差异的同时,仍然很好地平衡灵活性和搜索空间大小
- 在符合移动端使用的前提下,达到ImageNet和COCO的SOTA,且速度更快,模型更轻量。如图2所示,在准确率更高的前提下,MansNet速度比MobieNet和NASNet-A分别快1.8倍和2.3倍
Problem Formulation
对于模型$m$,$ACC(m)$为模型准确率,$LAT(m)$为目标移动平台的推理耗时,$T$为目标耗时,公式1为在符合耗时前提下,最大化准确率
但公式1仅最优准确率,没有进行多目标优化(multiple Pareto optimal),于是论文改用公式2的加权乘积方法来近似进行多目标优化
$w$是权重因子,$\alpha$和$\beta$为应用特定常数(application-specific constants),这两个值的设定是用来保证符合accuracy-latency trade-offs的有相似的reward,即高准确率稍高耗时和稍低准确率低耗时有相同的reward。例如,凭经验认为两倍耗时通常能带来5%准确率提升,对于模型M1(耗时$l$,准确率$a$),模型M2(耗时$2l$,准确率$a(1+5%)$),他们应该有相同的reward:$Reward(M2)=a\cdot (1+5%)\cdot (2l/T)^\beta\approx Reward(M1)=a\cdot (l/T)^\beta$,得到$\beta=-0.07$。后面实验没说明都使用$\alpha=\beta=-0.07$
图3为不同常数下的目标函数曲线,上图$(\alpha=0,\beta=-1)$意味着符合耗时的直接输出准确率,超过耗时的则大力惩罚,下图$(\alpha=\beta=-0.07)$则是将耗时作为软约束,平滑地调整目标函数
Mobile Neural Architecture Search
Factorized Hierarchical Search Space
论文提出分别的层次搜索空间,整体构造如图4所示,将卷积神经网络模型分解成独立的块(block),逐步降低块的输入以及增加块中的卷积核数。每个块进行独立块搜索,每个块包含多个相同的层,由块搜索来决定。搜索的目的是基于输入和输出的大小,选择最合适的算子以及参数(kernal size, filter size)来达到更好的accurate-latency trade-off
每个块的子搜索包含上面6个步骤,例如图4中的block 4,每层都为inverted bottleneck 5x5 convolution和residual skip path,共$N_4$层
搜索空间选择使用MobileNetV2作为参考,图4的block数与MobileNetV2对应,MobileNetV2的结构如上。在MobileNetV2的基础上,每个block的layer数量进行${0,+1,-1}$进行加减,而卷积核数则选择${0.75,1.0,1.25}$
论文提出的分解的层次搜索空间对于平衡层多样性和搜索空间大小有特别的好处,假设共$B$blocks,每个block的子搜索空间大小为$S$,平均每个block有$N$层,总共的搜索空间则为$SB$,对比按层搜索的空间$S{B*N}$小了很多
Search Algorithm
论文使用NAS的强化学习方法来优化公式2的rewadr期望,在每一轮,controller根据当前参数$\theta$一批模型,每个模型$m$训练后获得准确率$ACC(m)$以及实际推理耗时$LAT(m)$,根据公式2得到reward,然后使用Proximal Policy Optimization来更新controller的参数$\theta$最大化公式5
Experimental Setup
论文先尝试在CIFAR-10上进行架构搜索,然后迁移到大数据集上,但是发现这样不奏效,因为考虑了实际耗时,而应用到大数据集时,网络通常需要放大,耗时就不准确了。因此,论文直接在ImageNet上进行搜索,但每个模型只训练5轮来加速。RNN controller与NASNet保持一致,总共需要64 TPUv2搜索4.5天,每个模型使用Pixel 1手机进行耗时测试,最终大概测试了8K个模型,分别选择了top 15和top 1模型进行完整的ImageNet训练以及COCO迁移,输入图片的分辨率分别为$224\times 224$和$320\times 320$
Results
ImageNet Classification Performance
$T=75ms$,$\alpha=\beta=-0.07$,结果如Table 1所示,MnasNet比MobileNetV2(1.4)快1.8倍,准0.5%,比NASNet-A快2.3倍,准1.2%,而稍大的模型MnasNet-A3比ResNet-50准,但少用了4.8x参数和10x计算量
由于之前的方法没有使用SE模块,论文补充了个对比训练,MnasNet效果依然比之前的方法要好
Model Scaling Performance
缩放模型是调整准确率和耗时的来适应不同设备的常见操作,可以使用depth multiplier(好像叫width multiplier?)来缩放每层的channels数,也可以直接降低输入图片的分辨率。从图5可以看到,MansNet始终保持着比MobileNetV2好的表现
此外,论文提出的方法能够搜索不同耗时的模型,为了比较性能,论文对比了缩放模型和搜索模型的准确率。从Table4看出,搜索出来的模型有更好的准确率
COCO Object Detection Performance
论文对比了MnasNet在COCO上的表现,可以看到MnasNet准确率更高,且少用了7.4x参数和42x计算量
Ablation Study and Discussion
Soft vs. Hard Latency Constraint
多目标搜索方法允许通过设定$\alpha$和$\beta$进行hard和soft的耗时约束,图6展示了$(\alpha=0,\beta=-1)$和$(\alpha=\beta=-0.07)$,目标耗时为75ms,可以看到soft搜索更广的区域,构建了很多接近75ms耗时的模型,也构建了更多小于40ms和大于110ms的模型
Disentangling Search Space and Reward
论文将多目标优化和分解的层次搜索空间进行对比实验,从结果来看,多目标优化能很好平衡低耗和准确率,而论文提出的搜索空间能同时降低耗时和提高准确率
MnasNet Architecture and Layer Diversity
图7(a)为MnasNet-A1的结构,包含了不同的层结构,可以看到该网络同时使用了5x5和3x3的卷积,之前的方法都只使用了3x3卷积
Table 6展示了MansNet模型及其变体,变体上仅用某一层的来构建网络,可以看到MnasNet在准确率和耗时上有了更好的trade-off
CONCLUSION
论文提出了移动端的神经网络架构搜索方法,该方法使用多目标优化方法将模型在实际设备上的耗时融入搜索中,能够使得搜索的模型在准确率和耗时中有更好的trade off。另外该方法使用分解的层次搜索空间,来让网络保持层多样性的同时,搜索空间依然很简洁,也提高了搜索网络的准确率。从实验结果来看,论文搜索到的网络MansNet在准确率和耗时上都比目前的人工构建网络和自动搜索网络要好
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】
MnasNet:经典轻量级神经网络搜索方法 | CVPR 2019的更多相关文章
- CVPR 2019轨迹预测竞赛冠军方法总结
背景 CVPR 2019 是机器视觉方向最重要的学术会议,本届大会共吸引了来自全世界各地共计 5160 篇论文,共接收 1294 篇论文,投稿数量和接受数量都创下了历史新高,其中与自动驾驶相关的论文. ...
- 自动驾驶研究回顾:CVPR 2019摘要
我们相信开发自动驾驶技术是我们这个时代最大的工程挑战之一,行业和研究团体之间的合作将扮演重要角色.由于这个原因,我们一直在通过参加学术会议,以及最近推出的自动驾驶数据集和基于语义地图的3D对象检测的K ...
- CVPR 2019 | 用异构卷积训练深度CNN:提升效率而不损准确度
对于深度卷积神经网络而言,准确度和计算成本往往难以得兼,研究界也一直在探索通过模型压缩或设计新型高效架构来解决这一问题.印度理工学院坎普尔分校的一篇 CVPR 论文则给出了一个新的思路——使用异构的卷 ...
- CVPR 2019细粒度图像分类竞赛中国团队DeepBlueAI获冠军 | 技术干货分享
[导读]CVPR 2019细粒度图像分类workshop的挑战赛公布了最终结果:中国团队DeepBlueAI获得冠军.本文带来冠军团队解决方案的技术分享. 近日,在Kaggle上举办的CVPR 201 ...
- Relation-Shape Convolutional Neural Network for Point Cloud Analysis(CVPR 2019)
代码:https://github.com/Yochengliu/Relation-Shape-CNN 文章:https://arxiv.org/abs/1904.07601 作者直播:https:/ ...
- zz先睹为快:神经网络顶会ICLR 2019论文热点分析
先睹为快:神经网络顶会ICLR 2019论文热点分析 - lqfarmer的文章 - 知乎 https://zhuanlan.zhihu.com/p/53011934 作者:lqfarmer链接:ht ...
- CVPR 2019 行人检测新思路:
CVPR 2019 行人检测新思路:高级语义特征检测取得精度新突破 原创: CV君 我爱计算机视觉 今天 点击我爱计算机视觉置顶或标星,更快获取CVML新技术 今天跟大家分享一篇昨天新出的CVPR 2 ...
- TensorFlow实战之实现AlexNet经典卷积神经网络
本文根据最近学习TensorFlow书籍网络文章的情况,特将一些学习心得做了总结,详情如下.如有不当之处,请各位大拿多多指点,在此谢过. 一.AlexNet模型及其基本原理阐述 1.关于AlexNet ...
- 经典卷积神经网络(LeNet、AlexNet、VGG、GoogleNet、ResNet)的实现(MXNet版本)
卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现. 其中 文章 详解卷 ...
随机推荐
- TensorFlow从0到1之TensorFlow实现单层感知机(20)
简单感知机是一个单层神经网络.它使用阈值激活函数,正如 Marvin Minsky 在论文中所证明的,它只能解决线性可分的问题.虽然这限制了单层感知机只能应用于线性可分问题,但它具有学习能力已经很好了 ...
- mysql关于group by的用法
原文:https://blog.csdn.net/u014717572/article/details/80687042 先来看下表1,表名为test: 表1 执行如下SQL语句: SELECT na ...
- Android学习笔记样式资源文件
样式资源和主题资源都是写在styles.xml文件里面的 <style name="title"> <item name="android:textSi ...
- Java学习笔记6(集合类)
集合类 集合按照其存储结构可以分为两大类,即单列集合Collection和双列集合Map. Collection:单列集合类的根接口,用于存储一系列符合某种规则的元素,有List和Set两个重要子接口 ...
- Shiro密码重试次数限制
如在 1 个小时内密码最多重试 5 次,如果尝试次数超过 5 次就锁定 1 小时,1 小时后可再次重试,如果还是重试失败,可以锁定如 1 天,以此类推,防止密码被暴力破解.我们通过继承 HashedC ...
- JavaWeb网上图书商城完整项目--12.项目所需jquery函数介绍之ajax
jquery中使用ajax发送异步请求 下面的一个案例在input输入框失去焦点的时候发送一个异步的请求: 我们来看程序的案例: 这里要强调的是返回值最好选择是json,json对应的就是对象,Jav ...
- java 中的线程池
1.实现下面的一个需求,控制一个执行函数只能被五个线程访问 package www.weiyuan.test; public class Test { public static void main( ...
- 9、ssh的集成方式1
集成方式1:核心 我们没有创建applicationContext-action.xml配置文件,在该配置文件里面让Spring去管理我们的AddUserAction,但是AddUserAction的 ...
- redis基础二----操作hash
上面usr就是hash的名字,usr这个hash中存储了key 为id.name和age的值 一个hash相当于一个数据对象,里面可以存储key为id name age的值 2.批量插入一个hash数 ...
- Redis:rdb和aof
由于redis的数据都直接存储在内存里,在服务器发生宕机时内存的数据会瞬间清空,那么必须要有重启时恢复数据的方法. redis通过持久化机制将数据存储到磁盘中从而在服务器重启时恢复数据,这篇文章主要简 ...