2. Tensorflow的数据处理中的Dataset和Iterator
1. Tensorflow高效流水线Pipeline
2. Tensorflow的数据处理中的Dataset和Iterator
3. Tensorflow生成TFRecord
4. Tensorflow的Estimator实践原理
1. 前言
我们在训练模型的时候,必须经过的第一个步骤是数据处理。在机器学习领域有一个说法,数据处理的好坏直接影响了模型结果的好坏。数据处理是至关重要的一步。
我们今天关注数据处理的另一个问题:假设我们做深度学习,数据的量随随便便就到GB的级别,那数据处理的速度对于模型的训练也很重要。经常遇到的一个情况是,数据处理的时间占了训练整个模型的大部分。
今天介绍的是Tensorflow官方推荐的数据处理方式是用Dataset API同时支持从内存和硬盘的读取,相比之前的两种方法在语法上更加简洁易懂
2. Dataset原理
Google官方给出的Dataset API中的类图如下所示:
2.1 Dataset创建方法
Dataset API还提供了四种创建Dataset的方式:
- tf.data.Dataset.from_tensor_slices():这个函数直接从内存中读取数据,数据的形式可以是数组、矩阵、dict等。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
#实例化make_one_shot_iterator对象,该对象只能读取一次
iterator = dataset.make_one_shot_iterator()
# 从iterator里取出一个元素
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))
- tf.data.TFRecordDataset():顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
- tf.data.TextLineDataset():这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
- tf.data.FixedLengthRecordDataset():这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。
2.2 Dataset数据进行转换(Transformation)
一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作,常用的Transformation有:
- map:接收一个函数对象,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1。
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
- apply:应用一个转换函数到dataset。
dataset = dataset.apply(group_by_window(key_func, reduce_func, window_size))
- batch:根据接收的整数值将该数个元素组合成batch,如下面的程序将dataset中的元素组成了大小为32的batch。
dataset = dataset.batch(32)
- shuffle:打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。
dataset = dataset.shuffle(buffer_size=10000)
- repeat:整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch。
dataset = dataset.repeat(5)
# 如果repeat没有参数,则一直重复循环数据
dataset = dataset.repeat()
- padded_batch:对dataset中的数据进行padding到一定的长度。
dataset.padded_batch(
batch_size,
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([]), # tgt_output
tf.TensorShape([]),
tf.TensorShape([src_max_len])), # src_len
padding_values=(
src_eos_id, # src
0, # tgt_len -- unused
0, # src_len -- unused
0)) # mask
- shard:根据多GPU进行分片操作。
dataset.shard(num_shards, shard_index)
比较完整的生成dataset的代码。
def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"image": tf.FixedLengthFeature((), tf.string, ""),
"label": tf.FixedLengthFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["image"])
image = _augment_helper(image) # augments image using slice, reshape, resize_bilinear
return image, parsed["label"]
#简单的生成input_fn
def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
dataset = dataset.map(map_func=parse_fn)
dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset
3. Iterator原理
3.1 Iterator Init初始化
生成Iterator一共有4种,复杂程度递增,个人觉得掌握前两种应该够用了,Iterator还有一个优势,目前,单次迭代器是唯一易于与 Estimator 搭配使用的类型。
- one shot Iterator:one shot Iterator是最简单的一种Iterator,仅支持对整个数据集访问一遍,不需要显式的初始化。one-shot Iterator不支参数化。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
- initializable Iterator:Initializable Iterator 要求在使用之前显式的通过调用Iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数。
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
- reinitializable Iterator:可以被不同的dataset对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象。
- feedable Iterator:可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。
3.2 Iterator get_next遍历数据
Iterator.get_next() 方法tf.Tensor 对象,每次tf.Session.run(Iterator.get_next())都会获取底层数据集中下一个元素的值。
如果迭代器到达数据集的末尾,则执行 Iterator.get_next() 操作会产生 tf.errors.OutOfRangeError。在此之后,迭代器将处于不可用状态;如果需要继续使用,则必须对其重新初始化。
sess.run(iterator.initializer)
while True:
try:
sess.run(getNextTensor)
except tf.errors.OutOfRangeError:
sess.run(iterator.initializer)
3.3 Iterator Save保存
tf.contrib.data.make_saveable_from_iterator 函数通过迭代器创建一个 SaveableObject,该对象可用于保存和恢复迭代器(实际上是整个输入管道)的当前状态。
# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
with tf.Session() as sess:
if should_checkpoint:
saver.save(path_to_checkpoint)
# Restore the iterator state.
with tf.Session() as sess:
saver.restore(sess, path_to_checkpoint)
4. 总结
本文介绍了创建不同种类的Dataset和Iterator对象的基础知识,熟悉这个数据处理的步骤后,不仅复用性比较强,而且效率也能成倍的提升。
2. Tensorflow的数据处理中的Dataset和Iterator的更多相关文章
- TensorFlow数据读取方式:Dataset API
英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...
- Web Service 中返回DataSet结果大小改进
http://www.cnblogs.com/scottckt/archive/2012/11/10/2764496.html Web Service 中返回DataSet结果方法: 1)直接返回Da ...
- Web Service 中返回DataSet结果的几种方法
Web Service 中返回DataSet结果的几种方法: 1)直接返回DataSet对象 特点:通常组件化的处理机制,不加任何修饰及处理: 优点:代码精减.易于处理,小数据量处理较快: ...
- Oracle存储过程实现返回多个结果集 在构造函数方法中使用 dataset
原文 Oracle存储过程实现返回多个结果集 在构造函数方法中使用 dataset DataSet相当你用的数据库: DataTable相当于你的表.一个 DataSet 可以包含多个 DataTab ...
- 深度学习基础系列(五)| 深入理解交叉熵函数及其在tensorflow和keras中的实现
在统计学中,损失函数是一种衡量损失和错误(这种损失与“错误地”估计有关,如费用或者设备的损失)程度的函数.假设某样本的实际输出为a,而预计的输出为y,则y与a之间存在偏差,深度学习的目的即是通过不断地 ...
- 深度学习利器:TensorFlow在智能终端中的应用——智能边缘计算,云端生成模型给移动端下载,然后用该模型进行预测
前言 深度学习在图像处理.语音识别.自然语言处理领域的应用取得了巨大成功,但是它通常在功能强大的服务器端进行运算.如果智能手机通过网络远程连接服务器,也可以利用深度学习技术,但这样可能会很慢,而且只有 ...
- 第5章分布式系统模式 在 .NET 中使用 DataSet 实现 Data Transfer Object
要在 .NET Framework 中实现分布式应用程序.客户端应用程序需要显示一个窗体,该窗体要求对 ASP.NET Web Service 进行多个调用以满足单个用户请求.基于性能方面的考虑,我们 ...
- 大数据处理中的Lambda架构和Kappa架构
首先我们来看一个典型的互联网大数据平台的架构,如下图所示: 在这张架构图中,大数据平台里面向用户的在线业务处理组件用褐色标示出来,这部分是属于互联网在线应用的部分,其他蓝色的部分属于大数据相关组件,使 ...
- 吴裕雄--天生自然python TensorFlow图片数据处理:解决TensorFlow2.0 module ‘tensorflow’ has no attribute ‘python_io’
tf.python_io出错 TensorFlow 2.0 中使用 Python_io 暂时使用如下指令: tf.compat.v1.python_io.TFRecordWriter(filename ...
随机推荐
- P2376 [USACO09OCT]津贴Allowance
P2376 [USACO09OCT]津贴Allowance一开始想的是多重背包,但是实践不了.实际是贪心,让多c尽可能少,所以先放大的,最后让小的来弥补. #include<iostream&g ...
- Angular 个人深究(四)【生命周期钩子】
Angular 个人深究(四)[生命周期钩子] 定义: 每个组件都有一个被 Angular 管理的生命周期. Angular 创建它,渲染它,创建并渲染它的子组件,在它被绑定的属性发生变化时检查它,并 ...
- idea颜色主题
作者:韩梦飞沙 Author:han_meng_fei_sha 邮箱:313134555@qq.com E-mail: 313134555 @qq.com IDEA 主题样式 === 这个垂直线的 颜 ...
- 解决AD9中“......has no driver”的问题
- poj 1184
经典的宽搜题目,感觉最好的办法应该是双向广搜. 不过用简单的启发式搜索可以飘过. #include <iostream> #include <cstdio> #include ...
- api日常总结
异步加载JS和CSS <script type="text/javascript"> (function () { var s = document.createEle ...
- 安装NVIDIA驱动时禁用自带nouveau驱动
安装英伟达驱动时,一般需要禁用自带nouveau驱动,按如下命令操作: sudo vim /etc/modprobe.d/blacklist-nouveau.conf 添加如下内容: blacklis ...
- [leetcode]Minimum Window Substring @ Python
原题地址:https://oj.leetcode.com/problems/minimum-window-substring/ 题意: Given a string S and a string T, ...
- Revit中如何给不同构件着色
在Revit构件密集,默认的显示模式难以区分不同构件的区别,比如建筑立面有很多不同的机电管道,风管.水管,电缆桥架等,可一个给不同的机电管线添加不同的颜色,以示其区别,如下图所示,完成着色后,各种不同 ...
- Spark LDA实战
选取了10个文档,其中4个来自于一篇论文,3篇来自于一篇新闻,3篇来自于另一篇新闻. 首先在pom文件中加入mysql-connector-java: <dependency> <g ...