TensorFlow读取二进制文件数据到队列
2016-11-03 09:30:00      0个评论    来源:diligent_321的博客  
收藏   我要投稿

TensorFlow是一种符号编程框架(与theano类似),先构建数据流图再输入数据进行模型训练。Tensorflow支持很多种样例输入的方式。最容易的是使用placeholder,但这需要手动传递numpy.array类型的数据。第二种方法就是使用二进制文件和输入队列的组合形式。这种方式不仅节省了代码量,避免了进行data augmentation和读文件操作,可以处理不同类型的数据, 而且也不再需要人为地划分开“预处理”和“模型计算”。在使用TensorFlow进行异步计算时,队列是一种强大的机制。

队列使用概述

正如TensorFlow中的其他组件一样,队列就是TensorFlow图中的节点。这是一种有状态的节点,就像变量一样:其他节点可以修改它的内容。具体来说,其他节点可以把新元素插入到队列后端(rear),也可以把队列前端(front)的元素删除。队列,如FIFOQueue和RandomShuffleQueue(A queue implementation that dequeues elements in a random order.)等对象,在TensorFlow的tensor异步计算时都非常重要。例如,一个典型的输入结构是使用一个RandomShuffleQueue来作为模型训练的输入,多个线程准备训练样本,并且把这些样本压入队列,一个训练线程执行一个训练操作,此操作会从队列中移除最小批次的样本(mini-batches),这种结构具有许多优点。

TensorFlow的Session对象是可以支持多线程的,因此多个线程可以很方便地使用同一个会话(Session)并且并行地执行操作。然而,在Python程序实现这样的并行运算却并不容易。所有线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候, 队列必须能被正确地关闭。所幸TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一起使用。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常。QueueRunner类用来协调多个工作线程同时将多个tensor压入同一个队列中。

(1)读二进制文件数据到队列中

同很多其他的深度学习框架一样,TensorFlow有它自己的二进制格式。它使用了a mixture of its Records 格式和protobuf。Protobuf是一种序列化数据结构的方式,给出了关于数据的一些描述。TFRecords是tensorflow的默认数据格式,一个record就是一个包含了序列化tf.train.Example 协议缓存对象的二进制文件,可以使用python创建这种格式,然后便可以使用tensorflow提供的函数来输入给机器学习模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorflow as tf
 
def read_and_decode_single_example(filename_queue):
# 定义一个空的类对象,类似于c语言里面的结构体定义
    class Image(self):
    pass
    image = Image()
    image.height = 32
    image.width = 32
    image.depth = 3
    label_bytes = 1
     
    Bytes_to_read = label_bytes+image.heigth*image.width*3
    # A Reader that outputs fixed-length records from a file
    reader = tf.FixedLengthRecordReader(record_bytes=Bytes_to_read)
    # Returns the next record (key, value) pair produced by a reader, key 和value都是字符串类型的tensor
    # Will dequeue a work unit from queue if necessary (e.g. when the
    # Reader needs to start reading from a new file since it has
    # finished with the previous file).
    image.key, value_str = reader.read(filename_queue)
    # Reinterpret the bytes of a string as a vector of numbers,每一个数值占用一个字节,在[0, 255]区间内,因此out_type要取uint8类型
    value = tf.decode_raw(bytes=value_str, out_type=tf.uint8)
    # Extracts a slice from a tensor, value中包含了label和feature,故要对向量类型tensor进行'parse'操作
    image.label = tf.slice(input_=value, begin=[0], size=[1])
    value = value.slice(input_=value, begin=[1], size=[-1]).reshape((image.depth, image.height, image.width))
    transposed_value = tf.transpose(value, perm=[2, 0, 1])
    image.mat = transposed_value
    return image

接下来我们便可以调用这个函数了,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
filenames =[os.path.join(data_dir, 'test_batch.bin')]
# Output strings (e.g. filenames) to a queue for an input pipeline
filename_queue = tf.train.string_input_producer(string_tensor=filenames)
# returns symbolic label and image
img_obj = read_and_decode_single_example("filename_queue")
Label = img_obj.label
Image = img_obj.mat
sess = tf.Session()
# 初始化tensorflow图中的所有状态,如待读取的下一个记录tfrecord的位置,variables等
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
# grab examples back.
# first example from file
label_val_1, image_val_1 = sess.run([label, image])
# second example from file
label_val_2, image_val_2 = sess.run([label, image])

