1. Tensorflow高效流水线Pipeline

2. Tensorflow的数据处理中的Dataset和Iterator

3. Tensorflow生成TFRecord

4. Tensorflow的Estimator实践原理

1. 前言

GPU和TPU可以显著缩短执行单个训练步所需的时间。实现最高性能需要高效的输入流水线,以在当前时间步完成之前为下一步提供数据。tf.data API可以帮助我们构建灵活高效的输入流水线。本文档介绍了 tf.data API的功能,以及在各种模型和加速器上构建高性能TensorFlow输入流水线的最佳做法

2. Pipeline Structure输入流水线结构

我们可以将典型的 TensorFlow 训练输入流水线视为 ETL 流程:

  1. Extract:从永久性存储(可以是 HDD 或 SSD 等本地存储或 GCS 或 HDFS 等远程存储)读取数据。
  2. Transform:使用CPU核心解析数据并对其执行预处理操作,例如图像解压缩、数据增强转换(例如随机裁剪、翻转和颜色失真)、重排和批处理。
  3. Load:将转换后的数据加载到执行机器学习模型的加速器设备(例如,GPU 或 TPU)上。

这种模式可高效利用 CPU,同时预留加速器来完成对模型进行训练的繁重工作。此外,将输入流水线视为 ETL 流程可提供便于应用性能优化的结构。

使用 tf.estimator.Estimator API 时,前两个阶段(提取和转换)是在 input_fn(传递给 tf.estimator.Estimator.train)中捕获的。代码可能如以下(简单序列)实现所示:

def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"image": tf.FixedLengthFeature((), tf.string, ""),
"label": tf.FixedLengthFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["image"])
image = _augment_helper(image) # augments image using slice, reshape, resize_bilinear
return image, parsed["label"] def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
dataset = dataset.map(map_func=parse_fn)
dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset

2.1 最佳Pipeline步骤

在这里先给出最佳做法,如果同学们只想知道怎么做,直接参考这里就可以啦。

下面的内容是针对每一点优化的原理。

3. 优化性能

由于新型计算设备(例如 GPU 和 TPU)可以不断提高神经网络的训练速度,因此,CPU 处理很容易成为瓶颈。tf.data API 为用户提供构建块来设计可高效利用 CPU 的输入流水线,并优化 ETL 流程的每个步骤。

3.1 prefetch预取数据

要执行训练步骤,您必须首先提取并转换训练数据,然后将其提供给在加速器上运行的模型。但是,在一个简单的同步实现中,当 CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练步的用时是 CPU 预处理时间和加速器训练时间的总和。

流水线将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。

如果不使用流水线,CPU 和 GPU/TPU 在大部分时间都处于空闲状态:

使用流水线可显著减少空闲时间:

tf.data API 通过 tf.data.Dataset.prefetch 转换提供了一种软件流水线机制,该机制可用于将生成数据的时间和使用数据的时间分离开。具体而言,该转换使用后台线程和内部缓冲区,以便在请求元素之前从输入数据集中预取这些元素。因此,为了实现上图所示的流水线效果,您可以将 prefetch() 作为最终转换添加到数据集流水线中(如果单步训练使用 n 个元素,则添加 prefetch(n))。

要将此项更改应用于我们正在运行的示例,请将:

dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset

更改为:

dataset = dataset.batch(batch_size=FLAGS.batch_size)
dataset = dataset.prefetch(buffer_size=FLAGS.prefetch_buffer_size)
return dataset

3.2 map并行处理数据转换

准备批次数据时,可能需要预处理输入元素。为此,tf.data API 提供了 tf.data.Dataset.map 转换,以将用户定义的函数(例如,正在运行的示例的 parse_fn)应用于输入数据集的每个元素。由于输入元素彼此独立,因此可以跨多个 CPU 核心并行执行预处理。为实现这一点,map 转换提供了 num_parallel_calls 参数来指定并行处理级别。例如,下图说明了将 num_parallel_calls=2 设置为 map 转换的效果:

并行后,由于数据预处理的时间缩短,整体的时间也减少了。如何为 num_parallel_calls 参数选择最佳值取决于硬件、训练数据的特征(例如其大小和形状)、映射函数的成本以及同时在 CPU 上进行的其他处理;一个简单的启发法是设为可用 CPU 核心的数量。例如,如果执行以上示例的机器有 4 个核心,则设置 num_parallel_calls=4 会更高效。另一方面,将 num_parallel_calls 设置为远大于可用 CPU 数量的值可能会导致调度效率低下,进而减慢速度。

要将此项更改应用于我们正在运行的示例,请将:

dataset = dataset.map(map_func=parse_fn)

更改为:

dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)

此外,如果批次大小为数百或数千,那么并行处理批次创建过程还可能给流水线带来更大的优势。为此,tf.data API 提供了 tf.contrib.data.map_and_batch 转换,它可以有效地将映射和批次转换“混合”在一起。

