Tensorflow数据读取方式主要包括以下三种

  1. Preloaded data:预加载数据
  2. Feeding: 通过Python代码读取或者产生数据,然后给后端
  3. Reading from file: 通过TensorFlow队列机制,从文件中直接读取数据

前两种方法比较基础而且容易理解,在Tensorflow入门教程、书本中经常可以见到,这里不再进行介绍。

在介绍Tensorflow第三种读取数据方法之前,介绍以下有关队列相关知识

Queue(队列)

队列是用来存放数据的,并且tensorflow中的Queue中已经实现了同步机制,所以我们可以放心的往里面添加数据还有读取数据。如果Queue中的数据满了,那么en_queue(队列添加元素)操作将会阻塞,如果Queue是空的,那么dequeue(队列抛出元素)操作就会阻塞.在常用环境中,一般是有多个en_queue线程同时像Queue中放数据,有一个dequeue操作从Queue中取数据。

Coordinator(协调管理器)

Coordinator主要是用来帮助管理多个线程,协调多线程之间的配合

 # Thread body: loop until the coordinator indicates a stop was requested.
# If some condition becomes true, ask the coordinator to stop.
#将coord传入到线程中,来帮助它们同时停止工作
def MyLoop(coord):
while not coord.should_stop():
...do something...
if ...some condition...:
coord.request_stop()
# Main thread: create a coordinator.
coord = tf.train.Coordinator()
# Create 10 threads that run 'MyLoop()'
threads = [threading.Thread(target=MyLoop, args=(coord,)) for i in xrange(10)]
# Start the threads and wait for all of them to stop.
for t in threads:
t.start()
coord.join(threads)

QueueRunner()

QueueRunner可以创建多个线程对队列(queue)进行插入(enqueue)操作,它是一个op,这些线程可以通过上述的Coordinator协调器来协调工作。

在深度学习中样本数据集有多种存储编码形式,以经典数据集Cifar-10为例,公开共下载的数据有三种存储方式:Bin(二进制)、Python以及Matlab版本。此外,我们常用的还有csv(天池竞赛、百度竞赛等)比较常见或txt等,当然对图片存储最为直观的还是可视化展示的TIF、PNG、JPG等。Tensorflow官方推荐使用他自己的一种文件格式叫TFRecord,具体实现及应用会在以后详细介绍。

从上图中可知,Tensorflow数据读取过程主要包括两个队列(FIFO),一个叫做文件队列,主要用作对输入样本文件的管理(可以想象,所有的训练数据一般不会存储在一个文件内,该部分主要完成对数据文件的管理);另一个叫做数据队列,如果对应的数据是图像可以认为该队列中的每一项都是存储在内存中的解码后的一系列图像像素值。

下面,我们分别新建3个csv文件->A.csv;B.csv;C.csv,每个文件下分别用X_i, y_i代表训练样本的数据及标注信息。

 #-*- coding:gbk -*-
import tensorflow as tf
# 队列1:生成一个先入先出队列和一个QueueRunner,生成文件名队列
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=2)
# 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['string'], ['string']])
with tf.Session() as sess:
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。
for i in range(12):
e_val, l_val = sess.run([example, label])
print(e_val, l_val)
coord.request_stop()
coord.join(threads)

程序中,首先根据文件列表,通过tf.train.string_input_producer(filenames, shuffle=False)函数建立了一个对应的文件管理队列,其中shuffle=False表 示不对文件顺序进行打乱(True表示打乱,每次输出顺序将不再一致)。此外,还可通过设置第三个参数num_epochs来控制文件数据多少。
运行结果如下:

上段程序中,主要完成以下几方面工作:

  1. 针对文件名列表,建立对应的文件队列
  2. 使用reader读取对应文件数据集
  3. 解码数据集,得到样本example和标注label

