1. Tensorflow高效流水线Pipeline

2. Tensorflow的数据处理中的Dataset和Iterator

3. Tensorflow生成TFRecord

4. Tensorflow的Estimator实践原理

1. 前言

我们在训练模型的时候,必须经过的第一个步骤是数据处理。在机器学习领域有一个说法,数据处理的好坏直接影响了模型结果的好坏。数据处理是至关重要的一步。

我们今天关注数据处理的另一个问题:假设我们做深度学习,数据的量随随便便就到GB的级别,那数据处理的速度对于模型的训练也很重要。经常遇到的一个情况是,数据处理的时间占了训练整个模型的大部分。

今天介绍的是Tensorflow官方推荐的数据处理方式是用Dataset API同时支持从内存和硬盘的读取,相比之前的两种方法在语法上更加简洁易懂

2. Dataset原理

Google官方给出的Dataset API中的类图如下所示:

2.1 Dataset创建方法

Dataset API还提供了四种创建Dataset的方式:

  • tf.data.Dataset.from_tensor_slices():这个函数直接从内存中读取数据,数据的形式可以是数组、矩阵、dict等。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
#实例化make_one_shot_iterator对象,该对象只能读取一次
iterator = dataset.make_one_shot_iterator()
# 从iterator里取出一个元素
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))
  • tf.data.TFRecordDataset():顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
  • tf.data.TextLineDataset():这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
  • tf.data.FixedLengthRecordDataset():这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。

2.2 Dataset数据进行转换(Transformation)

一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作,常用的Transformation有:

  • map:接收一个函数对象,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
  • apply:应用一个转换函数到dataset。
dataset = dataset.apply(group_by_window(key_func, reduce_func, window_size))
  • batch:根据接收的整数值将该数个元素组合成batch,如下面的程序将dataset中的元素组成了大小为32的batch。
dataset = dataset.batch(32)
  • shuffle:打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。
dataset = dataset.shuffle(buffer_size=10000)
  • repeat:整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch。
dataset = dataset.repeat(5)
# 如果repeat没有参数,则一直重复循环数据
dataset = dataset.repeat()
  • padded_batch:对dataset中的数据进行padding到一定的长度。
dataset.padded_batch(
batch_size,
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([]), # tgt_output
tf.TensorShape([]),
tf.TensorShape([src_max_len])), # src_len
padding_values=(
src_eos_id, # src
0, # tgt_len -- unused
0, # src_len -- unused
0)) # mask
  • shard:根据多GPU进行分片操作。
dataset.shard(num_shards, shard_index)

比较完整的生成dataset的代码。

def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"image": tf.FixedLengthFeature((), tf.string, ""),
"label": tf.FixedLengthFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["image"])
image = _augment_helper(image) # augments image using slice, reshape, resize_bilinear
return image, parsed["label"] #简单的生成input_fn
def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
dataset = dataset.map(map_func=parse_fn)
dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset

3. Iterator原理

3.1 Iterator Init初始化

生成Iterator一共有4种,复杂程度递增,个人觉得掌握前两种应该够用了,Iterator还有一个优势,目前,单次迭代器是唯一易于与 Estimator 搭配使用的类型

  • one shot Iterator:one shot Iterator是最简单的一种Iterator,仅支持对整个数据集访问一遍,不需要显式的初始化。one-shot Iterator不支参数化。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next() for i in range(100):
value = sess.run(next_element)
assert i == value
  • initializable Iterator:Initializable Iterator 要求在使用之前显式的通过调用Iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数。
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next() sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
  • reinitializable Iterator:可以被不同的dataset对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象。
  • feedable Iterator:可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。

3.2 Iterator get_next遍历数据

Iterator.get_next() 方法tf.Tensor 对象,每次tf.Session.run(Iterator.get_next())都会获取底层数据集中下一个元素的值。

如果迭代器到达数据集的末尾,则执行 Iterator.get_next() 操作会产生 tf.errors.OutOfRangeError。在此之后,迭代器将处于不可用状态;如果需要继续使用,则必须对其重新初始化。

sess.run(iterator.initializer)
while True:
try:
sess.run(getNextTensor)
except tf.errors.OutOfRangeError:
sess.run(iterator.initializer)

3.3 Iterator Save保存

tf.contrib.data.make_saveable_from_iterator 函数通过迭代器创建一个 SaveableObject,该对象可用于保存和恢复迭代器(实际上是整个输入管道)的当前状态。

# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator) # Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver() with tf.Session() as sess: if should_checkpoint:
saver.save(path_to_checkpoint) # Restore the iterator state.
with tf.Session() as sess:
saver.restore(sess, path_to_checkpoint)

4. 总结

本文介绍了创建不同种类的Dataset和Iterator对象的基础知识,熟悉这个数据处理的步骤后,不仅复用性比较强,而且效率也能成倍的提升。

