以往的TensorFLow模型数据的导入方法可以分为两个主要方法,一种是使用feed_dict另外一种是使用TensorFlow中的Queues。前者使用起来比较灵活,可以利用Python处理各种输入数据,劣势也比较明显,就是程序运行效率较低;后面一种方法的效率较高,但是使用起来较为复杂,灵活性较差。

Dataset作为新的API,比以上两种方法的速度都快,并且使用难度要远远低于使用Queues。tf.data中包含了两个用于TensorFLow程序的接口:Dataset和Iterator。

Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道,原因如下:

Dataset

Dataset表示一个元素的集合,可以看作函数式编程中的 lazy list, 元素是tensor tuple。创建Dataset的方式可以分为两种,分别是:

Source

Apply transformation
Source
这里 source 指的是从tf.Tensor对象创建Dataset,常见的方法又如下几种:

  1. tf.data.Dataset.from_tensors((features, labels))
  2. tf.data.Dataset.from_tensor_slices((features, labels))
  3. tf.data.TextLineDataset(filenames)
  4. tf.data.TFRecordDataset(filenames)

作用分别为:

  1.从一个tensor tuple创建一个单元素的dataset;

  2.从一个tensor tuple创建一个包含多个元素的dataset;

  3.读取一个文件名列表,将每个文件中的每一行作为一个元素,构成一个dataset;

  4.读取硬盘中的TFRecord格式文件,构造dataset。

Apply transformation

第二种方法就是通过转化已有的dataset来得到新的dataset,TensorFLow tf.data.Dataset支持很多中变换,在这里介绍常见的几种:

  1. dataset.map(lambda x: tf.decode_jpeg(x))
  2. dataset.repeat(NUM_EPOCHS)
  3. dataset.batch(BATCH_SIZE)

以上三种方式分别表示了:使用map对dataset中的每个元素进行处理,这里的例子是对图片数据进行解码;将dataset重复一定数目的次数用于多个epoch的训练;将原来的dataset中的元素按照某个数量叠在一起,生成mini batch。

将以上代码组合起来,我们可以得到一个常用的代码片段:

  1. # 从一个文件名列表读取 TFRecord 构成 dataset
  2. dataset = TFRecordDataset(["file1.tfrecord", "file2.tfrecord"])
  3. # 处理 string,将 string 转化为 tf.Tensor 对象
  4. dataset = dataset.map(lambda record: tf.parse_single_example(record))
  5. # buffer 大小设置为 10000,打乱 dataset
  6. dataset = dataset.shuffle(10000)
  7. # dataset 将被用来训练 100 个 epoch
  8. dataset = dataset.repeat(100)
  9. # 设置 batch size 为 128
  10. dataset = dataset.batch(128)

Iterator

定义好了数据集以后可以通过Iterator接口来访问数据集中的tensor tuple,iterator保持了数据在数据集中的位置,提供了访问数据集中数据的方法。

可以通过调用 dataset 的 make iterator 方法来构建 iterator。

替换了place_holder,直接在原来开始的x,y处使用.get_next(),然后在sess.run时加个while true,在try里面放sess.run,exception 放OutofRangeError:

  1. X, y = dataset.get_next()
  2.  
  3. while True:
  4. try:
  5. sess.run(accuracy)
  6. except tf.errors.OutOfRangeError:
  7. break

API 支持以下四种 iterator,复杂程度递增:

  • one-shot
  • initializable
  • reinitializable
  • feedable

one-shot

one-shot iterator 谁最简单的一种 iterator,仅支持对整个数据集访问一遍,不需要显式的初始化。one-shot iterator 不支参数化。以下代码使用tf.data.Dataset.range生成数据集,作用与 python 中的 range 类似。

  1. dataset = tf.data.Dataset.range(100)
  2. iterator = dataset.make_one_shot_iterator()
  3. next_element = iterator.get_next()
  4.  
  5. for i in range(100):
  6. value = sess.run(next_element)
  7. assert i == value

initializable

Initializable iterator 要求在使用之前显式的通过调用iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数,如:

  1. max_value = tf.placeholder(tf.int64, shape=[])
  2. dataset = tf.data.Dataset.range(max_value)
  3. iterator = dataset.make_initializable_iterator()
  4. next_element = iterator.get_next()
  5.  
  6. # Initialize an iterator over a dataset with 10 elements.
  7. sess.run(iterator.initializer, feed_dict={max_value: 10})
  8. for i in range(10):
  9. value = sess.run(next_element)
  10. assert i == value
  11.  
  12. # Initialize the same iterator over a dataset with 100 elements.
  13. sess.run(iterator.initializer, feed_dict={max_value: 100})
  14. for i in range(100):
  15. value = sess.run(next_element)
  16. assert i == value

