Accelerating Deep Learning by Focusing on the Biggest Losers
Accelerating Deep Learning by Focusing on the Biggest Losers
概
思想很简单, 在训练网络的时候, 每个样本都会产生一个损失\(\mathcal{L}(f(x_i),y_i)\), 训练的模式往往是批训练, 将一个批次\(\sum_i \mathcal{L}(f(x_i),y_i)\)所产生的损失的梯度都传回去, 然后更新参数. 本文认为, 有些样本\((x_i,y_i)\)由于重复度高, 网络很高能够识别, 使得对应的\(\mathcal{L}(f(x_i),y_i)\)相对较小, 所以设计了一种机制, 使得损失较大的样本有大概率被选中, 而不重要的样本不被选中, 以此来降低计算时间. 实验证明, 这种方法能够在保持准确率不变的前提下降低训练时间.
相关工作
作者说这个算法首先是由Are Loss Functions All the Same?提出的, 但是这篇文章只是讲了hinge loss的优势和对其它损失函数的分析.
作者说最相关的文章是Not All Samples Are Created Equal: Deep Learning with Importance Sampling, 这篇文章是从预处理(虽然也是要算loss的)的角度出发的, 理论部分较本文多一些.
主要内容
算法1的思路是很清晰的, 主要困扰在算法2概率的计算上. 假设我们以及计算了\(n\)个样本的损失, 我们将其存储起来, 假设下一个样本的损失是\(\mathcal{L}_c\), 如果这\(n\)个样本中有\(k\)个样本的损失均小于\(\mathcal{L}_c\), 则改样本被选中的概率是:
\]
其中\(s\in[0,1]\)是人为设置的, 保证每个样本都有被选中的可能.
我们还可以设置一个最大的长度\(r\), 将以往的损失存储在一个双栈中, 当\(n=r\)的时候,存储下一个损失的同时会抛弃第一个损失, 这么做能在一定程度上减少计算量.
A[样本x] --> C(网络f)
C --> D[损失l]
D--更新-->E[损失库]
D-->F[计算概率]
F-->G(形成batch)
G--反向传递-->C
E-->F
从最开始的图中, 第二列就是表示这个算法, 第三列是在此基础上对前向传递进行一些处理. 直接的是, 每隔\(n\)次epoches更新一次损失, 然后中间的n-1次不更新损失, 直接用旧的损失对样本选择(应该是直接在传入网络就将样本选择好否则就不能降低时间了).
在随机算法中, 有单通道选择样本的一个算法, 但是这个算法只用于选择一个. 所以如果选择很多这个算法就没用了, 感觉一次性选择很多个不好弄.
代码
因为条件限制, 代码并没有测试过, 论文也给出了很棒的代码.
"""
OptInput.py
纯粹是为了便于交互一些, 直接用argparse也可以
"""
class Unit:
def __init__(self, command, type=str,
default=None):
if default is None:
default = type()
self.command = command
self.type = type
self.default = default
class Opi:
"""
>>> parser = Opi()
>>> parser.add_opt(command="lr", type=float)
>>> parser.add_opt(command="epochs", type=int)
"""
def __init__(self):
self.store = []
self.infos = {}
def add_opt(self, **kwargs):
self.store.append(
Unit(**kwargs)
)
def acquire(self):
s = "Acquire args {0.command} [" \
"type:{0.type.__name__} " \
"default:{0.default}] : "
for unit in self.store:
while True:
inp = input(s.format(
unit
))
try:
if inp: #若有输入
inp = unit.type(inp)
else:
inp = unit.default
self.infos.update(
{unit.command:inp}
)
self.__setattr__(unit.command, inp)
break
except:
print("Type {0} should be given".format(
unit.type.__name__
))
if __name__ == "__main__":
parser = Opi()
parser.add_opt(command = "x", type=int)
parser.add_opt(command="y", type=str)
parser.acquire()
print(parser.infos)
print(parser.x)
'''
calcprob.py
计算概率
'''
import collections
class Calcprob:
def __init__(self, beta, sample_min, max_len=3000):
assert 0. <= sample_min <= 1., "Invalid sample_min"
assert beta > 0, "Invalid beta"
self.beta = beta
self.sample_min = sample_min
self.max_len = max_len
self.history = collections.deque(maxlen=max_len)
self.num_slot = 1000
self.hist = [0] * self.num_slot
self.count = 0
def update_history(self, losses):
"""
BoundedHistogram
:param losses:
:return:
"""
for loss in losses:
assert loss > 0
if self.count is self.max_len:
loss_old = self.history.popleft()
slot_old = int(loss_old * self.num_slot) % self.num_slot
self.hist[slot_old] -= 1
else:
self.count += 1
self.history.append(loss)
slot = int(loss * self.num_slot) % self.num_slot
self.hist[slot] += 1
def get_probability(self, loss):
assert loss > 0
slot = int(loss * self.num_slot) % self.num_slot
prob = sum(self.hist[:slot]) / self.count
assert isinstance(prob, float), "int division error..."
return prob ** self.beta
def calc_probability(self, losses):
if isinstance(losses, float):
losses = (losses, )
self.update_history(losses)
probs = (
max(
self.get_probability(loss),
self.sample_min
)
for loss in losses
)
return probs
def __call__(self, losses):
return self.calc_probability(losses)
if __name__ == "__main__":
pass
'''
selector.py
'''
import calcprob
import numpy as np
class Selector:
def __init__(self, batch_size,
beta, sample_min, max_len=3000):
self.batch_size = batch_size
self.calcprob = calcprob.Calcprob(beta,
sample_min,
max_len)
self.reset()
def backward(self):
loss = sum(self.batch)
loss.backward()
self.reset()
def reset(self):
self.batch = []
self.length = 0.
def select(self, losses):
probs = self.calcprob(losses)
for i, prob in enumerate(probs):
if np.random.rand() < prob:
self.batch.append(losses[i])
self.length += 1
if self.length >= self.batch_size:
self.backward()
def __call__(self, losses):
self.select(losses)
'''
main.py
'''
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import selector
class Train:
def __init__(self, model, lossfunc,
bpsize, beta, sample_min, max_len=3000,
lr=0.01, momentum=0.9, weight_decay=0.0001):
self.net = self.choose_net(model)
self.criterion = self.choose_lossfunc(lossfunc)
self.opti = torch.optim.SGD(self.net.parameters(),
lr=lr, momentum=momentum,
weight_decay=weight_decay)
self.selector = selector.Selector(bpsize, beta,
sample_min, max_len)
self.gpu()
self.generate_path()
self.acc_rates = []
self.errors = []
def choose_net(self, model):
net = getattr(
torchvision.models,
model,
None
)
if net is None:
raise ValueError("no such model")
return net()
def choose_lossfunc(self, lossfunc):
lossfunc = getattr(
nn,
lossfunc,
None
)
if lossfunc is None:
raise ValueError("no such lossfunc")
return lossfunc
def gpu(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Let'us use %d GPUs" % torch.cuda.device_count())
self.net = nn.DataParallel(self.net)
self.net = self.net.to(self.device)
def generate_path(self):
"""
生成保存数据的路径
:return:
"""
try:
os.makedirs('./paras')
os.makedirs('./logs')
os.makedirs('./infos')
except FileExistsError as e:
pass
name = self.net.__class__.__name__
paras = os.listdir('./paras')
logs = os.listdir('./logs')
infos = os.listdir('./infos')
number = max((len(paras), len(logs), len(infos)))
self.para_path = "./paras/{0}{1}.pt".format(
name,
number
)
self.log_path = "./logs/{0}{1}.txt".format(
name,
number
)
self.info_path = "./infos/{0}{1}.npy".format(
name,
number
)
def log(self, strings):
"""
运行日志
:param strings:
:return:
"""
# a 往后添加内容
with open(self.log_path, 'a', encoding='utf8') as f:
f.write(strings)
def save(self):
"""
保存网络参数
:return:
"""
torch.save(self.net.state_dict(), self.para_path)
def derease_lr(self, multi=0.96):
"""
降低学习率
:param multi:
:return:
"""
self.opti.param_groups[0]['lr'] *= multi
def train(self, trainloder, epochs=50):
data_size = len(trainloder) * trainloder.batch_size
part = int(trainloder.batch_size / 2)
for epoch in range(epochs):
running_loss = 0.
total_loss = 0.
acc_count = 0.
if (epoch + 1) % 8 is 0:
self.derease_lr()
self.log(#日志记录
"learning rate change!!!\n"
)
for i, data in enumerate(trainloder):
imgs, labels = data
imgs = imgs.to(self.device)
labels = labels.to(self.device)
out = self.net(imgs)
_, pre = torch.max(out, 1) #判断是否判断正确
acc_count += (pre == labels).sum().item() #加总对的个数
losses = (
self.criterion(out[i], labels[i])
for i in range(len(labels))
)
self.opti.zero_grad()
self.selector(losses) #选择
self.opti.step()
running_loss += sum(losses).item()
if (i+1) % part is 0:
strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
epoch, i, running_loss / part
)
self.log(strings)#日志记录
total_loss += running_loss
running_loss = 0.
self.acc_rates.append(acc_count / data_size)
self.errors.append(total_loss / data_size)
self.log( #日志记录
"Accuracy of the network on %d train images: %d %%\n" %(
data_size, acc_count / data_size * 100
)
)
self.save() #保存网络参数
#保存一些信息画图用
np.save(self.info_path, {
'acc_rates': np.array(self.acc_rates),
'errors': np.array(self.errors)
})
if __name__ == "__main__":
import OptInput
args = OptInput.Opi()
args.add_opt(command="model", default="resnet34")
args.add_opt(command="lossfunc", default="CrossEntropyLoss")
args.add_opt(command="bpsize", default=32)
args.add_opt(command="beta", default=0.9)
args.add_opt(command="sample_min", default=0.3)
args.add_opt(command="max_len", default=3000)
args.add_opt(command="lr", default=0.001)
args.add_opt(command="momentum", default=0.9)
args.add_opt(command="weight_decay", default=0.0001)
args.acquire()
root = "C:/Users/pkavs/1jupiterdata/data"
trainset = torchvision.datasets.CIFAR10(root=root, train=True,
download=False,
transform=transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=8,
pin_memory=True)
dog = Train(**args.infos)
dog.train(train_loader, epochs=1000)
Accelerating Deep Learning by Focusing on the Biggest Losers的更多相关文章
- Deep Learning 27:Batch normalization理解——读论文“Batch normalization: Accelerating deep network training by reducing internal covariate shift ”——ICML 2015
这篇经典论文,甚至可以说是2015年最牛的一篇论文,早就有很多人解读,不需要自己着摸,但是看了论文原文Batch normalization: Accelerating deep network tr ...
- Applied Deep Learning Resources
Applied Deep Learning Resources A collection of research articles, blog posts, slides and code snipp ...
- [C3] Andrew Ng - Neural Networks and Deep Learning
About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...
- 【深度学习Deep Learning】资料大全
最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books by Yoshua Bengio, Ian Goodfellow and Aaron C ...
- (转) Awesome - Most Cited Deep Learning Papers
转自:https://github.com/terryum/awesome-deep-learning-papers Awesome - Most Cited Deep Learning Papers ...
- deep learning 的综述
从13年11月初开始接触DL,奈何boss忙or 各种问题,对DL理解没有CSDN大神 比如 zouxy09等 深刻,主要是自己觉得没啥进展,感觉荒废时日(丢脸啊,这么久....)开始开文,即为记录自 ...
- (转)分布式深度学习系统构建 简介 Distributed Deep Learning
HOME ABOUT CONTACT SUBSCRIBE VIA RSS DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...
- Machine and Deep Learning with Python
Machine and Deep Learning with Python Education Tutorials and courses Supervised learning superstiti ...
- The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near
The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near ...
随机推荐
- k8s配置中心-configmap,Secret密码
目录 k8s配置中心-configmap,Secret 创建ConfigMap 使用ConfigMap subPath参数 Secret 官方文档 编写secret清单 使用secret 在 Pod ...
- Rest使用get还是post
1. get是从服务器上获取数据,post是向服务器传送数据. 2. get是把参数数据队列加到提交表单的ACTION属性所指的URL中,值和表单内各个字段一一对应,在URL中可以看到.post是通过 ...
- 如果你不想让pthread_join阻塞你的进程,那么请调用pthread_detach
如果你不想让pthread_join阻塞你的进程,那么请调用pthread_detach 2016年01月13日 16:04:20 炸鸡叔 阅读数:7277 转发自:http://baike.ba ...
- oracle(数据文件)
--创建数据文件 create tablespace--创建表空间同时创建数据文件 create temporary tablespace --创建临时表空间的同时创建临时数据文件 alter tab ...
- Restful、SOAP、RPC、SOA、微服务之间的区别
什么是Restful Restful是一种架构设计风格,提供了设计原则和约束条件,而不是架构,而满足这些约束条件和原则的应用程序或设计就是 Restful架构或服务. 主要的设计原则: 资源与URI ...
- springmvc框架找那个@responseBody注解
<%@ page contentType="text/html;charset=UTF-8" language="java" %><html& ...
- Spring Boot对静态资源的映射规则
规则一:所有 " /webjars/** " 请求都去classpath:/META-INF/resources/webjars/找资源 webjars:以jar包的方式引入静态资 ...
- 学习 27 门编程语言的长处,提升你的 Python 代码水平
Python猫注:Python 语言诞生 30 年了,如今的发展势头可谓如火如荼,这很大程度上得益于其易学易用的优秀设计,而不可否认的是,Python 从其它语言中偷师了不少.本文作者是一名资深的核心 ...
- ios消息队列APNS实现和证书申请
iOS消息推送的工作机制可以简单的用下图来概括: Provider是指某个iPhone软件的Push服务器,APNS是Apple Push Notification Service的缩写,是苹果的服务 ...
- bjdctf_2020_babyrop2
这道题是一道基本题,正因为它经典,所以需要重点记录一下. 这道题考察格式化字符串泄露canary,然后rop获得libc版本,之后拿到shell.拿到程序之后我们先检查一下保护... 开启了堆栈不可执 ...