Katharopoulos A, Fleuret F. Not All Samples Are Created Equal: Deep Learning with Importance Sampling[J]. arXiv: Learning, 2018.

@article{katharopoulos2018not,

title={Not All Samples Are Created Equal: Deep Learning with Importance Sampling},

author={Katharopoulos, Angelos and Fleuret, F},

journal={arXiv: Learning},

year={2018}}

本文提出一种删选合适样本的方法, 这种方法基于收敛速度的一个上界, 而并非完全基于gradient norm的方法, 使得计算比较简单, 容易实现.

主要内容

设\((x_i,y_i)\)为输入输出对, \(\Psi(\cdot;\theta)\)代表网络, \(\mathcal{L}(\cdot, \cdot)\)为损失函数, 目标为

\[\tag{1}
\theta^* = \arg \min_{\theta} \frac{1}{N} \sum_{i=1}^N\mathcal{L}(\Psi(x_i;\theta),y_i),
\]

其中\(N\)是总的样本个数.

假设在第\(t\)个epoch的时候, 样本(被选中)的概率分布为\(p_1^t,\ldots,p_N^t\), 以及梯度权重为\(w_1^t, \ldots, w_N^t\), 那么\(P(I_t=i)=p_i^t\)且

\[\tag{2}
\theta_{t+1}=\theta_t-\eta w_{I_t}\nabla_{\theta_t} \mathcal{L}(\Psi(x_{I_t};\theta_t),y_{I_t}),
\]

在一般SGD训练中\(p_i=1/N,w_i=1\).

定义\(S\)为SGD的收敛速度为:

\[\tag{3}
S :=-\mathbb{E}_{P_t}[\|\theta_{t+1}-\theta^*\|_2^2-\|\theta_t-\theta^*\|_2^2],
\]

如果我们令\(w_i=\frac{1}{Np_i}\) 则



定义\(G_i=w_i\nabla_{\theta_t} \mathcal{L}(\Psi(x_{i};\theta_t),y_{i})\)



我们自然希望\(S\)能够越大越好, 此时即负项越小越好.

定义\(\hat{G}_i \ge \|\nabla_{\theta_t} \mathcal{L}(\Psi(x_{i};\theta_t),y_{i})\|_2\), 既然



(7)式我有点困惑,我觉得(7)式右端和最小化(6)式的负项(\(\mathrm{Tr}(\mathbb{V}_{P_t}[G_{I_t}])+\|\mathbb{E}_{P_t}[G_{I_t}]\|_2^2\))是等价的.

于是有

最小化右端(通过拉格朗日乘子法)可得\(p_i \propto \hat{G}_i\), 所以现在我们只要找到一个\(\hat{G}_i\)即可.

这个部分需要引入神经网络的反向梯度的公式, 之前有讲过,只是论文的符号不同, 这里不多赘诉了.

注意\(\rho\)的计算是比较复杂的, 但是\(p_i \propto \hat{G}_i\), 所以我们只需要计算\(\|\cdot\|\)部分, 设此分布为\(g\).

另外, 在最开始的时候, 神经网络没有得到很好的训练, 权重大小相差无几, 这个时候是近似正态分布的, 所以作者考虑设计一个指标,来判断是否需要根据样本分布\(g\)来挑选样本. 作者首先衡量



显然当这部分足够大的时候我们可以采用分布\(g\)而非正态分布\(u\), 但是这个指标不易判断, 作者进步除以\(\mathrm{Tr}(\mathbb{V}_u[G_i])\).



显然\(\tau\)越大越好, 我们自然可以人为设置一个\(\tau_{th}\). 算法如下

最后, 个人认为这个算法能减少计算量主要是因为样本少了, 少在一开始用正态分布抽取了一部分, 所以...

"代码"

主要是\(\hat{G}_i\)部分的计算, 因为涉及到中间变量的导数, 所以需要用到retain_grad().