2. Tensorflow的数据处理中的Dataset和Iterator的更多相关文章

  1. TensorFlow数据读取方式:Dataset API

    英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...

  2. Web Service 中返回DataSet结果大小改进

    http://www.cnblogs.com/scottckt/archive/2012/11/10/2764496.html Web Service 中返回DataSet结果方法: 1)直接返回Da ...

  3. Web Service 中返回DataSet结果的几种方法

    Web Service 中返回DataSet结果的几种方法: 1)直接返回DataSet对象    特点:通常组件化的处理机制,不加任何修饰及处理:    优点:代码精减.易于处理,小数据量处理较快: ...

  4. Oracle存储过程实现返回多个结果集 在构造函数方法中使用 dataset

    原文 Oracle存储过程实现返回多个结果集 在构造函数方法中使用 dataset DataSet相当你用的数据库: DataTable相当于你的表.一个 DataSet 可以包含多个 DataTab ...

  5. 深度学习基础系列(五)| 深入理解交叉熵函数及其在tensorflow和keras中的实现

    在统计学中,损失函数是一种衡量损失和错误(这种损失与“错误地”估计有关,如费用或者设备的损失)程度的函数.假设某样本的实际输出为a,而预计的输出为y,则y与a之间存在偏差,深度学习的目的即是通过不断地 ...

  6. 深度学习利器:TensorFlow在智能终端中的应用——智能边缘计算,云端生成模型给移动端下载,然后用该模型进行预测

    前言 深度学习在图像处理.语音识别.自然语言处理领域的应用取得了巨大成功,但是它通常在功能强大的服务器端进行运算.如果智能手机通过网络远程连接服务器,也可以利用深度学习技术,但这样可能会很慢,而且只有 ...

  7. 第5章分布式系统模式 在 .NET 中使用 DataSet 实现 Data Transfer Object

    要在 .NET Framework 中实现分布式应用程序.客户端应用程序需要显示一个窗体,该窗体要求对 ASP.NET Web Service 进行多个调用以满足单个用户请求.基于性能方面的考虑,我们 ...

  8. 大数据处理中的Lambda架构和Kappa架构

    首先我们来看一个典型的互联网大数据平台的架构,如下图所示: 在这张架构图中,大数据平台里面向用户的在线业务处理组件用褐色标示出来,这部分是属于互联网在线应用的部分,其他蓝色的部分属于大数据相关组件,使 ...

  9. 吴裕雄--天生自然python TensorFlow图片数据处理:解决TensorFlow2.0 module ‘tensorflow’ has no attribute ‘python_io’

    tf.python_io出错 TensorFlow 2.0 中使用 Python_io 暂时使用如下指令: tf.compat.v1.python_io.TFRecordWriter(filename ...

随机推荐

  1. P2376 [USACO09OCT]津贴Allowance

    P2376 [USACO09OCT]津贴Allowance一开始想的是多重背包,但是实践不了.实际是贪心,让多c尽可能少,所以先放大的,最后让小的来弥补. #include<iostream&g ...

  2. Angular 个人深究(四)【生命周期钩子】

    Angular 个人深究(四)[生命周期钩子] 定义: 每个组件都有一个被 Angular 管理的生命周期. Angular 创建它,渲染它,创建并渲染它的子组件,在它被绑定的属性发生变化时检查它,并 ...

  3. idea颜色主题

    作者:韩梦飞沙 Author:han_meng_fei_sha 邮箱:313134555@qq.com E-mail: 313134555 @qq.com IDEA 主题样式 === 这个垂直线的 颜 ...

  4. 解决AD9中“......has no driver”的问题

  5. poj 1184

    经典的宽搜题目,感觉最好的办法应该是双向广搜. 不过用简单的启发式搜索可以飘过. #include <iostream> #include <cstdio> #include ...

  6. api日常总结

    异步加载JS和CSS <script type="text/javascript"> (function () { var s = document.createEle ...

  7. 安装NVIDIA驱动时禁用自带nouveau驱动

    安装英伟达驱动时,一般需要禁用自带nouveau驱动,按如下命令操作: sudo vim /etc/modprobe.d/blacklist-nouveau.conf 添加如下内容: blacklis ...

  8. [leetcode]Minimum Window Substring @ Python

    原题地址:https://oj.leetcode.com/problems/minimum-window-substring/ 题意: Given a string S and a string T, ...

  9. Revit中如何给不同构件着色

    在Revit构件密集,默认的显示模式难以区分不同构件的区别,比如建筑立面有很多不同的机电管道,风管.水管,电缆桥架等,可一个给不同的机电管线添加不同的颜色,以示其区别,如下图所示,完成着色后,各种不同 ...

  10. Spark LDA实战

    选取了10个文档,其中4个来自于一篇论文,3篇来自于一篇新闻,3篇来自于另一篇新闻. 首先在pom文件中加入mysql-connector-java: <dependency> <g ...