值得一提的是,TFRecordReader总是作用于文件名队列。它将会从队列中弹出文件名并使用该文件名,直到tfrecord为空时停止,此时它将从文件名队列中弹出下一个filename。然而,文件名队列又是怎么得来的呢?起初这个队列是空的,QueueRunners的概念即源于此。QueueRunners本质上就是一个线程thread,这个线程负责使用会话session并不断地调用enqueue操作。Tensorflow把这个模式封装在tf.train.QueueRunner对象里面。入队列操作99%的时间都可以被忽略掉,因为这个操作是由后台负责运行。(比如在上面的例子中,tf.train.string_input_producer创建了一个这样的线程,添加QueueRunner到数据流图中)。

可想而知,在你运行任何训练步骤之前,我们要告知tensorflow去启动这些线程,否则这些队列会因为等待数据入队而被堵塞,导致数据流图将一直处于挂起状态。我们可以调用tf.train.start_queue_runners(sess=sess)来启动所有的QueueRunners。这个调用并不是符号化的操作,它会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。另外,必须要先运行初始化操作再创建这些线程。如果这些队列未被初始化,tensorflow会抛出错误。

(2)从二进制文件中读取mini-batchs

在训练机器学习模型时,使用单个样例更新参数属于“online learning”,然而在线下环境下,我们通常采用基于mini-batchs 随机梯度下降法(SGD),但是在tensorflow中如何利用queuerunners返回训练块数据呢?请参见下面的程序:

1
2
3
4
5
image_batch, label_batch = tf.train.shuffle_batch(tensor_list=[image, label]],
                                                  batch_size=batch_size,
                                                  num_threads=24,
                                                  min_after_dequeue=min_samples_in_queue,
                                                  capacity=min_samples_in_queue+3*batch_size)

读取batch数据需要使用新的队列queues和QueueRunners(大致流程图如下)。Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的(image,labels)对送入队列中,这个入队操作是通过QueueRunners启动另外的线程来完成的。这个RandomShuffleQueue会顺序地压样例到队列中,直到队列中的样例个数达到了batch_size+min_after_dequeue个。它然后从队列中选择batch_size个随机的元素进行返回。事实上,shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果。有了这个batches变量,就可以开始训练机器学习模型了。

函数 tf.train.shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, shared_name=None, name=None)的使用说明:

作用:Creates batches by randomly shuffling tensors.(从队列中随机筛选多个样例返回给image_batch和label_batch);

参数说明:

tensor_list: The list of tensors to enqueue.(待入队的tensor list);
batch_size: The new batch size pulled from the queue;
capacity: An integer. The maximum number of elements in the queue(队列长度);
min_after_dequeue: Minimum number elements in the queue after a dequeue, used to ensure a level of mixing of elements.(随机取样的样本总体最小值,用于保证所取mini-batch的随机性);
num_threads: The number of threads enqueuing `tensor_list`.(session会话支持多线程,这里可以设置多线程加速样本的读取)
seed: Seed for the random shuffling within the queue.
enqueue_many: Whether each tensor in `tensor_list` is a single example.(为False时表示tensor_list是一个样例,压入时占用队列中的一个元素;为True时表示tensor_list中的每一个元素都是一个样例,压入时占用队列中的一个元素位置,可以看作为一个batch);
shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for `tensor_list`.
shared_name: (Optional) If set, this queue will be shared under the given name across multiple sessions.

name: (Optional) A name for the operations.

