CoRR 2018 | Horovod: Fast and Easy Distributed Deep Learning in Tensorflow
将深度学习模型的训练从单GPU扩展到多GPU主要面临以下问题:(1)训练框架必须支持GPU间的通信,(2)用户必须更改大量代码以使用多GPU进行训练。为了克服这些问题,本文提出了Horovod,它通过Ring Allreduce实现高效的GPU间通信,而且仅仅更改少量代码就可以实现多GPU训练。
TensorFlow中提供了一些分布式训练的API,这些API适用于不同的环境。这就导致用户往往不知道如何更改代码以进行分布式训练,而且debug也很困难。再者,TensorFlow的分布式训练性能与理想的性能相差甚远,尤其是在大规模GPU环境下。如图1所示,随着GPU数量的增加,分布式TensorFlow的吞吐量与理想的吞吐量的差距逐渐增加,加速比逐渐降低。
因为目前单GPU可以容纳大部分深度学习模型,所以本文主要针对数据并行进行优化。首先来看一下数据并行的训练过程:
- 运行多个模型副本
(a) 读取一部分数据
(b) 把数据喂给模型,进行前向传播
(c) 反向传播,计算梯度 - 将多个模型的梯度进行平均
- 更新模型
- 重复上述步骤直到模型收敛
在标准的TensorFlow中,分布式训练使用参数服务器架构,如图3所示。在参数服务器架构中,主要有worker和server两种角色。worker负责处理数据,计算梯度然后把梯度传给server;server负责聚合梯度,更新模型,然后把模型传回worker。
在这上述两种模式下,主要有以下两个挑战:
- 如何确定worker和server的数量。如果只使用1台server,那么这台server可能成为计算和网络瓶颈;如果使用多台server,那么通信模式就类似于all-to-all,这样就不能完整利用网络带宽。
- 处理愈加复杂的TensorFlow程序。在TensorFlow中,必须显式地启动worker和server,传递一堆参数然后更新代码,这就使得分布式训练变得非常繁琐复杂。
所幸的是,2017年百度提出了一种名为Ring Allreduce的算法。在该算法中,所有worker组成一个环,每台worker只和相邻的两台worker通信,如图4所示。
在Ring Allreduce中,如果有\(N\)个节点,那么每个节点会通信\(2\times (N -1)\)次:前\(N-1\)次接收值并把它加到对应的buffer中,后\(N-1\)次接收并替换对应buffer中的值。Ring Allreduce算法是带宽最优的,也就是说,当buffer足够大时,它会最大限度地利用网络带宽。
综上所述,本文取长补短,使用Ring Allreduce算法优化TensorFlow的分布式训练过程。本文的实现流程如下:
- 将代码转换成独立的Python包,名为Horovod
- 将百度的Ring Allreduce实现替换为NCCL
- 增加了对单机多GPU训练的支持
- 根据反馈更新了部分API,还实现了一个广播操作,以在所有worker上进行强制一致性初始化
import tensorflow as tf
import horovod.tensorflow as hvd
# Initialize Horovod
hvd.init()
# Pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())
# Build model...
loss = ...
opt = tf.train.AdagradOptimizer(0.01)
# Add Horovod Distributed Optimizer
opt = hvd.DistributedOptimizer(opt)
# Add hook to broadcast variables from rank 0 to all other process
# during initialization.
hooks = [hvd.BroadcastGlobalVariablesHook(0)]
train_op = opt.minimize(loss)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing
# when done or an error occurs.
with tf.train.MonitoredTrainingSession(checkpoint="/tmp/train_logs",
config=config, hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Perform synchronous training
mon_sess.run(trian_op)
此外,Horovod还提供了一个名为Timeline的分析工具,它可以让用户每个节点在每次迭代时做了什么,效果如图5所示。
使用Timeline对一些模型进行分析后,发现当张量较小时,Ring Allreduce的效率并不高。因此,本文提出一种名为张量融合的技术来解决上述问题。
- 检测哪些张量将会被规约,选择适合缓冲区并具有相同数据类型的前几个张量
- 申请张量融合所需的缓冲区(如果之前没有申请的话),默认大小为64M
- 将选择的张量拷贝到融合缓冲区
- 在融合缓冲区执行allreduce操作
- 将数据从融合缓冲区拷贝到输出张量
- 重复上述步骤直到环中没有要被规约的向量
使用Horovod之后,Inception V3和ResNet-101模型的性能提升了约88%,如图6所示。
如图7,RDMA网络并没有比传统的TCP提升多少性能,只提升了约4%。
未来的工作主要包括:
- 让MPI的安装变得更容易
- 分布式深度学习模型调参经验的收集与分享
- 增加大型模型的示例
CoRR 2018 | Horovod: Fast and Easy Distributed Deep Learning in Tensorflow的更多相关文章
- (转)分布式深度学习系统构建 简介 Distributed Deep Learning
HOME ABOUT CONTACT SUBSCRIBE VIA RSS DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...
- 英特尔深度学习框架BigDL——a distributed deep learning library for Apache Spark
BigDL: Distributed Deep Learning on Apache Spark What is BigDL? BigDL is a distributed deep learning ...
- Summary on deep learning framework --- TensorFlow
Summary on deep learning framework --- TensorFlow Updated on 2018-07-22 21:28:11 1. Check failed: s ...
- Distributed Deep Learning
安利一下刘铁岩老师的<分布式机器学习>这本书 以及一个大神的blog: https://zhuanlan.zhihu.com/p/29032307 https://zhuanlan.zhi ...
- Comparing deep learning frameworks: Tensorflow, CNTK, MXNet, & Caffe
https://imaginghub.com/blog/10-a-comparison-of-four-deep-learning-frameworks-tensorflow-cntk-mxnet-a ...
- Install PaddlePaddle (Parallel Distributed Deep Learning)
Step 1: Install docker on your linux system (My linux is fedora) https://docs.docker.com/engine/inst ...
- NeurIPS 2017 | TernGrad: Ternary Gradients to Reduce Communication in Distributed Deep Learning
在深度神经网络的分布式训练中,梯度和参数同步时的网络开销是一个瓶颈.本文提出了一个名为TernGrad梯度量化的方法,通过将梯度三值化为\({-1, 0, 1}\)来减少通信量.此外,本文还使用逐层三 ...
- [ Deep Learning ] Keras & TensorFlow安装依赖包
OS:Mac Python:3.6 一.先安装Keras,再安装TensorFlow 1. 安装Keras Package Version---------- -------h5py 2.7.1 Ke ...
- 【深度学习Deep Learning】资料大全
最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books by Yoshua Bengio, Ian Goodfellow and Aaron C ...
随机推荐
- 微信小程序输入框上移问题解决
微信小程序的输入框在上面还好,如果不是,在聚焦的时候页面就会上移,上方的页面信息会看不到,影响用户操作 在这里可以手动设置并获取输入框的高度来解决 这种方式虽然有的机子有点卡,但是已经算是比较完美的解 ...
- 下载并搭建maven环境
1.下载maven 1.在官网下载maven http://maven.apache.org/download.cgi 2.将下载maven解压.复制路径. 2.搭建maven环境 1.新建M2_H ...
- 面试官:为什么 TCP 三次握手期间,客户端和服务端的初始化序列号要求不一样?
大家好,我是小林. 为什么 TCP 三次握手期间,客户端和服务端的初始化序列号要求不一样的呢? 接下来,我一步一步给大家讲明白,我觉得应该有不少人会有类似的问题,所以今天在肝一篇! 正文 为什么 TC ...
- JAVA并发-AQS知识笔记
概述 AQS是AbstractQueuedSynchronizer的缩写,翻译成中文就是抽象队列同步器,AbstractQueuedSynchronizer这个类也是在java.util.concur ...
- Android官方文档翻译 十七 4.1Starting an Activity
Starting an Activity 开启一个Activity This lesson teaches you to 这节课教给你 Understand the Lifecycle Callbac ...
- Rust 实现Netty HashedWheelTimer时间轮
目录 一.背景 二.延迟队列-时间轮 三.Netty 时间轮源码分析 四.Rust实现HashedWheelTimer 五.总结思考 一.背景 近期在内网上看到一篇文章,文中提到的场景是 系统自动取消 ...
- gin框架中图形验证码的生成和验证
功能和验证码使用原理 本案例中没有使用redis作为缓存,而是使用的内存存储方法 github链接地址 下载命令 go get github.com/mojocn/base64Captcha 请求处理 ...
- Kubernetes中部署wordpress+mysql(六)
经过前面的内容其实对k8s已经有了服务迁移的能力了,下面这篇文章主要是用来搭建一些后面要用的组件 一.创建wordpress命名空间 kubectl create namespace wordpres ...
- 使用jQuery所需添加的头文件
<script src="https://code.jquery.com/jquery-3.3.1.min.js"></script> 在html文件中输入 ...
- redis集群运维
Redis 的数据类型? Redis 支持五种数据类型:string(字符串),hash(哈希),list(列表),set(集合)及 zsetsorted set(有序集合) redis优势 速度快, ...