(第二章第四部分)TensorFlow框架之TFRecords数据的存储与读取
系列博客链接:
(第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html
(第二章第二部分)TensorFlow框架之读取图片数据:https://www.cnblogs.com/kongweisi/p/11050539.html
(第二章第三部分)TensorFlow框架之读取二进制数据:https://www.cnblogs.com/kongweisi/p/11050546.html
本文概述:
- 目标
- 说明Example的结构
- 应用TFRecordWriter实现TFRecords文件存储器的构造
- 应用parse_single_example实现解析Example结构
- 应用
- CIFAR10类图片的数据的TFRecords存储和读取
1、什么是TFRecords文件
TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。
TFRecords文件包含了tf.train.Example
协议内存块(protocol buffer)(协议内存块包含了字段 Features
)。可以获取你的数据, 将数据填入到Example
协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter
写入到TFRecords文件。
- 文件格式 *.tfrecords
2、Example结构解析
tf.train.Example
协议内存块(protocol buffer)(协议内存块包含了字段 Features
),Features
包含了一个feature
字段,feature
中包含要写入的数据、并指明数据类型。这是一个样本的结构,批数据需要循环存入这样的结构
- tf.train.Example(features=None)
- 写入tfrecords文件
- features:tf.train.Features类型的特征实例
- return:example格式协议块
- tf.train.Features(feature=None)
- 构建每个样本的信息键值对
- feature: 字典数据, key为要保存的名字
- value为tf.train.Feature实例
- return:Features类型
- tf.train.Feature(options)
- options:例如
- bytes_list=tf.train. BytesList(value=[Bytes])
- int64_list=tf.train. Int64List(value=[Value])
- 支持存入的类型如下
- tf.train.Int64List(value=[Value])
- tf.train.BytesList(value=[Bytes])
- tf.train.FloatList(value=[value])
- options:例如
上面的三段API是从上到下,一层一层嵌套使用的,如果理解有困难,可以看代码中的用法,相信就可以理解了。
这种结构很好的解决了数据和标签(训练的类别标签)或者其他属性数据存储在同一个文件中
3、案例:CIFAR10数据,存入TFRecords文件
3.1分析
构造存储实例,tf.python_io.TFRecordWriter(path)
- 写入tfrecords文件
- path: TFRecords文件的路径 + 文件名字
- return:写文件
- method
- write(record): 向文件中写入一个example
- close(): 关闭文件写入器
循环将数据填入到
Example
协议内存块(protocol buffer)
3.2代码
def write_to_tfrecords(self, image_batch, label_batch):
"""
将数据存进tfrecords,方便管理每个样本的属性
:param image_batch: 特征值
:param label_batch: 目标值
:return: None
"""
# 1、构造tfrecords的存储实例
writer = tf.python_io.TFRecordWriter("./tmp/cifar.tfrecords") # 路径 + 名字 # 2、循环将每个样本写入到文件当中
for i in range(10): # 一个样本一个样本的处理写入
# 准备特征值,特征值必须是bytes类型 调用tostring()函数
# [10, 32, 32, 3] ,在这里避免tensorflow的坑,取出来的不是真正的值,而是类型,所以要运行结果才能存入
# 出现了eval,那就要在会话当中去运行该行数
image = image_batch[i].eval().tostring() # 准备目标值,目标值是一个Int类型
# eval()-->[6]--->6
label = label_batch[i].eval()[0] # 绑定每个样本的属性
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
})) # 写入每个样本的example,记住这里要将example序列化,成字符串
writer.write(example.SerializeToString()) # 文件需要关闭
writer.close()
return None # 开启会话打印内容
with tf.Session() as sess:
# 创建线程协调器
coord = tf.train.Coordinator() # 开启子线程去读取数据
# 返回子线程实例
threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 获取样本数据去训练
print(sess.run([image_batch, label_batch])) # 存入数据
cr.write_to_tfrecords(image_batch, label_batch ) # 关闭子线程,回收
coord.request_stop() coord.join(threads)
4、读取TFRecords文件
读取这种文件整个过程与其他文件一样,只不过需要有个解析Example的步骤。从TFRecords文件中读取数据, 可以使用tf.TFRecordReader
的tf.parse_single_example
解析器。这个操作可以将Example
协议内存块(protocol buffer)解析为张量。
tf.parse_single_example(serialized, features=None, name=None)
- 解析一个单一的Example原型
- serialized:标量字符串Tensor,一个序列化的Example
- features:dict字典数据,键为读取的名字,值为FixedLenFeature
- return:一个键值对组成的字典,键为读取的名字
tf.FixedLenFeature(shape, dtype)
- shape:输入数据的形状,一般不指定, 为空列表
- dtype:输入数据类型,与存储进文件的类型要一致
- 类型只能是float32, int64, string
5、案例:读取CIFAR的TFRecords文件
5.1 分析
- 使用tf.train.string_input_producer构造文件队列
- tf.TFRecordReader 读取TFRecords数据并进行解析
- tf.parse_single_example进行解析
- tf.decode_raw解码
- 类型是bytes类型需要解码
- 其他类型不需要
- 处理图片数据形状以及数据类型,批处理返回
- 开启会话线程运行
5.2 代码
def read_tfrecords(self):
"""
读取tfrecords的数据
:return: None
"""
# 1、构造文件队列
file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"]) # 2、构造tfrecords读取器,读取队列
reader = tf.TFRecordReader() # 默认也是只读取一个样本
key, values = reader.read(file_queue) # tfrecords
# 多了解析example的一个步骤
feature = tf.parse_single_example(values, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
}) # 取出feature里面的特征值和目标值
# 通过键值对获取
image = feature["image"] label = feature["label"] # 3、解码操作
# 对于image是一个bytes类型,所以需要decode_raw去解码成uint8张量
# 对于Label:本身是一个int类型,不需要去解码
image = tf.decode_raw(image, tf.uint8) print(image, label) # 处理image的形状和类型
image_reshape = tf.reshape(image, [self.height, self.width, self.channel]) # 处理label的形状和类型
label_cast = tf.cast(label, tf.int32) print(image_reshape, label_cast) # 4、批处理操作
image_batch, label_batch = tf.train.batch([image_reshape, label_cast], batch_size=10, num_threads=1, capacity=10) print(image_batch, label_batch)
return image_batch, label_batch # 从tfrecords文件读取数据
image_batch, label_batch = cr.read_tfrecords() # 开启会话打印内容
with tf.Session() as sess:
# 创建线程协调器
coord = tf.train.Coordinator()
- 结合第二章第二部分读取商品图片和第三部分读取二进制数据的完整代码
import tensorflow as tf
import os def picread(file_list):
"""
读取商品图片数据到张量
:param file_list:路径+文件名的列表
:return:
"""
# 构造文件队列
file_queue = tf.train.string_input_producer(file_list) # 利用图片读取器去读取文件队列内容 reader = tf.WholeFileReader()
# 默认一次一张图片,没有形状
_, value = reader.read(file_queue) print(value) # 对图片数据进行解码
# string --> uint8
# () ----> (?, ?, ?)
image = tf.image.decode_jpeg(value) print(image) # 图片的形状固定、大小处理
# 把图片大小固定统一大小(算法训练要求样本的特征值数量一样)
# 固定[200, 200]
image_resize = tf.image.resize_images(image, [200, 200])
print(image_resize) # 设置图片图片形状
image_resize.set_shape([200, 200, 3]) # 进行批处理
# 3D ----> 4D
image_batch = tf.train.batch([image_resize], batch_size=10, num_threads=1, capacity=10) print(image_batch)
return image_batch class CifarRead(object):
"""读取CIFAR10类别的二进制文件
"""
def __init__(self): # 每个图片样本的属性
self.height = 32
self.width = 32
self.channel = 3 # bytes
# 1
self.label_bytes = 1
# 3072
self.image_bytes = self.height * self.width * self.channel
# 3073
self.all_bytes = self.label_bytes + self.image_bytes def bytes_read(self, file_list):
"""读取二进制解码张量
:return:
"""
# 1、构造文件队列
file_queue = tf.train.string_input_producer(file_list) # 2、使用tf.FixedLengthRecordReader(bytes)读取
# 默认必须指定读取一个样本
reader = tf.FixedLengthRecordReader(self.all_bytes) _, value = reader.read(file_queue) # 3、解码操作
# (?, ) (3073, ) = label(1, ) + feature(3072, )
label_image = tf.decode_raw(value, tf.uint8)
# 为了训练方便,一般会把特征值和目标值分开处理
print(label_image) # 使用tf.slice进行切片
label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32) image = tf.slice(label_image, [self.label_bytes], [self.image_bytes]) print(label, image) # 处理类型和图片数据的形状
# 图片形状
# reshape (3072, )----[channel, height, width]
# transpose [channel, height, width] --->[height, width, channel]
depth_major = tf.reshape(image, [self.channel, self.height, self.width])
print(depth_major) image_reshape = tf.transpose(depth_major, [1, 2, 0]) print(image_reshape) # 4、批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10) return image_batch, label_batch def write_to_tfrecords(self, image_batch, label_batch):
"""
将数据写入TFRecords文件
:param image_batch: 特征值
:param label_batch: 目标值
:return:
"""
# 构造TFRecords存储器
writer = tf.python_io.TFRecordWriter("./tmp/cifar.tfrecords") # 循环将每个样本构造成一个example,然后序列化写入
for i in range(10): # 取出相应的第i个样本的特征值和目标值
# 写入的是具体的张量的值,不是OP的名字
# [10, 32, 32, 3]
# [32, 32, 3]值
image = image_batch[i].eval().tostring() # [10, 1]
# 整形类型
label = int(label_batch[i].eval()[0]) # 每个样本的example
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
})) # 写入第i样本的example
writer.write(example.SerializeToString()) writer.close()
return None def read_tfrecords(self):
"""
读取TFRecords文件
:return: None
"""
# 1、构造文件队列
file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"]) # 2、tf.TFRecordReader 读取TFRecords数据并 reader = tf.TFRecordReader() # 默认只读取一个样本
_, value = reader.read(file_queue) # 进行解析example协议
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
}) # 3、解码操作 二进制的格式必须解码
image = tf.decode_raw(feature['image'], tf.uint8) label = tf.cast(feature['label'], tf.int32) # 形状
# [32, 32, 3]--->bytes---> tf.uint8
image_reshape = tf.reshape(image, [self.height, self.width, self.channel]) # 4、批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10) return image_batch, label_batch if __name__ == '__main__':
filename = os.listdir("./data/cifar10/cifar-10-batches-bin/") file_list = [os.path.join("./data/cifar10/cifar-10-batches-bin/", file) for file in filename if file[-3:] == "bin"] # print(file_list)3
cr = CifarRead() # image_batch, label_batch = cr.bytes_read(file_list)
# 读取TFRecords的结果
image_batch, label_batch = cr.read_tfrecords() with tf.Session() as sess: # 创建线程回收的协调员
coord = tf.train.Coordinator() # 需要手动开启子线程去进行批处理读取到队列操作
threads = tf.train.start_queue_runners(sess=sess, coord=coord) print(sess.run([image_batch, label_batch])) # 写入文件
# cr.write_to_tfrecords(image_batch, label_batch) # 回收线程
coord.request_stop() coord.join(threads)
(第二章第四部分)TensorFlow框架之TFRecords数据的存储与读取的更多相关文章
- (第二章第一部分)TensorFlow框架之文件读取流程
本章概述:在第一章的系列文章中介绍了tf框架的基本用法,从本章开始,介绍与tf框架相关的数据读取和写入的方法,并会在最后,用基础的神经网络,实现经典的Mnist手写数字识别. 有四种获取数据到Tens ...
- (第一章第一部分)TensorFlow框架介绍
接下来会更新一系列博客,介绍TensorFlow的入门使用,尽可能详细. 本文概述: 说明TensorFlow的数据流图结构 1.数据流图介绍 TensorFlow是一个采用数据流图(data fl ...
- tensorflow2.0学习笔记第二章第四节
2.4损失函数损失函数(loss):预测值(y)与已知答案(y_)的差距 nn优化目标:loss最小->-mse -自定义 -ce(cross entropy)均方误差mse:MSE(y_,y) ...
- (第二章第三部分)TensorFlow框架之读取二进制数据
系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html (第二章第二部分)Tens ...
- (第二章第二部分)TensorFlow框架之读取图片数据
系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html 本文概述: 目标 说明图片 ...
- (第一章第五部分)TensorFlow框架之变量OP
系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...
- (第一章第六部分)TensorFlow框架之实现线性回归小案例
系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...
- C# Language Specification 5.0 (翻译)第二章 词法结构
程序 C# 程序(program)由至少一个源文件(source files)组成,其正式称谓为编译单元(compilation units)[1].每个源文件都是有序的 Unicode 字符序列.源 ...
- (第一章第四部分)TensorFlow框架之张量
系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...
随机推荐
- 新年好 takoyaki,期待再次与你相见
一.序 今天是中国农历一年的最后一天,往年都叫年三十,今年没有三十,最后一天是二十九.厨房的柴火味.窗外的鞭炮声还有不远处传来的说笑声,一切都是熟悉味道,新年到了,家乡热闹起来了.平常左邻右舍都是看不 ...
- 学习MyBatis必知必会(6)~Mapper基础的拓展
一.typeAlias 类型别名[自定义别名.系统自带别名] 1.类型别名:为 Java 类型设置一个缩写名字. 它仅用于 XML 配置,意在降低冗余的全限定类名书写 2.配置自定义别名: (1)方式 ...
- 使用Hot Chocolate和.NET 6构建GraphQL应用(5) —— 实现Query过滤功能
系列导航 使用Hot Chocolate和.NET 6构建GraphQL应用文章索引 需求 对于查询来说,还有一大需求是针对查询的数据进行过滤,本篇文章我们准备实现GraphQL中基本的查询过滤. 思 ...
- CF1399F Yet Another Segments Subset
首先注意一下题面要求,使得选出的线段两两要么包含要么不相交,也就是说一条线段可能会出现不相交的几条线段,而这些线段上面也可能继续这样包含线段.然后我们可以发现我们要做的实际上是在这条线段上选取几条线段 ...
- NSURLConnection和Runloop(面试)
(1)两种为NSURLConnection设置代理方式的区别 //第一种设置方式: //通过该方法设置代理,会自动的发送请求 // [[NSURLConnection alloc]initWithRe ...
- Html设置文本换行与不按行操作
图片来源:W3C 部分引自大佬:https://zhidao.baidu.com/question/424920602093167052.html 强制不换行 div{ white-space:now ...
- 数值分析:最小二乘与岭回归(Pytorch实现)
Chapter 4 1. 最小二乘和正规方程 1.1 最小二乘的两种视角 从数值计算视角看最小二乘法 我们在学习数值线性代数时,学习了当方程的解存在时,如何找到\(\textbf{A}\bm{x}=\ ...
- Failed to execute goal org.apache.maven.plugins:maven-surefire-plugin:2.22.2:test (default-test) on project gulimall-common: There are test failures.
对Maven打包时碰见的问题: Failed to execute goal org.apache.maven.plugins:maven-surefire-plugin:2.22.2:test (d ...
- day3 -- 集合、文件操作、函数
1.集合:集合无序,不重复,可以用set(列表) 方法将列表转换为集合,实现去重 对比列表:集合是{}包围,列表是[]包围 对比字典:集合是没有key的,字典是有key的 set_1 = {1, 2, ...
- Solution -「Code+#4」「洛谷 P4370」组合数问题 2
\(\mathcal{Description}\) Link. 给定 \(n,k\),求 \(0\le b\le a\le n\) 的 \(\binom{a}{b}\) 的前 \(k\) 大. ...