TensorFlow多线程输入数据处理框架(三)——组合训练数据
参考书
《TensorFlow:实战Google深度学习框架》(第2版)
通过TensorFlow提供的tf.train.batch和tf.train.shuffle_batch函数来将单个的样例组织成batch的形式输出。
- #!/usr/bin/env python
- # -*- coding: UTF-8 -*-
- # coding=utf-8
- """
- @author: Li Tian
- @contact: 694317828@qq.com
- @software: pycharm
- @file: sample_data_deal2.py
- @time: 2019/2/4 11:15
- @desc: 通过TensorFlow提供的tf.train.batch和tf.train.shuffle_batch函数来将单个的样例组织成batch的形式输出。
- """
- import tensorflow as tf
- # 使用tf.train.match_filenames_once函数获取文件列表
- files = tf.train.match_filenames_once('./data.tfrecords-*')
- # 通过tf.train.string_input_producer函数创建输入队列,输入队列中的文件列表为
- # tf.train.match_filenames_once函数获取的文件列表。这里将shuffle参数设为False
- # 来避免随机打乱读文件的顺序。但一般在解决真实问题时,会将shuffle参数设置为True
- filename_queue = tf.train.string_input_producer(files, shuffle=False)
- # 如前面所示读取并解析一个样本
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- features = tf.parse_single_example(
- serialized_example,
- features={
- 'i': tf.FixedLenFeature([], tf.int64),
- 'j': tf.FixedLenFeature([], tf.int64),
- }
- )
- # 使用前面的方法读取并解析得到的样例。这里假设Example结构中i表示一个样例的特征向量
- # 比如一张图像的像素矩阵。而j表示该样例对应的标签。
- example, label = features['i'], features['j']
- # 一个batch中样例的个数。
- batch_size = 3
- # 组合样例的队列中最多可以存储的样例个数。这个队列如果太大,那么需要占用很多内存资源;
- # 如果太小,那么出队操作可能会因为没有数据而被阻碍(block),从而导致训练效率降低。
- # 一般来说这个队列的大小会和每一个batch的大小相关,下面一行代码给出了设置队列大小的一种方式。
- capacity = 1000 + 3 * batch_size
- # 使用tf.train.batch函数来组合样例。[example, label]参数给出了需要组合的元素,
- # 一般example和label分别代表训练样本和这个样本对应的正确标签。batch_size参数给出了
- # 每个batch中样例的个数。capacity给出了队列的最大容量。每当队列长度等于容量时,
- # TensorFlow将暂停入队操作,而只是等待元素出队。当元素个数小于容量时,
- # TensorFlow将自动重新启动入队操作。
- # example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)
- # 使用tf.train.shuffle_batch函数来组合样例。tf.train.shuffle_batch函数的参数
- # 大部分都和tf.train.batch函数相似,但是min_after_dequeue参数是tf.train.shuffle_batch
- # 函数特有的。min_after_dequeue参数限制了出队时队列中元素的最少个数。当队列中元素太少时,
- # 随机打乱样例顺序的作用就不大了。所以tf.train.shuffle_batch函数提供了限制出队时最少元素的个数
- # 来保证随机打乱顺序的作用。当出队函数被调用但是队列中元素不够时,出队操作将等待更多的元素入队
- # 才会完成。如果min_after_dequeue参数被设定,capacity也应该相应调整来满足性能需求。
- example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=30)
- with tf.Session() as sess:
- tf.local_variables_initializer().run()
- tf.global_variables_initializer().run()
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- # 获取并打印组合之后的样例。在真实问题中,这个输出一般会作为神经网络的输入。
- for i in range(2):
- cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
- print(cur_example_batch, cur_label_batch)
- coord.request_stop()
- coord.join(threads)
运行结果:
1. 使用tf.train.batch函数来组合样例
2. 使用tf.train.shuffle_batch函数来组合样例
3. 两个函数的区别
tf.train.batch函数不会随机打乱顺序,而tf.train.shuffle_batch会随机打乱顺序。
TensorFlow多线程输入数据处理框架(三)——组合训练数据的更多相关文章
- TensorFlow多线程输入数据处理框架(四)——输入数据处理框架
参考书 <TensorFlow:实战Google深度学习框架>(第2版) 输入数据处理的整个流程. #!/usr/bin/env python # -*- coding: UTF-8 -* ...
- Tensorflow多线程输入数据处理框架
Tensorflow提供了一系列的对图像进行预处理的方法,但是复杂的预处理过程会减慢整个训练过程,所以,为了避免图像的预处理成为训练神经网络效率的瓶颈,Tensorflow提供了多线程处理输入数据的框 ...
- TensorFlow多线程输入数据处理框架(二)——输入文件队列
参考书 <TensorFlow:实战Google深度学习框架>(第2版) 一个简单的程序来生成样例数据. #!/usr/bin/env python # -*- coding: UTF-8 ...
- Tensorflow多线程输入数据处理框架(一)——队列与多线程
参考书 <TensorFlow:实战Google深度学习框架>(第2版) 对于队列,修改队列状态的操作主要有Enqueue.EnqueueMany和Dequeue.以下程序展示了如何使用这 ...
- tensorflow学习笔记——多线程输入数据处理框架
之前我们学习使用TensorFlow对图像数据进行预处理的方法.虽然使用这些图像数据预处理的方法可以减少无关因素对图像识别模型效果的影响,但这些复杂的预处理过程也会减慢整个训练过程.为了避免图像预处理 ...
- 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架
import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...
- 吴裕雄--天生自然 pythonTensorFlow图形数据处理:输入数据处理框架
import tensorflow as tf # 1. 创建文件列表,通过文件列表创建输入文件队列 files = tf.train.match_filenames_once("F:\\o ...
- (第二章第三部分)TensorFlow框架之读取二进制数据
系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html (第二章第二部分)Tens ...
- Hadoop 1.0 和 2.0 中的数据处理框架 - MapReduce
1. MapReduce - 映射.化简编程模型 1.1 MapReduce 的概念 1.1.1 map 和 reduce 1.1.2 shufftle 和 排序 MapReduce 保证每个 red ...
随机推荐
- POJ 1151 HDU 1542 Atlantis(扫描线)
题目大意就是:去一个地方探险,然后给你一些地图描写叙述这个地方,每一个描写叙述是一个矩形的右下角和左上角.地图有些地方是重叠的.所以让你求出被描写叙述的地方的总面积. 扫描线的第一道题,想了又想,啸爷 ...
- IO管理与磁盘调度
- 关于Android滑动冲突的解决方法(二)
之前的一遍学习笔记主要就Android滑动冲突中,在不同方向的滑动所造成冲突进行了了解,这样的冲突非常easy理解,当然也非常easy解决.今天,就同方向的滑动所造成的冲突进行一下了解,这里就先以垂直 ...
- Codeforces Round #335 (Div. 2) 606B Testing Robots(模拟)
B. Testing Robots time limit per test 2 seconds memory limit per test 256 megabytes input standard i ...
- openwrt gstreamer实例学习笔记(三.深入了解gstreamer 的 Element)
在前面的部分,我们简要介绍过 GstElementFactory 可以用来创建一个element的实例,但是GstElementFactory不仅仅只能做这件事,GstElementFactory作为 ...
- Mono 和 .NET Core比翼双飞
大家好,今天给大家分享.NET 蓝图之下的Mono和.NET Core 话题,微软在Build 2019 大会上给.NET 做了一个五年规划,所以分享的主题就是<Mono和.NET Core 比 ...
- TC SRM 583 DIV 2
做了俩,rating涨了80.第二个题是关于身份证的模拟题,写的时间比较长,但是我认真检查了... 第三个题是最短路,今天写了写,写的很繁琐,写的很多错. #include <cstring&g ...
- jquery1.9是最后支持IE678
bootstrap 需要 jquery 1.9.1或更高 jquery1.9是最后支持IE678
- Java面试必会知识点
1.== 和 equals()比较: (1)== 是运算符,equals()是Object中定义的方法: (2)== 比较的是 数值 是否相同,基本类型比较数值,引用类型比较对象地址的数值:且变量类型 ...
- HDU1520 Anniversary party —— 树形DP
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1520 Anniversary party Time Limit: 2000/1000 MS (Java ...