要将此项更改应用于我们正在运行的示例,请将:

dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)
dataset = dataset.batch(batch_size=FLAGS.batch_size)

更改为:

dataset = dataset.apply(tf.contrib.data.map_and_batch(
map_func=parse_fn, batch_size=FLAGS.batch_size))

3.3 并行处理远程数据提取

在实际设置中,输入数据可能会远程存储(例如,GCS 或 HDFS),这是因为输入数据不适合本地存储,或因为训练是分布式训练,因此在每台机器上复制输入数据没有意义。非常适合在本地读取数据的数据集流水线在远程读取数据时可能会遇到 I/O 瓶颈,这是因为本地存储和远程存储之间存在以下差异:

  • 首字节时间:与本地存储相比,从远程存储读取文件的首字节所用时间可能要多出几个数量级。
  • 读取吞吐量:虽然远程存储通常可提供较大的聚合带宽,但读取单个文件可能只能利用此带宽的一小部分。

此外,将原始字节读入内存中后,可能还需要对数据进行反序列化或解密(例如,protobuf),这会带来额外的开销。无论数据是在本地还是远程存储,都存在这种开销,但如果未有效预取数据,则在远程存储的情况下可能更糟。

为了降低各种数据提取开销的影响,tf.data API 提供了 tf.contrib.data.parallel_interleave 转换。使用此转换可以并行执行其他数据集(例如数据文件读取器)并交错这些数据集的内容。可以通过 cycle_length 参数指定要重叠的数据集的数量。

下图说明了为 parallel_interleave 转换提供 cycle_length=2 的效果:



要将此项更改应用于我们正在运行的示例,请将:

dataset = files.interleave(tf.data.TFRecordDataset)

更改为:

dataset = files.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=FLAGS.num_parallel_readers))

由于负载或网络事件,远程存储系统的吞吐量可能会随时间而变化。鉴于这种差异,parallel_interleave 转换可以选择使用预取(如需了解详情,请参阅 tf.contrib.data.parallel_interleave)。

默认情况下,parallel_interleave 转换可提供元素的确定性排序以帮助实现可再现性。作为预取的替代方案(在某些情况下可能效率低下),parallel_interleave 转换还提供了一个可提升性能但无法保证排序的选项。特别是,如果 sloppy 参数设为 true,则该转换可在系统请求下一个元素时暂时跳过其元素不可用的文件,从而放弃该转换的确定性排序。

4. 性能考虑因素

tf.data API 围绕可组合转换而设计,旨在为用户提供灵活性。虽然这些转换中有很多都是可以交替的,但某些转换的顺序会对性能产生影响。

4.1 map映射和batch批次

调用传递给 map 转换的用户定义函数具有与调度和执行用户定义函数相关的开销。通常,与函数执行的计算量相比,这种开销很小。但是,如果 map 几乎不起作用,那么这种开销可能会占总成本的很大一部分。在这种情况下,建议向量化用户定义的函数(即,让该函数一次对一批输入进行操作),并在 map 转换之前先应用 batch 转换

或者直接更改为如下代码:

dataset = dataset.apply(tf.contrib.data.map_and_batch(
map_func=parse_fn, batch_size=FLAGS.batch_size))

4.2 map映射和cache缓存

tf.data.Dataset.cache 转换可以在内存或本地存储中缓存数据集。如果传递给 map 转换的用户定义函数代价很高,则只要内存或本地存储仍可以容纳生成的数据集,就可以在映射转换后应用缓存转换。如果用户定义的函数会增加存储数据集所需的空间,并超出缓存容量,请考虑在训练作业之前预处理数据以减少资源消耗量。

4.3 map映射和interleave交错/prefetch预取/shuffle重排

许多转换(包括map interleave、prefetch 和 shuffle)都维持一个内部元素缓冲区。如果传递给 map 转换的用户定义函数改变了元素的大小,那么映射转换的顺序和缓冲元素的转换会影响内存使用量。通常,我们建议选择可以减少内存占用的顺序,除非为了提高性能而需要采用不同的顺序(例如,为了混合映射和批次转换)。

4.4 repeat重复和shuffle重排

tf.data.Dataset.repeat 转换会将输入数据重复有限(或无限)次;每次数据重复通常称为一个周期。tf.data.Dataset.shuffle 转换会随机化数据集样本的顺序。

如果在 shuffle 转换之前应用 repeat 转换,则系统会对周期边界进行模糊处理。也就是说,某些元素可以在其他元素出现之前重复出现。另一方面,如果在重复转换之前应用 shuffle 转换,那么在每个周期开始时性能可能会降低,因为需要初始化 shuffle 转换的内部状态。换言之,前者(repeat 在 shuffle 之前)可提供更好的性能,而后者(repeat 在 shuffle 之前)可提供更强的排序保证。