"""
这里只是一个例子
""" import torch
import torch.nn as nn class Net(nn.Module): def __init__(self):
super(Net, self).__init__()
self.dense = nn.Sequential(
nn.Linear(10, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
self.final = nn.ReLU() def forward(self, x):
z = self.dense(x)
z.retain_grad()
out = self.final(z)
return out, z if __name__ == "__main__": net = Net()
criterion = nn.MSELoss() x = torch.rand((2, 10))
y = torch.rand((2, 10)) out, z = net(x)
loss = criterion(out, y)
loss.backward()
print(z.grad) #这便是我们所需要的

Not All Samples Are Created Equal: Deep Learning with Importance Sampling的更多相关文章

  1. Accelerating Deep Learning by Focusing on the Biggest Losers

    目录 概 相关工作 主要内容 代码 Accelerating Deep Learning by Focusing on the Biggest Losers 概 思想很简单, 在训练网络的时候, 每个 ...

  2. (转) The major advancements in Deep Learning in 2016

    The major advancements in Deep Learning in 2016 Pablo Tue, Dec 6, 2016 in MACHINE LEARNING DEEP LEAR ...

  3. Deep Learning in R

    Introduction Deep learning is a recent trend in machine learning that models highly non-linear repre ...

  4. Summary on deep learning framework --- PyTorch

    Summary on deep learning framework --- PyTorch  Updated on 2018-07-22 21:25:42  import osos.environ[ ...

  5. [C3] Andrew Ng - Neural Networks and Deep Learning

    About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...

  6. Deep Learning 5_深度学习UFLDL教程:PCA and Whitening_Exercise(斯坦福大学深度学习教程)

    前言 本文是基于Exercise:PCA and Whitening的练习. 理论知识见:UFLDL教程. 实验内容:从10张512*512自然图像中随机选取10000个12*12的图像块(patch ...

  7. (转)WHY DEEP LEARNING IS SUDDENLY CHANGING YOUR LIFE

    Main Menu Fortune.com       E-mail Tweet Facebook Linkedin Share icons By Roger Parloff Illustration ...

  8. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  9. (转)The 9 Deep Learning Papers You Need To Know About (Understanding CNNs Part 3)

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About The 9 Deep Learning Papers You Need To Know Abo ...

随机推荐

  1. 用户体验再升级!Erda 1.2 版本正式发布

    来源|尔达 Erda 公众号 Erda v1.2 Changelog: https://github.com/erda-project/erda/blob/master/CHANGELOG/CHANG ...

  2. pymongdb入门

    Pymongo入门 安装 pip install pymongo 连接 实际就是实例化一个客户端对象,然后客户端对象中指定一个库作为库对象,库对象中的集合对象就是之后常用来执行操作的对象 1 ''' ...

  3. 爬虫系列:连接网站与解析 HTML

    这篇文章是爬虫系列第三期,讲解使用 Python 连接到网站,并使用 BeautifulSoup 解析 HTML 页面. 在 Python 中我们使用 requests 库来访问目标网站,使用 Bea ...

  4. Android 利用Settings.Global属性跨应用定义标志位

    https://blog.csdn.net/ouzhuangzhuang/article/details/82258148 需求 需要在不同应用中定义一个标志位,这里介绍下系统级别的应用和非系统级别应 ...

  5. 100个Shell脚本——【脚本4】自定义rm命令

    [脚本4]自定义rm命令 linux系统的rm命令太危险,一不小心就会删除掉系统文件. 写一个shell脚本来替换系统的rm命令,要求当删除一个文件或者目录时,都要做一个备份,然后再删除.下面分两种情 ...

  6. Largest Rectangle in Histogram及二维解法

    昨天看岛娘直播解题,看到很经典的一题Largest Rectangle in Histogram 题目地址:https://leetcode.com/problems/largest-rectangl ...

  7. Git上项目代码拉到本地方法

    1.先在本地打开workspace文件夹,或者自定义的文件夹,用来保存项目代码的地方. 2.然后登陆GitHub账号,点击复制项目路径 3.在刚才文件夹下空白处点击鼠标右键,打开Git窗口 4.在以下 ...

  8. 【Java 8】函数式接口(二)—— 四大函数接口介绍

    前言 Java8中函数接口有很多,大概有几十个吧,具体究竟是多少我也数不清,所以一开始看的时候感觉一脸懵逼,不过其实根本没那么复杂,毕竟不应该也没必要把一个东西设计的很复杂. 几个单词 在学习了解之前 ...

  9. 转:addChildViewController实现网易新闻首页切换

    本来只是打算介绍一下addChildViewController这个方法的,正好今天朋友去换工作面试问到网易新闻标签栏效果的实现,就结合它,用个小Demo实例介绍一下:(具体解释都写在了Demo里面的 ...

  10. [BUUCTF]PWN——bjdctf_2020_babyrop2

    bjdctf_2020_babyrop2 附件 步骤: 例行检查,64位程序,开启了NX和canary保护 2. 试运行一下程序,看看大概的情况 提示我们去泄露libc 3. 64位ida载入,从ma ...