CoRR 2018 | Horovod: Fast and Easy Distributed Deep Learning in Tensorflow
将深度学习模型的训练从单GPU扩展到多GPU主要面临以下问题:(1)训练框架必须支持GPU间的通信,(2)用户必须更改大量代码以使用多GPU进行训练。为了克服这些问题,本文提出了Horovod,它通过Ring Allreduce实现高效的GPU间通信,而且仅仅更改少量代码就可以实现多GPU训练。
TensorFlow中提供了一些分布式训练的API,这些API适用于不同的环境。这就导致用户往往不知道如何更改代码以进行分布式训练,而且debug也很困难。再者,TensorFlow的分布式训练性能与理想的性能相差甚远,尤其是在大规模GPU环境下。如图1所示,随着GPU数量的增加,分布式TensorFlow的吞吐量与理想的吞吐量的差距逐渐增加,加速比逐渐降低。
因为目前单GPU可以容纳大部分深度学习模型,所以本文主要针对数据并行进行优化。首先来看一下数据并行的训练过程:
- 运行多个模型副本
(a) 读取一部分数据
(b) 把数据喂给模型,进行前向传播
(c) 反向传播,计算梯度 - 将多个模型的梯度进行平均
- 更新模型
- 重复上述步骤直到模型收敛
在标准的TensorFlow中,分布式训练使用参数服务器架构,如图3所示。在参数服务器架构中,主要有worker和server两种角色。worker负责处理数据,计算梯度然后把梯度传给server;server负责聚合梯度,更新模型,然后把模型传回worker。
在这上述两种模式下,主要有以下两个挑战:
- 如何确定worker和server的数量。如果只使用1台server,那么这台server可能成为计算和网络瓶颈;如果使用多台server,那么通信模式就类似于all-to-all,这样就不能完整利用网络带宽。
- 处理愈加复杂的TensorFlow程序。在TensorFlow中,必须显式地启动worker和server,传递一堆参数然后更新代码,这就使得分布式训练变得非常繁琐复杂。
所幸的是,2017年百度提出了一种名为Ring Allreduce的算法。在该算法中,所有worker组成一个环,每台worker只和相邻的两台worker通信,如图4所示。
在Ring Allreduce中,如果有\(N\)个节点,那么每个节点会通信\(2\times (N -1)\)次:前\(N-1\)次接收值并把它加到对应的buffer中,后\(N-1\)次接收并替换对应buffer中的值。Ring Allreduce算法是带宽最优的,也就是说,当buffer足够大时,它会最大限度地利用网络带宽。
综上所述,本文取长补短,使用Ring Allreduce算法优化TensorFlow的分布式训练过程。本文的实现流程如下:
- 将代码转换成独立的Python包,名为Horovod
- 将百度的Ring Allreduce实现替换为NCCL
- 增加了对单机多GPU训练的支持
- 根据反馈更新了部分API,还实现了一个广播操作,以在所有worker上进行强制一致性初始化
import tensorflow as tf
import horovod.tensorflow as hvd
# Initialize Horovod
hvd.init()
# Pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())
# Build model...
loss = ...
opt = tf.train.AdagradOptimizer(0.01)
# Add Horovod Distributed Optimizer
opt = hvd.DistributedOptimizer(opt)
# Add hook to broadcast variables from rank 0 to all other process
# during initialization.
hooks = [hvd.BroadcastGlobalVariablesHook(0)]
train_op = opt.minimize(loss)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing
# when done or an error occurs.
with tf.train.MonitoredTrainingSession(checkpoint="/tmp/train_logs",
config=config, hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Perform synchronous training
mon_sess.run(trian_op)
此外,Horovod还提供了一个名为Timeline的分析工具,它可以让用户每个节点在每次迭代时做了什么,效果如图5所示。
使用Timeline对一些模型进行分析后,发现当张量较小时,Ring Allreduce的效率并不高。因此,本文提出一种名为张量融合的技术来解决上述问题。
- 检测哪些张量将会被规约,选择适合缓冲区并具有相同数据类型的前几个张量
- 申请张量融合所需的缓冲区(如果之前没有申请的话),默认大小为64M
- 将选择的张量拷贝到融合缓冲区
- 在融合缓冲区执行allreduce操作
- 将数据从融合缓冲区拷贝到输出张量
- 重复上述步骤直到环中没有要被规约的向量
使用Horovod之后,Inception V3和ResNet-101模型的性能提升了约88%,如图6所示。
如图7,RDMA网络并没有比传统的TCP提升多少性能,只提升了约4%。
未来的工作主要包括:
- 让MPI的安装变得更容易
- 分布式深度学习模型调参经验的收集与分享
- 增加大型模型的示例
CoRR 2018 | Horovod: Fast and Easy Distributed Deep Learning in Tensorflow的更多相关文章
- (转)分布式深度学习系统构建 简介 Distributed Deep Learning
HOME ABOUT CONTACT SUBSCRIBE VIA RSS DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...
- 英特尔深度学习框架BigDL——a distributed deep learning library for Apache Spark
BigDL: Distributed Deep Learning on Apache Spark What is BigDL? BigDL is a distributed deep learning ...
- Summary on deep learning framework --- TensorFlow
Summary on deep learning framework --- TensorFlow Updated on 2018-07-22 21:28:11 1. Check failed: s ...
- Distributed Deep Learning
安利一下刘铁岩老师的<分布式机器学习>这本书 以及一个大神的blog: https://zhuanlan.zhihu.com/p/29032307 https://zhuanlan.zhi ...
- Comparing deep learning frameworks: Tensorflow, CNTK, MXNet, & Caffe
https://imaginghub.com/blog/10-a-comparison-of-four-deep-learning-frameworks-tensorflow-cntk-mxnet-a ...
- Install PaddlePaddle (Parallel Distributed Deep Learning)
Step 1: Install docker on your linux system (My linux is fedora) https://docs.docker.com/engine/inst ...
- NeurIPS 2017 | TernGrad: Ternary Gradients to Reduce Communication in Distributed Deep Learning
在深度神经网络的分布式训练中,梯度和参数同步时的网络开销是一个瓶颈.本文提出了一个名为TernGrad梯度量化的方法,通过将梯度三值化为\({-1, 0, 1}\)来减少通信量.此外,本文还使用逐层三 ...
- [ Deep Learning ] Keras & TensorFlow安装依赖包
OS:Mac Python:3.6 一.先安装Keras,再安装TensorFlow 1. 安装Keras Package Version---------- -------h5py 2.7.1 Ke ...
- 【深度学习Deep Learning】资料大全
最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books by Yoshua Bengio, Ian Goodfellow and Aaron C ...
随机推荐
- 【Java】eclipse中的JUnit单元测试
eclipse中的JUnit单元测试 步骤: 选中当前工程 - 右键选择:build path - add libraries - JUnit 4 - 下一步 创建Java类,进行单元测试. 此时的J ...
- 【涨姿势】原来golang的case <-time.After(xxx)还有这样的坑
偶然看到这样一篇文章:<使用 pprof 排查 Golang 内存泄露>https://www.toutiao.com/i6881796351139676680/ 最后一段让我很疑惑: 修 ...
- 【记录一个问题】ndk下使用c++11的condition_variable问题较多
1.存在通知丢失的情况:生产者线程通知196次,消费者线程收到190次,导致部分数据无法被处理. 2.cond.wait()方法后的加锁有问题,导致对空队列进行出队操作然后coredump.一直记得w ...
- manjaro20软件商店无法链接下载
软件商店如果无法链接下载 解决方案1 可以使用terminal慢慢下载,.bashrc中配置代理 如果依然不行,检查网络设置代理是否为自动或者手动设置正确. 解决方案2 检查是否未设置中国社区源或者重 ...
- 实习之bii--配置esxi重启时,虚拟机也跟随重启
由于初创环境不稳定又是服务器会重启,而内部安装的多部虚拟机并不默认跟随启动,需要设置,方法如下: 1.在本地通过vsphere client 登录到esxi的服务器上,然后点击配置找到虚拟机启动/关机 ...
- Qt之QColorDialog
widget.h: #ifndef WIDGET_H #define WIDGET_H #include <QWidget> class Widget : public QWidget { ...
- 3D建模服务提供更高效、专业的能力,“筑”力开发者
3D建模服务(3D Modeling Kit)是HMS Core在图形图像领域又一技术开放.3D建模产品的定位就是要做快速.简洁.低成本的3D制作能力,并陆续开放给有3D模型.动画游戏制作等能力诉求的 ...
- List<Integer>里有可能存String类型元素吗?
这其实是我遇到的一个线上bug,在这里分享给大家. 如果是用反射,那就很简单了,毕竟泛型只是在编译期进行约束,对运行期是无能为力的. 想想看,如果不使用反射,有没有办法做到呢? 问题起因 在我们公司的 ...
- SIFT,SuperPoint在图像特征提取上的对比实验
SIFT,SuperPoint都具有提取图片特征点,并且输出特征描述子的特性,本篇文章从特征点的提取数量,特征点的正确匹配数量来探索一下二者的优劣. 视角变化较大的情况下 原图1 原图2 SuperP ...
- 随机UA
from fake_useragent import UserAgent ua = UserAgent().random headers={ 'User-Agent':ua } print(heade ...