【模型压缩】MetaPruning:基于元学习和AutoML的模型压缩新方法
- 论文名称:MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning
- 论文地址:https://arxiv.org/abs/1903.10258
- 开源代码:https://github.com/megvii-model/MetaPruning
目录
- 导语
- 简介
- 方法
- PruningNet Training
- Pruned-Network Search
- 实验
- Comparisons with state-of-the-arts
- 结论
- 参考文献
导语
通道剪裁(Channel Pruning)作为一种神经网络压缩/加速方法,其有效性已深获认可,并广泛应用于工业界。
一个经典的剪裁方法包含三步:1)训练一个参数过多的大型网络;2)剪裁较不重要的权重或通道;3)微调或再训练已剪裁的网络。其中第二个阶段是关键,它通常借助迭代式逐层剪裁、快速微调或者权重重建以保持精度。
卷积通道剪裁方法主要依赖于数据驱动的稀疏约束(sparsity constraints)或者人工设计的策略。最近,一些基于反馈闭环或者强化学习的 AutoML 方法可自动剪裁一个迭代模型中的通道。
相较于传统剪裁方法, AutoML 方法不仅可以节省人力,还可以帮助人们在不用知道硬件底层实现的情况下,直接为特定硬件定制化设计在满足该硬件上速度限制的最优网络结构。
MetaPruning 作为利用 AutoML 进行网络裁剪的算法之一,有着 AutoML 所共有的省时省力,硬件定制等诸多优势,同时也创新性地加入了先前 AutoML pruning 所不具备的功能,如轻松裁剪 shortcut 中的通道。
简介
过去的研究往往通过逐层裁剪一个已训练好模型中带有不重要权重的通道来达到裁剪的目的。而一项最新研究发现,不管继不继承原始网络的权重,已剪裁的网络都可获得相同精度。
这一发现表明,通道剪裁的本质是决定逐层的通道数量。基于这个,MetaPruning 跳过选择剪裁哪些通道,而直接决定每层剪裁多少通道——好的剪裁结构。
然而,可能的每层通道数组合数巨大,暴力寻找最优的剪裁结构是计算量所不支的。
受到近期的神经网络架构搜索(NAS)的启发,尤其是 One-Shot 模型,以及 HyperNetwork 中的权重预测机制,旷视研究院提出训练一个 PruningNet,它可生成所有候选的已剪裁网络结构的权重,从而仅仅评估其在验证集上的精度,即可搜索表现良好的结构。这极其有效。
PruningNet 的训练采用随机采样网络结构策略,如图 1 所示,它为带有相应网络编码向量(其数量等于每一层的通道数量)的已剪裁网络生成权重。通过在网络编码向量中的随机输入,PruningNet 逐渐学习为不同的已剪裁结构生成权重。
图 1:MetaPruning 分为两步,1)训练一个 PruningNet,2)搜索最佳的 Pruned Network
训练结束之后,研究员会借助进化算法来搜索表现较好的 Pruned Networks,进化算法中可以灵活加入不同的硬约束(hard constraints),比如浮点数运算次数(FLOPs)或者硬件运行时长(latency)。由于 PruningNet 已学会为各种不同的 Pruned Networks 提供可靠的参数,从而可轻松使用 PruningNet 为 Pruned Networks 结构填入对应参数。
这只需几秒,便可获知 Pruned Network 的精度表现,孰优孰劣,高下立现。这让通道裁剪变的极其省心省力,也是通道剪裁领域的一个新突破,称之为 MetaPruning,其贡献可以归为四个方面:
- MetaPruning 是一种用于通道剪裁的元学习方法,其核心思想是学习一个元网络(称之为 PruningNet ),为不同的剪裁结构生成权重,进而获得不同约束下的多种已剪裁网络。
- 相较于传统剪裁方法,MetaPruning 免除了笨重的超参数人工调节,并可按照想要的指标直接优化。
- 相较于其他 AutoML 方法,MetaPruning 可在搜索目标结构时轻松加入硬约束,而无需手头调节强化学习超参数。
- 在类似于 ResNet 的结构中,short-cut 中通道数往往很难裁剪,因为裁剪这些通道会影响多层。大多通道裁剪算法无法高效裁剪,而 MetaPruning 可以毫不费力剪裁 short-cut 的通道。
方法
MetaPruning 可以自动剪裁深度神经网络中的通道,已剪裁的网络可以满足不同的硬约束。
这一通道剪裁问题可表示为:
本文想要找到在权重训练结束之后,满足约束条件的损失最小的剪裁网络通道宽度组合。
为此,研究员构建了 PruningNet,为不同的剪裁网络结构生成权重,从而只需要在验证集上评估,即可快速获知剪裁网络结构的精度,排序不同剪裁网络结构的表现。接着,配合任意搜索方法便可搜索最优的剪裁网络。
具体的 PruningNet 构建及训练算法和本文采用的进化搜索算法如下:
PruningNet Training
图 2:PruningNet 的随机训练方法图示
PruningNet 包含两个全连接层。在前向传播中,它的输入是网络的编码向量(即每一层的输出通道宽度),输出则是网络的权重矩阵;同时,根据每一层的输出通道宽度构建对应的 Pruned Network。
已生成的权重矩阵被切割以匹配 Pruned Network 输入/输出通道的数量。给定一批输入图像,则可计算带有生成权重的 Pruned Network 的损失。
在反向传播中,不用更新 Pruned Networks 的权重, 而是计算 PruningNet 权重的梯度,由于 PruningNet 全连接层的输出与 Pruned Network 的前一个卷积层的输出之间 reshape 操作和卷积操作也是可微分的, PruningNet 权重的梯度可轻松通过链式法则计算。
PruningNet 是端到端可训练的,其与 Pruned Network 相连的详细结构可参见图 3。
图 3:PruningNet 架构图示
Pruned-Network Search
在 MetaPruning 使用的进化算法中,每个 Pruned Network 被对应网络向量(代表了每层通道数)编码,即 Pruned Network 的基因(Genes)。
在硬约束下,本文首先随机选择大量基因,并通过进化获得相应 Pruned Network 的精度。接着,带有最高精度的 top k 基因被选中以生成带有突变和交叉的新基因。
研究员可通过消除不合格的基因轻松施加硬约束。通过进一步重复 top k 的选择过程和新基因的生成过程,并做若干次迭代,即可获得满足硬约束,同时精度最高的基因。详细算法如下所示。
算法 1:进化搜索算法
实验
本节意在证明 MetaPruning 方法的有效性。第一,说明一下实验设置并介绍如何把 MetaPruning 应用于 MobileNet V1/V2,并可轻松泛化至其他网络结构;第二,把本文方法与一般的 pruning baselines 和当前最优的基于 AutoML 的通道剪裁方法进行对比;第三,可视化由 MetaPruning 生成的 Pruned Network;最后,借助消融实验阐明权重预测在本文方法中的有效性。本文只介绍第二部分,其他部分请参见原论文。
Comparisons with state-of-the-arts
本文把 MetaPruning 与 uniform pruning baselines 以及当前最优的 AutoML 方法做了对比,其结果如下:
表 1:把 MetaPruning top-1 精度与 MobileNet V1 的一般基线做对比
表 2:把 MetaPruning top-1 精度与 MobileNet V2 的一般基线做对比
表 3:把 MetaPruning top-1 精度与当前最优的 AutoML 方法做对比
结论
本文给出了用于模型压缩的新方法——MetaPruning,这一元学习方法有以下优势:1)它比一般的 pruning baselines 精度高很多,比其他基于 AutoML 的通道剪裁方法精度更高或更好;2)它可根据不同的约束做灵活的优化而无需额外的超参数;3)它可高效裁剪类似于 ResNet 一样带有 short-cut 的网络结构;4)整个 pipeline 极其高效。
参考文献
- Z. Liu, M. Sun, T. Zhou, G. Huang, and T. Darrell. Re- thinking the value of network pruning. arXiv preprint arXiv:1810.05270, 2018.
- G. Bender, P.-J. Kindermans, B. Zoph, V. Vasudevan, and Q. Le. Understanding and simplifying one-shot architecture search. In International Conference on Machine Learning, pages 549–558, 2018.
- D. Ha, A. Dai, and Q. V. Le. Hypernetworks. arXiv preprint arXiv:1609.09106, 2016.
- A.G.Howard,M.Zhu,B.Chen,D.Kalenichenko,W.Wang, T. Weyand, M. Andreetto, and H. Adam. Mobilenets: Effi- cient convolutional neural networks for mobile vision appli- cations. arXiv preprint arXiv:1704.04861, 2017.
- M. Sandler, A. Howard, M. Zhu, A. Zhmoginov, and L.-C. Chen. Mobilenetv2: Inverted residuals and linear bottle- necks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 4510–4520, 2018.
- Y. He, J. Lin, Z. Liu, H. Wang, L.-J. Li, and S. Han. Amc: Automl for model compression and acceleration on mobile devices. In Proceedings of the European Conference on Computer Vision (ECCV), pages 784–800, 2018.
- T.-J. Yang, A. Howard, B. Chen, X. Zhang, A. Go, M. San- dler, V. Sze, and H. Adam. Netadapt: Platform-aware neural network adaptation for mobile applications. In Proceedings of the European Conference on Computer Vision (ECCV), pages 285–300, 2018.
【模型压缩】MetaPruning:基于元学习和AutoML的模型压缩新方法的更多相关文章
- 【RS】Deep Learning based Recommender System: A Survey and New Perspectives - 基于深度学习的推荐系统:调查与新视角
[论文标题]Deep Learning based Recommender System: A Survey and New Perspectives ( ACM Computing Surveys ...
- 我用 tensorflow 实现的“一个神经聊天模型”:一个基于深度学习的聊天机器人
概述 这个工作尝试重现这个论文的结果 A Neural Conversational Model (aka the Google chatbot). 它使用了循环神经网络(seq2seq 模型)来进行 ...
- DL4NLP——词表示模型(一)表示学习;syntagmatic与paradigmatic两类模型;基于矩阵的LSA和GloVe
本文简述了以下内容: 什么是词表示,什么是表示学习,什么是分布式表示 one-hot representation与distributed representation(分布式表示) 基于distri ...
- Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(用于深度网络快速适应的元学习)
摘要:我们提出了一种不依赖模型的元学习算法,它与任何梯度下降训练的模型兼容,适用于各种不同的学习问题,包括分类.回归和强化学习.元学习的目标是在各种学习任务上训练一个模型,这样它只需要少量的训练样本就 ...
- 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)
基于深度学习和迁移学习的识花实践(转) 深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...
- 基于.net的分布式系统限流组件 C# DataGridView绑定List对象时,利用BindingList来实现增删查改 .net中ThreadPool与Task的认识总结 C# 排序技术研究与对比 基于.net的通用内存缓存模型组件 Scala学习笔记:重要语法特性
基于.net的分布式系统限流组件 在互联网应用中,流量洪峰是常有的事情.在应对流量洪峰时,通用的处理模式一般有排队.限流,这样可以非常直接有效的保护系统,防止系统被打爆.另外,通过限流技术手段,可 ...
- Jcompress: 一款基于huffman编码和最小堆的压缩、解压缩小程序
前言 最近基于huffman编码和最小堆排序算法实现了一个压缩.解压缩的小程序.其源代码已经上传到github上面: Jcompress下载地址 .在本人的github上面有一个叫Utility的re ...
- 基于深度学习的中文语音识别系统框架(pluse)
目录 声学模型 GRU-CTC DFCNN DFSMN 语言模型 n-gram CBHG 数据集 本文搭建一个完整的中文语音识别系统,包括声学模型和语言模型,能够将输入的音频信号识别为汉字. 声学模型 ...
- 准确率99%!基于深度学习的二进制恶意样本检测——瀚思APT 沙箱恶意文件检测使用的是CNN,LSTM TODO
所以我们的流程如图所示.将正负样本按 1:1 的比例转换为图像.将 ImageNet 中训练好的图像分类模型作为迁移学习的输入.在 GPU 集群中进行训练.我们同时训练了标准模型和压缩模型,对应不同的 ...
随机推荐
- beta版本——第七次冲刺
第七次冲刺 (1)SCRUM部分☁️ 成员描述: 姓名 李星晨 完成了哪个任务 编写个人信息修改界面的js 花了多少时间 3h 还剩余多少时间 0h 遇到什么困难 密码验证部分出现问题 这两天解决的进 ...
- 关于DOM事件流、DOM0级事件与DOM2级事件
一.DOM 事件模型 DOM 事件模型包括捕获和冒泡,捕获是从上往下到达目标元素,冒泡是从当前元素,也就是目标元素往上到 window 二.流 流的概念,在现今的 JavaScript 中随处可见.比 ...
- Lovers(HDU6562+线段树+2018年吉林站)
题目链接 传送门 题意 初始时有\(n\)个空串,然后进行\(q\)次操作,操作分为以下两种: wrap l r x:把\(l,r\)中的每个字符串的首尾都加入\(x\),如\(s_i=121,x=3 ...
- c和c++区别(未整理)
学习完C语言和c++比较一下他们之间的区别: c++是c语言的基础上开发的一种面向对象的编程语言,应用十分广泛,按理说c++可以编译任何c的程序,但是两者还是有细微的差别. c++在c的基础上添加了类 ...
- 转 OJDBC驱动版本区别 [ojdbc14.jar,ojdbc5.jar跟ojdbc6.jar的区别]
OJDBC版本区别 [ojdbc14.jar,ojdbc5.jar和ojdbc6.jar的区别] 在使用Oracle JDBC驱动时,有些问题你是不是通过替换不同版本的Oracle JDBC驱动来解 ...
- shell脚本awk的基本用法
AWK 1 AWK 2 3 linux取IP地址 4 5 ifconfig | grep -w inet | sed -n '1p' | awk '{print $2}' 6 7 eg: 8 9 aw ...
- width: calc(100% - 80px); 屏幕自适应方法
width: calc(100% - 80px); 屏幕自适应方法
- Window IDEA开发工具 杀死指定端口 cmd 命令行 taskkill
Windows平台 两步方法 : 1 查询端口占用,2 强行杀死进程 netstat -aon|findstr "8080" taskkill /pid 4136-t -f ...
- learning java transient 自定义序例化
public class Person implements java.io.Serializable { private String name; private transient int age ...
- ROM
ROM 是 read only memory的简称,表示只读存储器,是一种半导体存储器.只读存储器(ROM)是一种在正常工作时其存储的数据固定不变,其中的数据只能读出,不能写入,即使断电也能够保留数据 ...