reinitializable

reinitializable iterator 可以被不同的 dataset 对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象,如:

  1. # Define training and validation datasets with the same structure.
  2. training_dataset = tf.data.Dataset.range(100).map(
  3. lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
  4. validation_dataset = tf.data.Dataset.range(50)
  5.  
  6. # A reinitializable iterator is defined by its structure. We could use the
  7. # `output_types` and `output_shapes` properties of either `training_dataset`
  8. # or `validation_dataset` here, because they are compatible.
  9. iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
  10. training_dataset.output_shapes)
  11. next_element = iterator.get_next()
  12.  
  13. training_init_op = iterator.make_initializer(training_dataset)
  14. validation_init_op = iterator.make_initializer(validation_dataset) # 如果后面初始化的是这个,那么就将循环这个数据集
  15.  
  16. # Run 20 epochs in which the training dataset is traversed, followed by the
  17. # validation dataset.
  18. for _ in range(20):
  19. # Initialize an iterator over the training dataset.
  20. sess.run(training_init_op)
  21. for _ in range(100):
  22. sess.run(next_element)
  23.  
  24. # Initialize an iterator over the validation dataset.
  25. sess.run(validation_init_op) # 替换init_op,相当于替换数据集
  26. for _ in range(50):
  27. sess.run(next_element)

feedable