感兴趣的读者可以打开tf.train.string_input_producer(...)函数,可以看到如下代码

     """
@compatibility(eager)
Input pipelines based on Queues are not supported when eager execution is
enabled. Please use the `tf.data` API to ingest data under eager execution.
@end_compatibility
"""
if context.in_eager_mode():
raise RuntimeError(
"Input pipelines based on Queues are not supported when eager execution"
" is enabled. Please use tf.data to ingest data into your model"
" instead.")
with ops.name_scope(name, "input_producer", [input_tensor]):
input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
element_shape = input_tensor.shape[1:].merge_with(element_shape)
if not element_shape.is_fully_defined():
raise ValueError("Either `input_tensor` must have a fully defined shape "
"or `element_shape` must be specified")
if shuffle:
input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
input_tensor = limit_epochs(input_tensor, num_epochs)
q = data_flow_ops.FIFOQueue(capacity=capacity,
dtypes=[input_tensor.dtype.base_dtype],
shapes=[element_shape],
shared_name=shared_name, name=name)
enq = q.enqueue_many([input_tensor])
queue_runner.add_queue_runner(
queue_runner.QueueRunner(
q, [enq], cancel_op=cancel_op))
if summary_name is not None:
summary.scalar(summary_name,
math_ops.to_float(q.size()) * (1. / capacity))
return q

可以看到该段代码主要完成以下工作:

  1. 创建队列Queue
  2. 创建线程enqueue_many
  3. 添加QueueRunner到collection中
  4. 返回队列Queue

数据解析

 # 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['string'], ['string']])

这里,我们通过定义一个reader来读取每个数据文件内容,也可图中也展示了TensorFlow支持定义多个reader并且读取文件队列中文件内容,从而提供数据读取效率。然后,采用一个decoder_csv函数对读取的原始CSV文件内容进行解码,平时我们也可根据自己数据存储格式选择不同数据解码方式。在这里需要指出的是,上述程序中并没有用到图中展示的第二个数据队列,这是为什么呢。

实际上做深度学习or机器学习训练过程中,为了保证训练过程的高效性通常不采用单个样本数据给训练模型,而是采用一组N个数据(称作mini-batch),并把每组样本个数N成为batch-size。现在假设我们每组需要喂给模型N个数据,需通过N次循环读入内存,然后再通过GPU进行前向or返向传播运算,这就意味着GPU每次运算都需要一段时间等待CPU读取数据,从而大大降低了训练效率。而第二个队列(数据队列)就是为了解决这个问题提出来的,代码实现即为:tf.train.batch()和 tf.train.shuffle_batch,这两个函数的主要区别在于是否需要将列表中数据进行随机打乱。

 #-*- coding:gbk -*-
import tensorflow as tf
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列
filenames = ['A.csv', 'B.csv', 'C.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=3)
# 定义Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value, record_defaults=[['string'], ['string']])
#example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=16, capacity=200, min_after_dequeue=100, num_threads=2)
example_batch, label_batch = tf.train.batch([example,label], batch_size=8, capacity=200, num_threads=2)
#example_list = [tf.decode_csv(value, record_defaults=[['string'], ['string']])
# for _ in range(2)] # Reader设置为2
### 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。
#example_batch, label_batch = tf.train.batch_join(
# example_list, batch_size=5)
init_local_op = tf.initialize_local_variables()
with tf.Session() as sess:
sess.run(init_local_op)
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。
for i in range(5):
# Retrieve a single instance:
e_val, l_val = sess.run([example_batch, label_batch])
print(e_val, l_val)
coord.request_stop()
coord.join(threads)

使用tf.train.batch()函数,每次根据自己定义大小会返回一组训练数据,从而避免了往内存中循环读取数据的麻烦,提高了效率。并且还可以通过设置reader个数,实现多线程高效地往数据队列(或叫内存队列)中填充数据,直到文件队列读完所有文件(或文件数据不足一个batch size)。
tf.train.batch()程序运行结果如下

注:tf.train.batch([example,label], batch_size=8, capacity=200, num_threads=2)参数中,capacity表示队列大小,每次读出数据后队尾会按顺序依次补充。num_treads=2表示两个线程(据说在一个reader下可达到最优),batch_size=8表示每次返回8组训练数据,即batch size大小。tf.train.shuffle_batch()比tf.train.bathc()多一个min_after_dequeue参数,意思是在每次抛出一个batch后,剩余数据样本不少于多少个。

Tensorflow学习-数据读取的更多相关文章

  1. AI学习---数据读取&神经网络

    AI学习---数据读取&神经网络 fa

  2. Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例

    紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取.建立网络以及模型训练整理成一个小样例,完整代码如下. #coding:utf-8 import t ...

  3. tensorflow之数据读取探究(2)

    tensorflow之tfrecord数据读取 Tensorflow关于TFRecord格式文件的处理.模型的训练的架构为: 1.获取文件列表.创建文件队列:http://blog.csdn.net/ ...

  4. tensorflow之数据读取探究(1)

    Tensorflow中之前主要用的数据读取方式主要有: 建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用.使用这种方法十分灵活,可以一下子将所有数据 ...

  5. 关于Tensorflow 的数据读取环节

    Tensorflow读取数据的一般方式有下面3种: preloaded直接创建变量:在tensorflow定义图的过程中,创建常量或变量来存储数据 feed:在运行程序时,通过feed_dict传入数 ...

  6. 机器学习: TensorFlow 的数据读取与TFRecords 格式

    最近学习tensorflow,发现其读取数据的方式看起来有些不同,所以又重新系统地看了一下文档,总得来说,tensorflow 有三种主流的数据读取方式: 1) 传送 (feeding): Pytho ...

  7. tensorflow学习--数据加载

    文章主要来自Tensorflow官方文档,同时加入了自己的理解以及部分代码 数据读取 TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每 ...

  8. 『TensorFlow』数据读取类_data.Dataset

    一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...

  9. TensorFlow的数据读取机制

    一.tensorflow读取机制图解 首先需要思考的一个问题是,什么是数据读取?以图像数据为例,读取的过程可以用下图来表示 假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003 ...

