论文地址

channel pruning是指给定一个CNN模型,去掉卷积层的某几个输入channel以及相应的卷积核,
并最小化裁剪channel后与原始输出的误差。

可以分两步来解决:

  1. channel selection
    利用LASSO回归裁剪掉多余的channel,求出每个channel的权重,如果为0即是被裁减。
  2. feature map reconstruction
    利用剩下的channel重建输出,直接使用最小平方误差来拟合原始卷积层的输出,求出新的卷积核W。

二、优化目标

2.1 定义优化目标

输入c个channel,输出n个channel,卷积核W的大小是

我们对输入做了采样,假设对每个输入,对channel采样出来N个大小为块

为了把输入channel从c个剪裁到c’个,我们定义最优化的目标为

其中是每个channel的权重向量,如果是0,意味着裁剪当前channel,相应的也被裁减。

2.2 求最优目标

为了最优化目标,分为如下两步

2.2.1 固定W,求

其中,大小是,

这里之所以加上关于的L1正则项,是为了避免所有的都为1,而是让它们趋于0。

2.2.2 固定,求W

利用剩下的channel重建输出,直接求最小平方误差

其中,大小为,
W’也被reshape为。

2.2.3 多分支的情况

论文只考虑了常见的残差网络,设residual分支的输出为,shortcut 分支的输出为。

这里首先在residual分支的第一层前做了channel采样,从而减少计算量(训练过程中做的,即filter layer)。

设为原始的上一层的输出,
那么channel pruning中,residual分支的输出拟合,其中是上一层裁减后的shortcut。

三、实现

实现的时候,不是按照不断迭代第一步和第二步,因为比较耗时。
而是先不断的迭代第一步,直到裁剪剩下的channel个数为c’,然后执行第二步求出最终的W。

3.1 第一步Channel Selection

如何得到LASSO回归的输入:

(1)首先把输入做转置

# (N, c, hw) --> (c, N, hw)
inputs = np.transpose(inputs, [1, 0, 2])

(2)把weigh做转置

# (n, c, hw) --> (c, hw, n)
weights = np.transpose(weights, [1, 2, 0]))

(3)最后两维做矩阵乘法

# (c, N, n), matmul apply dot on the last two dim
outputs = np.matmul(inputs, weights)

(4)把输出做reshape和转置

# (Nn, c)
outputs = np.transpose(outputs.reshape(outputs.shape[0], -1))

LASSO回归的目标值即是对应的Y,大小为

的大小影响了最终为0的个数,为了找出合适的,需要尝试不同的值,直到裁剪剩下的channel个数为为止。

为了找到合适的可以使用二分查找,
或者不断增大直到裁剪剩下的channel个数,然后降序排序取前,剩下的为0。

while True:
coef = solve(alpha)
if sum(coef != 0) < rank:
break
last_alpha = alpha
last_coef = coef
alpha = 4 * alpha + math.log(coef.shape[0])
if not fast:
# binary search until compression ratio is satisfied
left = last_alpha
right = alpha
while True:
alpha = (left + right) / 2
coef = solve(alpha)
if sum(coef != 0) < rank:
right = alpha
elif sum(coef != 0) > rank:
left = alpha
else:
break
else:
last_coef = np.abs(last_coef)
sorted_coef = sorted(last_coef, reverse=True)
rank_max = sorted_coef[rank - 1]
coef = np.array([c if c >= rank_max else 0 for c in last_coef])

3.2 第二步Feature Map Reconstruction

直接利用最小平方误差,求出最终的卷积核。

from sklearn import linear_model
def LinearRegression(input, output):
clf = linear_model.LinearRegression()
clf.fit(input, output)
return clf.coef_, clf.intercept_
pruned_weights, pruned_bias = LinearRegression(input=inputs, output=real_outputs)

3.3 一些细节

  1. 将Relu层和卷积层分离
    因为Relu一般会使用inplace操作来节省内存/显存,如果不分离开,那么得到的卷积层的输出是经过了Relu激活函数计算后的结果。

  2. 每次裁减完一个卷积层后,需要对该层的bottom和top层的输入或输出大小作相应的改变。

  3. 第一步求出后,若为0,则说明要裁减对应的channel,否则置为1,表示保留channel。

参考链接

https://github.com/yihui-he/channel-pruning

