更加清晰的TFRecord格式数据生成及读取
TFRecords 格式数据文件处理流程
TFRecords 文件包含了 tf.train.Example 协议缓冲区(protocol buffer),协议缓冲区包含了特征 Features。TensorFlow 通过 Protocol Buffers 定义了 TFRecords 文件中存储的数据记录及其所含字段的数据结构,它们分别定义在 tensorflow/core/example 目录下的 example.proto 和 feature.proto 文件中。因此,我们将数据记录转换后的张量称为样例,将记录包含的字段称为特征域。
TFRecords 文件的样例结构层次非常清晰,一个样例包含一组特征。一组特征由多个特征向量组成的 Python 字典构成。为了说明读取 TFRecords 文件中样例的方法,我们首先使用 tf.python_io.TFRecordWriter 方法将下表中的数据写入 TFRecords 文件 stat.tfrecord 中。
表格如下:
'''writer.py'''
# -*- coding: utf-8 -*-
import tensorflow as tf # 创建向TFRecords文件写数据记录的writer
writer = tf.python_io.TFRecordWriter('stat.tfrecord')
# 2轮循环构造输入样例
for i in range(1,3):
# 创建example.proto中定义的样例
example = tf.train.Example(
features = tf.train.Features(
feature = {
'id': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),
'age': tf.train.Feature(int64_list = tf.train.Int64List(value=[i*24])),
'income': tf.train.Feature(float_list = tf.train.FloatList(value=[i*2048.0])),
'outgo': tf.train.Feature(float_list = tf.train.FloatList(value=[i*1024.0]))
}
)
)
# 将样例序列化为字符串后,写入stat.tfrecord文件
writer.write(example.SerializeToString())
# 关闭输出流
writer.close()
然后使用 tf.TFRecordReader 方法读取 stat.tfrecord 文件中的样例,接着使用 tf.parse_single_example 将样例转换为张量。
tf.parse_single_example 方法的输入参数 features 是一个 Python 字典,具体包括组成样例的所有特征的名称和数据类型,
它们必须与 writer. py 中使用 tf.train.Features 方法定义的特征保持完全一致。tf.FixedLenFeature 方法的输入参数为特征形状和特征数据类型。
因为本例中的4个特征都是标量,所以形状为 [] 。
'''reader.py'''
# -*- coding: utf-8 -*-
import tensorflow as tf # 创建文件名队列filename_queue
filename_queue = tf.train.string_input_producer(['stat.tfrecord'])
# 创建读取TFRecords文件的reader
reader = tf.TFRecordReader()
# 取出stat.tfrecord文件中的一条序列化的样例serialized_example
_, serialized_example = reader.read(filename_queue)
# 将一条序列化的样例转换为其包含的所有特征张量
features = tf.parse_single_example(
serialized_example,
features={
'id': tf.FixedLenFeature([], tf.int64),
'age': tf.FixedLenFeature([], tf.int64),
'income': tf.FixedLenFeature([], tf.float32),
'outgo': tf.FixedLenFeature([], tf.float32),
}
)
init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)
# 启动执行入队搡作的后台线程
tf.start_queue_runners(sess=sess)
# 读取第一条数据记录
for i in range(2):
example=sess.run(features)
print(example)
'''
{'age': 24, 'outgo': 1024.0, 'id': 1, 'income': 2048.0}
{'age': 48, 'outgo': 2048.0, 'id': 2, 'income': 4096.0}
'''
在会话执行时,为了使计算任务顺利获取到输入数据,我们需要使用 tf.train.start_queue_runners 方法启动执行入队操作的所有线程,
具体包括将文件名入队到文件名队列的操作,以及将样例入队到样例队列的操作。这些队列操作相关的线程属于 TensorFIow 的后台线程,
它们确保文件名队列和样例队列始终有数据可以供后续操作读取。
————————————————
虽然我们用上面的代码成功读取并输出了 stat.tfrecord 文件中的数据,但是这种方法并不适用于生产环境。因为它的容错性较差,主要体现在队列操作后台线程的生命周期“无入管理",任何线程出现异常都会导致程序崩溃。常见的异常是文件名队列或样例队列越界抛出的 tf.errors.0ut0fRangeError 。队列越界的原因通常是读取的数据记录数量超过了 tf.train_string_input_producer 方法中指定的数据集遍历次数。
为了处理这种异常,我们使用 tf.train.coordinator 方法创建管理多线程生命周期的协调器。协调器的工作原理很简单,它监控 TensorFlow 的所有后台线程。当其中某个线程出现异常时,它的 should_stop 成员方法返回 True,for 循环结束。然后程序执行 finally 中协调器的 request_stop 成员方法,请求所有线程安全退出。
需要注意的是,当我们使用协调器管理多线程前,需要先执行 tf.local_variables_initializer 方法对其进行初始化。为此,我们使用 tf.group 方法将它和 tf.global_variables_initializer 方法聚合生成整个程序的初始化操作 init_op 。
创建协调器
使用协调器的示例如下:
import tensorflow as tf # 创建文件名队列filename_queue,并制定遍历两次数据集
filename_queue = tf.train.string_input_producer(['stat.tfrecord'], num_epochs=2)
# 省略中间过程
#同上面
# 聚合两种初始化操作
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
# 创建协调器,管理线程
coord = tf.train.Coordinator()
# 启动QueueRunner, 此时文件名队列已经进队。
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 打印程序的后台线程信息
print('Threads: %s' % threads)
try:
for i in range(10):
if not coord.should_stop():
example = sess.run(features)
print(example)
except tf.errors.OutOfRangeError:
print('Catch OutOfRangeError')
finally:
# 请求停止所有后台线程
coord.request_stop()
print('Finishreading')
# 等待所有后台线程安全退出
coord.join(threads)
sess.close() '''
输出:
Threads: [<Thread(Thread-1, started daemon 149248776427264)>, \
<Thread(Thread-2, started daemon 149248768934560)>]
{'age': 24, 'outgo': 1024.0, 'id': 1, 'income': 2048.0}
{'age': 48, 'outgo': 2048.0, 'id': 2, 'income': 4096.0}
{'age': 24, 'outgo': 1024.0, 'id': 1, 'income': 2048.0}
{'age': 48, 'outgo': 2048.0, 'id': 2, 'income': 4096.0}
Catch OutOfRangeError
Finish reading
'''
这两句实现的功能就是创建线程并使用 QueueRunner 对象来提取数据。简单来说:使用 tf.train 函数添加 QueueRunner 到 TensorFlow 中。在运行任何训练步骤之前,需要调用 tf.train.start_queue_runners 函数,否则 TensorFlow 将一直挂起。
前面说过 tf.train.start_queue_runners 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个 tf.train.Coordinator ,这样可以在发生错误的情况下正确地关闭这些线程。如果你对训练迭代数做了限制,那么需要使用一个训练迭代数计数器,并且需要被初始化。
创建批样例数据
经过之前的介绍,我们最后得到了许多样例,但是这些样例需要打包聚合成批数据才能供模型训练、评价和推理使用。TensorFlow 提供的 tf.train.shuffle_batch 方法不仅能够使用样例创建批数据,而且能顾在打包过程中打乱样例顺序,增加随机性。因此,我们认为完整的输入流水线应该还包括一个批数据队列。
————————————————
代码实例如下:
def get_my_example(filename_queue):
reader = tf.SomeReader()
_, value = reader.read(filename_queue)
features = tf.decodesome(value)
# 对样例进行预处理
processed_example = some_processing(features)
return processed_example def input_pipeline(filenames, batchsize, num_epochs=None):
# 当num_epochs--None时,表示文件名队列总是可用的,一直循环入队
filename_queue.tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example = get_my_example(filename_queue)
# min_after_dequeue表示从样例队列中出队的样例个数,
# 值越大表示打乱顺序效果越好,同时意味着消耗更多内存
min_after_dequeue = 10000
# capacity表示扯数据队列的容量,推荐设置:
# min_after_dequeue + (num_threads + a small safety margin) * batchsize
capacity = min_after_dequeue + 3 * batch_size
# 创建样例example_batch
examplebatch = tf.train.shuffle_batch(
[example], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch
代码模板
# -*- coding: utf-8 -*-
import tensorflow as tf def read_and_decode(filename):
filename_list = tf.gfile.Glob(filename_pattern)
filename_queue = tf.train.string_input_producer(filename_list, shuffle=True) reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label_raw': tf.FixedLenFeature([], tf.string),
'img_raw': tf.FixedLenFeature([], tf.string),
})
label = tf.decode_raw(features['label_raw'], tf.uint8)
label = tf.reshape(label, [512, 512, 1])
label = tf.cast(label, tf.float32) label_max = tf.reduce_max(label)
label_min = tf.reduce_min(label)
label = (label - label_min) / (label_max - label_min) img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [512, 512, 1])
img = tf.cast(img, tf.float32) img_max = tf.reduce_max(img)
img_min = tf.reduce_min(img)
img = (img - img_min) / (img_max - img_min) example_queue = tf.RandomShuffleQueue(
capacity=16*batch_size,
min_after_dequeue=8*batch_size,
dtypes=[tf.float32, tf.float32],
shapes=[[512, 512, 1], [512, 512, 1]]) num_threads = 16 example_enqueue_op = example_queue.enqueue([img, label]) tf.train.add_queue_runner(tf.train.queue_runner.QueueRunner(
example_queue, [example_enqueue_op]*num_threads)) images, labels = example_queue.dequeue_many(batch_size) return images, labels train_images, train_labels = read_tfrecord('./data/train.tfrecord',
batch_size=train_batch_size)
val_images, val_labels = read_tfrecord('./data/validation.tfrecord',
batch_size=valid_batch_size)
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer()) sess.run(init_op) coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) try:
while not coord.should_stop():
example = sess.run(train_op)
print(example) except tf.errors.OutOfRangeError:
print('Catch OutOfRangeError')
finally:
coord.request_stop()
print('Finishreading') coord.join(threads)
sess.close()
原文链接:https://blog.csdn.net/TeFuirnever/article/details/90523253
更加清晰的TFRecord格式数据生成及读取的更多相关文章
- ini格式数据生成与解析具体解释
ini格式数据生成与解析具体解释 1.ini格式数据长啥样? watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQv/font/5a6L5L2T/fontsize/ ...
- Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)
#写libsvm格式 数据 write libsvm #!/usr/bin/env python #coding=gbk # ================================= ...
- tensorflow制作tfrecord格式数据
tf.Example msg tensorflow提供了一种统一的格式.tfrecord来存储图像数据.用的是自家的google protobuf.就是把图像数据序列化成自定义格式的二进制数据. To ...
- [TFRecord格式数据]利用TFRecords存储与读取带标签的图片
利用TFRecords存储与读取带标签的图片 原创文章,转载请注明出处~ 觉得有用的话,欢迎一起讨论相互学习~Follow Me TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是 ...
- "笨方法"学习CNN图像识别(二)—— tfrecord格式高效读取数据
原文地址:https://finthon.com/learn-cnn-two-tfrecord-read-data/-- 全文阅读5分钟 -- 在本文中,你将学习到以下内容: 将图片数据制作成tfre ...
- iOS开发之JSON格式数据的生成与解析
本文将从四个方面对IOS开发中JSON格式数据的生成与解析进行讲解: 一.JSON是什么? 二.我们为什么要用JSON格式的数据? 三.如何生成JSON格式的数据? 四.如何解析JSON格式的数据? ...
- Android使用DOM生成和输出XML格式数据
Android使用DOM生成和输出XML格式数据 本文主要简单解说怎样使用DOM生成和输出XML数据. 1. 生成和输出XML数据 代码及凝视例如以下: try { DocumentBuilderFa ...
- 转载 -- iOS开发之JSON格式数据的生成与解析
本文将从四个方面对IOS开发中JSON格式数据的生成与解析进行讲解: 一.JSON是什么? 二.我们为什么要用JSON格式的数据? 三.如何生成JSON格式的数据? 四.如何解析JSON格式的数据? ...
- PHP生成和获取XML格式数据
在做数据接口时,我们通常要获取第三方数据接口或者给第三方提供数据接口,而这些数据格式通常是以XML或者JSON格式传输,本文将介绍如何使用PHP生成XML格式数据供第三方调用以及如何获取第三方提供的X ...
随机推荐
- git的安装与命令行基本的使用
1.https://git-scm.com/ 点击这个网址进入git的官方网站 2,.进去里面会有提示,64位于32位的,根据自己的电脑安装 3 下载完了过后就直接安装,一般会安装在c盘里面 ,进入安 ...
- linux tasklet工作队列
工作队列是, 表面上看, 类似于 taskets; 它们允许内核代码来请求在将来某个时间调用 一个函数. 但是, 有几个显著的不同在这 2 个之间, 包括: tasklet 在软件中断上下文中运行的结 ...
- vue在html中写动态背景图片
<div class="img" :style="`background: url(`+item.img+'?any_string_is_ok'+`)center ...
- SpringBoot --web 应用开发之文件上传
原文出处: oKong 前言 上一章节,我们讲解了利用模版引擎实现前端页面渲染,从而实现动态网页的功能,同时也提出了兼容jsp项目的解决方案.既然开始讲解web开发了,我们就接着继续往web这个方向继 ...
- JQ绑定事件的叠加和解决,index()方法的坑
JQ绑定事件的叠加和解决,index()方法的坑 前言 在做过几个不大不小的项目后,发现技术这种东西,必须要多多实践,才能发现各种问题,理论的知识掌握的再好终究是纸上谈兵. 因此目前感觉有两点是必须要 ...
- 如何在ClickOnce 应用中使用 GitVersion
https://github.com/GitTools/GitVersion/issues/1153 I'm using GitVersion in an internal ClickOnce app ...
- lambda应用
def test(a, b, func): result = func(a, b) print(result) test(10, 15, lambda x, y: x + y) #coding=utf ...
- python关于MySQL的API -- pymysql模块
1.模块安装 pip install pymysql 2.执行sql语句 import pymysql #添加数据 conn = pymysql.connect(host='127.0.0.1', p ...
- poj3471 - 倍增+LCA+树上差分
题意:一张n节点连通无向图,n-1条树边,m条非树边.若通过先删一条树边,再删一条非树边想操作 将此图划分为不连通的两部分,问有多少种方案. 利用LCA整好区间覆盖,dfs用来求前缀和 需要注意的是, ...
- 使用原生JDBC方式对数据库进行操作
使用原生JDBC方式对数据库进行操作,包括六个步骤: 1.加载JDBC驱动程序 在连接数据库之前,首先要加载想要连接的数据库的驱动到JVM.可以通过java.lang.Class类的静态方法forNa ...