随机推荐

  1. JavaScript中对象数组 根据某个属性值 然后push到新的数组

    原文链接 https://segmentfault.com/q/1010000010075035 将下列对象数组中,工资大于1w的员工,增加到对象数组 WanSalary中 var BaiduUser ...

  2. orderBy新写法

    通常,我们处理排序规则的处理方法是在sql 语句中order by create_time desc, 但是这时我们需要从控制器中一步步找到该方法,操作多. 我们试着将业务逻辑拆分到控制器 中, 把排 ...

  3. 使用Python自动提取内容摘要

    https://www.biaodianfu.com/automatic-text-summarizer.html 利用计算机将大量的文本进行处理,产生简洁.精炼内容的过程就是文本摘要,人们可通过阅读 ...

  4. Install OpenCV on Ubuntu or Debian

    http://milq.github.io/install-OpenCV-ubuntu-debian/转注:就用第一个方法吧,第二个方法的那个sh文件执行失败,因为我价格kurento.org的源,在 ...

  5. android 开发中,经常遇到http://dl-ssl.google.com/ 无法访问的问题解决

    window - android sdk manager 在选择某个版本sdk安装时,总是出现http://dl-ssl.google.com/无法链接的问题,那是因为qiang太高了,不过也还是有办 ...

  6. springAOP之代理

    AOP是指面向切面编程. 在学习AOP之前先来了解一下代理,因为传说中的AOP其实也对代理的一种应用. 首先来看这样一段代码: public interface Hello { void say(St ...

  7. gitlab钩子搭建

    目标:在本地开发机上push代码到GitLab仓库时,通过钩子同步到测试服务器 准备工作GitLab 服务器一台测试服务器一台本地开发服务器一台 1.在gitlab上新建一个项目,名称test2.在本 ...

  8. ws-trust、域、webservice接口的总结

    最近燃料公司门户做了一个待办的汇总,从三个数据源拿数据汇总到首页,这三个数据源分别是域认证的接口,域认证的webservices,证书加密的接口,下面就这些接口,做一下简单总结 1 pfx证书的探索过 ...

  9. Coursera-AndrewNg(吴恩达)机器学习笔记——第三周

    一.逻辑回归问题(分类问题) 生活中存在着许多分类问题,如判断邮件是否为垃圾邮件:判断肿瘤是恶性还是良性等.机器学习中逻辑回归便是解决分类问题的一种方法.二分类:通常表示为yϵ{0,1},0:&quo ...

  10. C#实现的HttpGet请求

    话不多说,代码贴上: /// <summary> /// HTTP Get请求 /// </summary> /// <param name="url" ...