模型压缩之Channel Pruning的更多相关文章

  1. 【转载】NeurIPS 2018 | 腾讯AI Lab详解3大热点:模型压缩、机器学习及最优化算法

    原文:NeurIPS 2018 | 腾讯AI Lab详解3大热点:模型压缩.机器学习及最优化算法 导读 AI领域顶会NeurIPS正在加拿大蒙特利尔举办.本文针对实验室关注的几个研究热点,模型压缩.自 ...

  2. 【模型压缩】MetaPruning:基于元学习和AutoML的模型压缩新方法

    论文名称:MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning 论文地址:https://arxiv.org/ ...

  3. 模型压缩-Learning Efficient Convolutional Networks through Network Slimming

    Zhuang Liu主页:https://liuzhuang13.github.io/ Learning Efficient Convolutional Networks through Networ ...

  4. [论文分享]Channel Pruning via Automatic Structure Search

    authors: Mingbao Lin, Rongrong Ji, etc. comments: IJCAL2020 cite: [2001.08565v3] Channel Pruning via ...

  5. CNN 模型压缩与加速算法综述

    本文由云+社区发表 导语:卷积神经网络日益增长的深度和尺寸为深度学习在移动端的部署带来了巨大的挑战,CNN模型压缩与加速成为了学术界和工业界都重点关注的研究领域之一. 前言 自从AlexNet一举夺得 ...

  6. 论文笔记——Channel Pruning for Accelerating Very Deep Neural Networks

    论文地址:https://arxiv.org/abs/1707.06168 代码地址:https://github.com/yihui-he/channel-pruning 采用方法 这篇文章主要讲诉 ...

  7. 【DMCP】2020-CVPR-DMCP Differentiable Markov Channel Pruning for Neural Networks-论文阅读

    DMCP 2020-CVPR-DMCP Differentiable Markov Channel Pruning for Neural Networks Shaopeng Guo(sensetime ...

  8. 对抗性鲁棒性与模型压缩:ICCV2019论文解析

    对抗性鲁棒性与模型压缩:ICCV2019论文解析 Adversarial Robustness vs. Model Compression, or Both? 论文链接: http://openacc ...

  9. 模型压缩,模型减枝,tf.nn.zero_fraction,统计0的比例,等。

    我们刚接到一个项目时,一开始并不是如何设计模型,而是去先跑一个现有的模型,看在项目需求在现有模型下面效果怎么样.当现有模型效果不错需要深入挖掘时,仅仅时跑现有模型是不够的,比如,如果你要在嵌入式里面去 ...

随机推荐

  1. Vue专题-组件

    vue.js既然是框架,那就不能只是简单的完成数据模板引擎的任务,它还提供了页面布局的功能.本文详细介绍使用vue.js进行页面布局的强大工具,vue.js组件系统. Vue.js组件系统 每一个新技 ...

  2. 获取文件MD5值(JS、JAVA)

    文章HTML代码翻译于地址:https://www.cnblogs.com/linyihai/p/7040786.html           文件MD5有啥用?                  文 ...

  3. JavaScript学习笔记 - 进阶篇(6)- JavaScript内置对象

    什么是对象 JavaScript 中的所有事物都是对象,如:字符串.数值.数组.函数等,每个对象带有属性和方法. 对象的属性:反映该对象某些特定的性质的,如:字符串的长度.图像的长宽等: 对象的方法: ...

  4. 洛谷P1525 关押罪犯(并查集、二分图判定)

    本人蒟蒻,只能靠题解AC,看到大佬们的解题思路,%%%%%% https://www.luogu.org/problemnew/show/P1525 题目描述 S城现有两座监狱,一共关押着N名罪犯,编 ...

  5. 第04项目:淘淘商城(SpringMVC+Spring+Mybatis)【第八天】(solr服务器搭建、搜索功能实现)

    https://pan.baidu.com/s/1bptYGAb#list/path=%2F&parentPath=%2Fsharelink389619878-229862621083040 ...

  6. E - Ingredients 拓扑排序+01背包

    题源:https://codeforces.com/gym/101635/attachments 题意: n行,每行给定字符串s1,s2,s3代表一些菜谱名.s2和s3是煮成是的必要条件,然后给出c和 ...

  7. 对象数组和for循环遍历输出学生的信息

    public class Student { private int no; private String name; private int age; public Student(int no, ...

  8. 43)PHP,mysql_fetch_row 和mysql_fetch_assoc和mysql_fetch_array

    mysql_fetch_row   提取的结果是没有查询中的字段名了(也就是没有键id,GoodsName,只有值),如下图: mysql_fetch_assoc 提取的结果有键值,如下图: mysq ...

  9. Qt 信息提示框 QMessageBox

    information QMessageBox::information(NULL, "Title","Content",QMessageBox::Yes | ...

  10. linux的进程和管道符(二)

    回顾:进程管理:kill killall pkill问题:1.pkill -u root 禁止2.用户名不要用数字开头或者纯数字windows的用户名不要用中文3.pokit/etc/passwd 6 ...