Tensorflow读取文件到队列文件的更多相关文章

  1. 第十二节,TensorFlow读取数据的几种方法以及队列的使用

    TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起 ...

  2. 文件 FIFO队列

    <?php /** * Filefifo.php 文件型FIFO队列 */ class Filefifo { /** * $_file_data, 数据文件的路径 */ private $_fi ...

  3. C#内存映射文件消息队列实战演练(MMF—MQ)

    一.课程介绍 本次分享课程属于<C#高级编程实战技能开发宝典课程系列>中的一部分,阿笨后续会计划将实际项目中的一些比较实用的关于C#高级编程的技巧分享出来给大家进行学习,不断的收集.整理和 ...

  4. 使用java读取文件夹中文件的行数

    使用java统计某文件夹下所有文件的行数 经理突然交代一个任务:要求统计某个文件夹下所有文件的行数.在网上查了一个多小时没有解决.后来心里不爽就决定自己写一个java类用来统计文件的行数,于是花了两个 ...

  5. Android想服务器传图片,透过流的方式。还有读取服务器图片(文件),也通过流的方式。

    /** * Created by Administrator on 2016/7/19. */ import android.util.Log; import com.gtercn.asPolice. ...

  6. Java读取Level-1行情dbf文件极致优化(3)

    最近架构一个项目,实现行情的接入和分发,需要达到极致的低时延特性,这对于证券系统是非常重要的.接入的行情源是可以配置,既可以是Level-1,也可以是Level-2或其他第三方的源.虽然Level-1 ...

  7. Java读取Level-1行情dbf文件极致优化(2)

    最近架构一个项目,实现行情的接入和分发,需要达到极致的低时延特性,这对于证券系统是非常重要的.接入的行情源是可以配置,既可以是Level-1,也可以是Level-2或其他第三方的源.虽然Level-1 ...

  8. Java读取Level-1行情dbf文件极致优化(1)

    最近架构一个项目,实现行情的接入和分发,需要达到极致的低时延特性,这对于证券系统是非常重要的.接入的行情源是可以配置,既可以是Level-1,也可以是Level-2或其他第三方的源.虽然Level-1 ...

  9. php读取指定结束指针文件内容

    fopen操作时文件读取开始指针位于文件开始部分, fseek 以指定文件大小以及开始指针位置确定结束指针位置 具体案例: <?php//打开文件流,fopen不会把文件整个加载到内存$f = ...

随机推荐

  1. 具体解释linux文件处理的的经常使用命令

    原创Blog.转载请注明出处 附上之前訪问量比較高的几篇linux博客 本人使用shell的8个小技巧 grep的九个经典使用场景 sed命令具体解释 awk命令具体解释 linux中全部的东西都是文 ...

  2. RDLC报表钻取空白页问题

    在改动报表查询条件时,钻取页突然空白了,百思不得其解,之前好好的,研究了一个下午和一个晚上.查资料等等,网上非常多资料都是设置报表的 ConsumeConteinerWhitespace = True ...

  3. 点击TButton后的执行OnClick和OnMouseDown两个事件的过程(其实是通过WM_COMMAND执行程序员的代码)

    问题的来源:在李维的<深入浅出VCL>一书中提到了点击TButton会触发WM_COMMAND消息,正是它真正执行了程序员的代码.也许是我比较笨,没有理解他说的含义.但是后来经过追踪代码和 ...

  4. HDU3535 AreYouBusy 混合背包

    题目大意 给出几组物品的体积和价值,每组分为三种:0.组内物品至少选一个:1.组内物品最多选一个:2.组内物品任意选.给出背包容量,求所能得到的最大价值. 注意 仔细审题,把样例好好看完了再答题,否则 ...

  5. Swift - 可编辑表格样例(可直接编辑单元格中内容、移动删除单元格)

    (本文代码已升级至Swift3)   本文演示如何制作一个可以编辑单元格内容的表格(UITableView). 1,效果图 (1)默认状态下,表格不可编辑,当点击单元格的时候会弹出提示框显示选中的内容 ...

  6. 【NOIP2011 Day 1】选择客栈

    [问题描述] 丽江河边有n家客栈,客栈按照其位置顺序从1到n编号.每家客栈都按照某一种色调进行装饰(总共k种,用整数0 ~ k-1表示),且每家客栈都设有一家咖啡店,每家咖啡店均有各自的最低消费.两位 ...

  7. GObject调用父类函数

    最近在分析Gstreamer的代码时,发现GstPipeline中有如下代码: result = GST_ELEMENT_CLASS (parent_class)->change_state ( ...

  8. 大数字运算——1、BigInteger

    package com.wh.BigInteger; import java.math.BigInteger; import java.util.Arrays; /** * @author 王恒 * ...

  9. C# net winform wpf 发送post数据和xml到网页

    由于项目需要发送数据到网页 这里用aspx做测试 采用post以及get发送数据,页面进行数据  首先这个东西很简单很简单,基本上学过的都会,但是原谅一直搞cs几乎不搞bs的猿类吧.三四年没接触bs. ...

  10. 每条sql语句实际上都是一个事物(事物多种类型解读)

    事务(数据库引擎) 事务是作为单个逻辑工作单元执行的一系列操作.一个逻辑工作单元必须有四个属性,称为原子性.一致性.隔离性和持久性 (ACID) 属性,只有这样才能成为一个事务.原子性事务必须是原子工 ...