如果可能,建议您使用 tf.contrib.data.shuffle_and_repeat 混合转换,这样可以达到两全其美的效果(良好的性能和强大的排序保证)。否则,我们建议在repeat重复之前进行shuffle重排

1. Tensorflow高效流水线Pipeline的更多相关文章

  1. jenkins的流水线pipeline+项目实验php

    声明:实验环境使用Jenkins的应用与搭建的环境 新建一个流水线 pipeline脚本语法架构 node('slave节点名'){ def 变量 #def可以进行变量声明 stage('阶段名A') ...

  2. Tensorflow高效读取数据的方法

    最新上传的mcnn中有完整的数据读写示例,可以参考. 关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码 ...

  3. Jenkins流水线(pipeline)实战之:从部署到体验

    关于Jenkins流水线(pipeline) Jenkins 流水线 (pipeline) 是一套插件,让Jenkins可以实现持续交付管道的落地和实施. 关于blueocean Blue Ocean ...

  4. Hadoop架构: 流水线(PipeLine)

    该系列总览: Hadoop3.1.1架构体系——设计原理阐述与Client源码图文详解 : 总览 流水线(PipeLine),简单地理解就是客户端向DataNode传输数据(Packet)和接收Dat ...

  5. 吴裕雄 python 机器学习——数据预处理流水线Pipeline模型

    from sklearn.svm import LinearSVC from sklearn.pipeline import Pipeline from sklearn import neighbor ...

  6. 8.Jenkins进阶之流水线pipeline基础使用实践(1)

    ​目录一览: 0x01 基础实践 (1) Maven 构建之 Pipeline Script (2) Maven 构建之 Pipeline Script from SCM (3) Jenkins pi ...

  7. TensorFlow高效读取数据的方法——TFRecord的学习

    关于TensorFlow读取数据,官网给出了三种方法: 供给数据(Feeding):在TensorFlow程序运行的每一步,让python代码来供给数据. 从文件读取数据:在TensorFlow图的起 ...

  8. Redis附加功能之Redis流水线pipeline

    流水线功能的目的:通过减少客户端与服务器之间的通信次数来提高程序的执行效率. 一.通信 在一般情况下, 用户每执行一个 Redis 命令,客户端与服务器都需要进行一次通信:客户端会将命令请求发送给服务 ...

  9. Tensorflow高效读取数据

    关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow ...

随机推荐

  1. 基于 Jenkins+Docker+Git 的CI流程初探

    在如今的互联网时代,随着软件开发复杂度的不断提高,软件开发和发布管理也越来越重要.目前已经形成一套标准的流程,最重要的组成部分就是持续集成(Continuous Integration,CI)及持续部 ...

  2. IdentityServer4-前后端分离之Vue(七)

    前言 之前文章讲到如何使用Node.js+Express构建JavaScript客户端,实现前后端分离.本节将介绍如何使用Vue实现前后端分离,文中介绍Vue的知识比较基础,适合新手学习. 一.搭建V ...

  3. spring整合ssmXML版

    以下是一个简单的ssm项目:如果中途报错,肯定是tomcat配置或者数据库配置有问题,在程序中注意将包名等配置换成自己的.数据库表需要提前建好,并加入数据,注意表结构要和实体对象对应. 1.开发条件: ...

  4. Emit学习笔记

    1,给字段设置值,并返回 static void Main(string[] args) { //给字段设置值,并返回 AssemblyName assemblyName = new Assembly ...

  5. php页面静态化,ob缓存方法

    <?php ob_start();//开启缓存 //要生成静态网页的内容开始 ?> 中间的html代码 <?php //要生成静态网页的内容结束 //把生成的静态内容保存到文件,而不 ...

  6. android 按钮特效 波纹 Android button effects ripple

    android 按钮特效 波纹 Android button effects ripple 作者:韩梦飞沙 Author:han_meng_fei_sha 邮箱:313134555@qq.com E- ...

  7. C#中的快捷键,可以更方便的编写代码

    C#中的快捷键,可以更方便的编写代码 CTRL + SHIFT + B 生成解决方案 CTRL + F7 生成编译 CTRL + O 打开文件 CTRL + SHIFT + O 打开项目 CTRL + ...

  8. bzoj4503: 两个串 bitset

    目录 题目链接 题解 代码 题目链接 bzoj4503: 两个串 题解 暴一发bitset f[i][j] 表示 S[1..i] 是否有个后缀能匹配 T[1..j] 那么假设 S[i+1] 能匹配 T ...

  9. 5.27 Test

    1.COGS.2039. 树的统计 思路: 各种方法. 代码: 1.遍历树1   时间 0.314 s   平均内存 2.96 MB #include<cstdio> using name ...

  10. PHP is_numeric 检测变量是否为数字或数字字符串

    bool is_numeric ( mixed $var ) 如果 var 是数字和数字字符串则返回 TRUE,否则返回 FALSE. For example 1: <?php $v = is_ ...