feedable iterator 可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。它提供了与 reinitilizable iterator 类似的功能,并且在切换数据集的时候不需要在开始的时候初始化iterator,还是上面的例子,通过tf.data.Iterator.from_string_handle来定义一个 feedable iterator,达到切换数据集的目的:

  1. # Define training and validation datasets with the same structure.
  2. training_dataset = tf.data.Dataset.range(100).map(
  3. lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
  4. validation_dataset = tf.data.Dataset.range(50)
  5.  
  6. # A feedable iterator is defined by a handle placeholder and its structure. We
  7. # could use the `output_types` and `output_shapes` properties of either
  8. # `training_dataset` or `validation_dataset` here, because they have
  9. # identical structure.
  10. handle = tf.placeholder(tf.string, shape=[])
  11. iterator = tf.data.Iterator.from_string_handle(
  12. handle, training_dataset.output_types, training_dataset.output_shapes)
  13. next_element = iterator.get_next()
  14. # You can use feedable iterators with a variety of different kinds of iterator
  15. # (such as one-shot and initializable iterators).
  16. training_iterator = training_dataset.make_one_shot_iterator()
  17. validation_iterator = validation_dataset.make_initializable_iterator()
  18.  
  19. # The `Iterator.string_handle()` method returns a tensor that can be evaluated
  20. # and used to feed the `handle` placeholder.
  21. training_handle = sess.run(training_iterator.string_handle())
  22. validation_handle = sess.run(validation_iterator.string_handle())
  23.  
  24. # Loop forever, alternating between training and validation.
  25. while True:
  26. # Run 200 steps using the training dataset. Note that the training dataset is
  27. # infinite, and we resume from where we left off in the previous `while` loop
  28. # iteration.
  29. for _ in range(200):
  30. sess.run(next_element, feed_dict={handle: training_handle})
  31.  
  32. # Run one pass over the validation dataset.
  33. sess.run(validation_iterator.initializer)
  34. for _ in range(50):
  35. sess.run(next_element, feed_dict={handle: validation_handle})

使用实例:

  1. def get_encodes(x):
  2. # x is `batch_size` of lines, each of which is a json object
  3. samples = [json.loads(l) for l in x]
  4. text = [s['fact'] for s in samples]
  5. # get a client from available clients
  6. bc_client = bc_clients.pop()
  7. features = bc_client.encode(text)
  8. # after use, put it back
  9. bc_clients.append(bc_client)
  10. labels = [0 for _ in text]
  11. return features, labels
  12.  
  13. data_node = (tf.data.TextLineDataset(train_fp).batch(batch_size)
  14. .map(lambda x: tf.py_func(get_encodes, [x], [tf.float32, tf.int64], name='bert_client'), num_parallel_calls=num_parallel_calls)
  15. .map(lambda x, y: {'feature': x, 'label': y})
  16. .make_one_shot_iterator().get_next())

tf.data的更多相关文章

  1. python3 zip 与tf.data.Data.zip的用法

    ###python自带的zip函数 与 tf.data.Dataset.zip函数 功能用法相似 ''' zip([iterator1,iterator2,]) 将可迭代对象中对应的元素打包成一个元祖 ...

  2. Tensorflow2(二)tf.data输入模块

    代码和其他资料在 github 一.tf.data模块 数据分割 import tensorflow as tf dataset = tf.data.Dataset.from_tensor_slice ...

  3. tf.data(二) —— 并行化 tf.data.Dataset 生成器

    在处理大规模数据时,数据无法全部载入内存,我们通常用两个选项 使用tfrecords 使用 tf.data.Dataset.from_generator() tfrecords的并行化使用前文已经有过 ...

  4. tf.contrib.slim.data数据加载(1) reader

    reader: 适用于原始数据数据形式的Tensorflow Reader 在库中parallel_reader.py是与reader相关的,它使用多个reader并行处理来提高速度,但文件中定义的类 ...

  5. TensorFlow走过的坑之---数据读取和tf中batch的使用方法

    首先介绍数据读取问题,现在TensorFlow官方推荐的数据读取方法是使用tf.data.Dataset,具体的细节不在这里赘述,看官方文档更清楚,这里主要记录一下官方文档没有提到的坑,以示" ...

  6. tf更新tensor/自定义层

    修改Tensor特定位置的值 如 stack overflow 中提到的方案. TensorFlow不让你直接单独改指定位置的值,但是留了个歪门儿,就是tf.scatter_update这个方法,它可 ...

  7. TF常用知识

    命名空间及变量共享 # coding=utf-8 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt; ...

  8. Tensorflow1.4 高级接口使用(estimator, data, keras, layers)

    TensorFlow 高级接口使用简介(estimator, keras, data, experiment) TensorFlow 1.4正式添加了keras和data作为其核心代码(从contri ...

  9. 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)

    1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...

随机推荐

  1. 利用ASIHTTPRequest访问网络

    ASIHTTPRequest是第三方类库,ASIHTTPRequest对CFNetwork API进行了封装. 有如下特点: l 通过简单的接口,即可完成向服务端提交数据和从服务端获取数据的工作 l ...

  2. 1. Two Sum [Array] [Easy]

    Given an array of integers, return indices of the two numbers such that they add up to a specific ta ...

  3. 酒店订房系统:如何使用mysql来确定一个时间段内的房间都是可订的

    需要解决的问题: 假设一个用户选择了日期范围来进行订房,例如:2014-04-25至2014-04-30 ,那么现在问题就出现,你必须要确认在这个时间段内某个房间是否都是有房间的,如果没有那么当然不能 ...

  4. [FMX]将 Android 程序切换到后台及从后台切换到前台实现

    有时候,我们需要将自己的Android程序切换到后台运行,在必要时,将其切换到前台运行.下面提供了一种实现方式,首先需要引用三个单元:   1 uses Androidapi.JNI.App,Andr ...

  5. Spring Boot 2 实践记录之 封装依赖及尽可能不创建静态方法以避免在 Service 和 Controller 的单元测试中使用 Powermock

    在前面的文章中(Spring Boot 2 实践记录之 Powermock 和 SpringBootTest)提到了使用 Powermock 结合 SpringBootTest.WebMvcTest ...

  6. Dalsa线扫相机SDK下载和安装

    1.首先去官方网站下载SDK Support Downloads - Teledyne DALSA http://www.teledynedalsa.com/imaging/support/downl ...

  7. 先装VS2008之后,又装了2013,然后启动VS2008提示“Tools Version”有问题?

    这个网上资料一搜很多,我就是按照下面这个链接去解决的,删除 “14.0” 整个键值文件夹之后重启VS2008就好了, 注意:上面第一张图是我在网上找的08和10版本弹出的错误,我自己弹出的是提示14. ...

  8. qt linux下配置安装

    linux版本: qt卸载: 1. 先找到qt的安装位置: 2.然后执行其下面的文件MaintenanceTool: 3. 然后会出现图形界面: 卸载完成. 安装qt 下载地址: https://ww ...

  9. day 81 天 ORM 操作复习总结

    # ###############基于对象查询(子查询)############## 一.对多查询  正向查询 from django.shortcuts import render,HttpResp ...

  10. Map容器中keySet()、entrySet()

    1.定义 keySet(): 返回的是只存放key值的Set集合,使用迭代器方式遍历该Set集合,在迭代器中再使用get方法获取每一个键对应的值.使用get方法获取键对应的值时就需要遍历Map集合,主 ...