(第二章第四部分)TensorFlow框架之TFRecords数据的存储与读取
系列博客链接:
(第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html
(第二章第二部分)TensorFlow框架之读取图片数据:https://www.cnblogs.com/kongweisi/p/11050539.html
(第二章第三部分)TensorFlow框架之读取二进制数据:https://www.cnblogs.com/kongweisi/p/11050546.html
本文概述:
- 目标
- 说明Example的结构
- 应用TFRecordWriter实现TFRecords文件存储器的构造
- 应用parse_single_example实现解析Example结构
- 应用
- CIFAR10类图片的数据的TFRecords存储和读取
1、什么是TFRecords文件
TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。
TFRecords文件包含了tf.train.Example
协议内存块(protocol buffer)(协议内存块包含了字段 Features
)。可以获取你的数据, 将数据填入到Example
协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter
写入到TFRecords文件。
- 文件格式 *.tfrecords
2、Example结构解析
tf.train.Example
协议内存块(protocol buffer)(协议内存块包含了字段 Features
),Features
包含了一个feature
字段,feature
中包含要写入的数据、并指明数据类型。这是一个样本的结构,批数据需要循环存入这样的结构
- tf.train.Example(features=None)
- 写入tfrecords文件
- features:tf.train.Features类型的特征实例
- return:example格式协议块
- tf.train.Features(feature=None)
- 构建每个样本的信息键值对
- feature: 字典数据, key为要保存的名字
- value为tf.train.Feature实例
- return:Features类型
- tf.train.Feature(options)
- options:例如
- bytes_list=tf.train. BytesList(value=[Bytes])
- int64_list=tf.train. Int64List(value=[Value])
- 支持存入的类型如下
- tf.train.Int64List(value=[Value])
- tf.train.BytesList(value=[Bytes])
- tf.train.FloatList(value=[value])
- options:例如
上面的三段API是从上到下,一层一层嵌套使用的,如果理解有困难,可以看代码中的用法,相信就可以理解了。
这种结构很好的解决了数据和标签(训练的类别标签)或者其他属性数据存储在同一个文件中
3、案例:CIFAR10数据,存入TFRecords文件
3.1分析
构造存储实例,tf.python_io.TFRecordWriter(path)
- 写入tfrecords文件
- path: TFRecords文件的路径 + 文件名字
- return:写文件
- method
- write(record): 向文件中写入一个example
- close(): 关闭文件写入器
循环将数据填入到
Example
协议内存块(protocol buffer)
3.2代码
def write_to_tfrecords(self, image_batch, label_batch):
"""
将数据存进tfrecords,方便管理每个样本的属性
:param image_batch: 特征值
:param label_batch: 目标值
:return: None
"""
# 1、构造tfrecords的存储实例
writer = tf.python_io.TFRecordWriter("./tmp/cifar.tfrecords") # 路径 + 名字 # 2、循环将每个样本写入到文件当中
for i in range(10): # 一个样本一个样本的处理写入
# 准备特征值,特征值必须是bytes类型 调用tostring()函数
# [10, 32, 32, 3] ,在这里避免tensorflow的坑,取出来的不是真正的值,而是类型,所以要运行结果才能存入
# 出现了eval,那就要在会话当中去运行该行数
image = image_batch[i].eval().tostring() # 准备目标值,目标值是一个Int类型
# eval()-->[6]--->6
label = label_batch[i].eval()[0] # 绑定每个样本的属性
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
})) # 写入每个样本的example,记住这里要将example序列化,成字符串
writer.write(example.SerializeToString()) # 文件需要关闭
writer.close()
return None # 开启会话打印内容
with tf.Session() as sess:
# 创建线程协调器
coord = tf.train.Coordinator() # 开启子线程去读取数据
# 返回子线程实例
threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 获取样本数据去训练
print(sess.run([image_batch, label_batch])) # 存入数据
cr.write_to_tfrecords(image_batch, label_batch ) # 关闭子线程,回收
coord.request_stop() coord.join(threads)
4、读取TFRecords文件
读取这种文件整个过程与其他文件一样,只不过需要有个解析Example的步骤。从TFRecords文件中读取数据, 可以使用tf.TFRecordReader
的tf.parse_single_example
解析器。这个操作可以将Example
协议内存块(protocol buffer)解析为张量。
tf.parse_single_example(serialized, features=None, name=None)
- 解析一个单一的Example原型
- serialized:标量字符串Tensor,一个序列化的Example
- features:dict字典数据,键为读取的名字,值为FixedLenFeature
- return:一个键值对组成的字典,键为读取的名字
tf.FixedLenFeature(shape, dtype)
- shape:输入数据的形状,一般不指定, 为空列表
- dtype:输入数据类型,与存储进文件的类型要一致
- 类型只能是float32, int64, string
5、案例:读取CIFAR的TFRecords文件
5.1 分析
- 使用tf.train.string_input_producer构造文件队列
- tf.TFRecordReader 读取TFRecords数据并进行解析
- tf.parse_single_example进行解析
- tf.decode_raw解码
- 类型是bytes类型需要解码
- 其他类型不需要
- 处理图片数据形状以及数据类型,批处理返回
- 开启会话线程运行
5.2 代码
def read_tfrecords(self):
"""
读取tfrecords的数据
:return: None
"""
# 1、构造文件队列
file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"]) # 2、构造tfrecords读取器,读取队列
reader = tf.TFRecordReader() # 默认也是只读取一个样本
key, values = reader.read(file_queue) # tfrecords
# 多了解析example的一个步骤
feature = tf.parse_single_example(values, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
}) # 取出feature里面的特征值和目标值
# 通过键值对获取
image = feature["image"] label = feature["label"] # 3、解码操作
# 对于image是一个bytes类型,所以需要decode_raw去解码成uint8张量
# 对于Label:本身是一个int类型,不需要去解码
image = tf.decode_raw(image, tf.uint8) print(image, label) # 处理image的形状和类型
image_reshape = tf.reshape(image, [self.height, self.width, self.channel]) # 处理label的形状和类型
label_cast = tf.cast(label, tf.int32) print(image_reshape, label_cast) # 4、批处理操作
image_batch, label_batch = tf.train.batch([image_reshape, label_cast], batch_size=10, num_threads=1, capacity=10) print(image_batch, label_batch)
return image_batch, label_batch # 从tfrecords文件读取数据
image_batch, label_batch = cr.read_tfrecords() # 开启会话打印内容
with tf.Session() as sess:
# 创建线程协调器
coord = tf.train.Coordinator()
- 结合第二章第二部分读取商品图片和第三部分读取二进制数据的完整代码
import tensorflow as tf
import os def picread(file_list):
"""
读取商品图片数据到张量
:param file_list:路径+文件名的列表
:return:
"""
# 构造文件队列
file_queue = tf.train.string_input_producer(file_list) # 利用图片读取器去读取文件队列内容 reader = tf.WholeFileReader()
# 默认一次一张图片,没有形状
_, value = reader.read(file_queue) print(value) # 对图片数据进行解码
# string --> uint8
# () ----> (?, ?, ?)
image = tf.image.decode_jpeg(value) print(image) # 图片的形状固定、大小处理
# 把图片大小固定统一大小(算法训练要求样本的特征值数量一样)
# 固定[200, 200]
image_resize = tf.image.resize_images(image, [200, 200])
print(image_resize) # 设置图片图片形状
image_resize.set_shape([200, 200, 3]) # 进行批处理
# 3D ----> 4D
image_batch = tf.train.batch([image_resize], batch_size=10, num_threads=1, capacity=10) print(image_batch)
return image_batch class CifarRead(object):
"""读取CIFAR10类别的二进制文件
"""
def __init__(self): # 每个图片样本的属性
self.height = 32
self.width = 32
self.channel = 3 # bytes
# 1
self.label_bytes = 1
# 3072
self.image_bytes = self.height * self.width * self.channel
# 3073
self.all_bytes = self.label_bytes + self.image_bytes def bytes_read(self, file_list):
"""读取二进制解码张量
:return:
"""
# 1、构造文件队列
file_queue = tf.train.string_input_producer(file_list) # 2、使用tf.FixedLengthRecordReader(bytes)读取
# 默认必须指定读取一个样本
reader = tf.FixedLengthRecordReader(self.all_bytes) _, value = reader.read(file_queue) # 3、解码操作
# (?, ) (3073, ) = label(1, ) + feature(3072, )
label_image = tf.decode_raw(value, tf.uint8)
# 为了训练方便,一般会把特征值和目标值分开处理
print(label_image) # 使用tf.slice进行切片
label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32) image = tf.slice(label_image, [self.label_bytes], [self.image_bytes]) print(label, image) # 处理类型和图片数据的形状
# 图片形状
# reshape (3072, )----[channel, height, width]
# transpose [channel, height, width] --->[height, width, channel]
depth_major = tf.reshape(image, [self.channel, self.height, self.width])
print(depth_major) image_reshape = tf.transpose(depth_major, [1, 2, 0]) print(image_reshape) # 4、批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10) return image_batch, label_batch def write_to_tfrecords(self, image_batch, label_batch):
"""
将数据写入TFRecords文件
:param image_batch: 特征值
:param label_batch: 目标值
:return:
"""
# 构造TFRecords存储器
writer = tf.python_io.TFRecordWriter("./tmp/cifar.tfrecords") # 循环将每个样本构造成一个example,然后序列化写入
for i in range(10): # 取出相应的第i个样本的特征值和目标值
# 写入的是具体的张量的值,不是OP的名字
# [10, 32, 32, 3]
# [32, 32, 3]值
image = image_batch[i].eval().tostring() # [10, 1]
# 整形类型
label = int(label_batch[i].eval()[0]) # 每个样本的example
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
})) # 写入第i样本的example
writer.write(example.SerializeToString()) writer.close()
return None def read_tfrecords(self):
"""
读取TFRecords文件
:return: None
"""
# 1、构造文件队列
file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"]) # 2、tf.TFRecordReader 读取TFRecords数据并 reader = tf.TFRecordReader() # 默认只读取一个样本
_, value = reader.read(file_queue) # 进行解析example协议
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
}) # 3、解码操作 二进制的格式必须解码
image = tf.decode_raw(feature['image'], tf.uint8) label = tf.cast(feature['label'], tf.int32) # 形状
# [32, 32, 3]--->bytes---> tf.uint8
image_reshape = tf.reshape(image, [self.height, self.width, self.channel]) # 4、批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10) return image_batch, label_batch if __name__ == '__main__':
filename = os.listdir("./data/cifar10/cifar-10-batches-bin/") file_list = [os.path.join("./data/cifar10/cifar-10-batches-bin/", file) for file in filename if file[-3:] == "bin"] # print(file_list)3
cr = CifarRead() # image_batch, label_batch = cr.bytes_read(file_list)
# 读取TFRecords的结果
image_batch, label_batch = cr.read_tfrecords() with tf.Session() as sess: # 创建线程回收的协调员
coord = tf.train.Coordinator() # 需要手动开启子线程去进行批处理读取到队列操作
threads = tf.train.start_queue_runners(sess=sess, coord=coord) print(sess.run([image_batch, label_batch])) # 写入文件
# cr.write_to_tfrecords(image_batch, label_batch) # 回收线程
coord.request_stop() coord.join(threads)
(第二章第四部分)TensorFlow框架之TFRecords数据的存储与读取的更多相关文章
- (第二章第一部分)TensorFlow框架之文件读取流程
本章概述:在第一章的系列文章中介绍了tf框架的基本用法,从本章开始,介绍与tf框架相关的数据读取和写入的方法,并会在最后,用基础的神经网络,实现经典的Mnist手写数字识别. 有四种获取数据到Tens ...
- (第一章第一部分)TensorFlow框架介绍
接下来会更新一系列博客,介绍TensorFlow的入门使用,尽可能详细. 本文概述: 说明TensorFlow的数据流图结构 1.数据流图介绍 TensorFlow是一个采用数据流图(data fl ...
- tensorflow2.0学习笔记第二章第四节
2.4损失函数损失函数(loss):预测值(y)与已知答案(y_)的差距 nn优化目标:loss最小->-mse -自定义 -ce(cross entropy)均方误差mse:MSE(y_,y) ...
- (第二章第三部分)TensorFlow框架之读取二进制数据
系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html (第二章第二部分)Tens ...
- (第二章第二部分)TensorFlow框架之读取图片数据
系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html 本文概述: 目标 说明图片 ...
- (第一章第五部分)TensorFlow框架之变量OP
系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...
- (第一章第六部分)TensorFlow框架之实现线性回归小案例
系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...
- C# Language Specification 5.0 (翻译)第二章 词法结构
程序 C# 程序(program)由至少一个源文件(source files)组成,其正式称谓为编译单元(compilation units)[1].每个源文件都是有序的 Unicode 字符序列.源 ...
- (第一章第四部分)TensorFlow框架之张量
系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...
随机推荐
- Redis 学习笔记(一)redis 数据类型和对象机制
Redis 简介 Redis 是(key-value)的 NoSQL 数据库,所有的 key 都是 String ,它的 value 可以是 String.hash.list.set.zset(有序集 ...
- docker简单介绍。
docker是啥? 一.概念? // 和运维有关的工具,和开发没有很大的关系.只需要去调试项目,将项目运行更迅速. 二.作用? 1.只需要关心项目的编写和调试,不需要关心具体的项目需要运行在哪里,并且 ...
- Linux 集群 和免秘钥登录的方法。
/* 1.1.什么是集群? 很多台服务器(计算机)做相同的事,就称之为集群 服务器和服务器之间必须要处于联通状态(linux01和linux02可以相互访问并且传输数据) 服务器的配置和常见的计算机没 ...
- Objects、Arrays、Collectors、System工具类
Objects类 定义 位于java.util包中,JDK1.7以后操作对象的类,对对象的空,对象是否相等进行判断. 常用方法 1.public static boolean equals(Objec ...
- Mac 屏幕录制Gif 制作 By-胡罗
一.视频录制 1)使用Mac系统自带的QuickTime进行屏幕录像 手动打开(如下图) 详细 Mac 基础教程:如何使用 Mac 系统原生的屏幕录制功能 相关快捷键 option+command+n ...
- linux_19
haproxy https实现 总结tomcat的核心组件以及根目录结构 tomcat实现多虚拟主机 nginx实现后端tomcat的负载均衡调度 简述memcached的工作原理
- java+selenium自动化脚本编写
实训项目:创盟后台管理,页面自动化脚本编写 使用工具:java+selenium 1)java+selenium环境搭建文档 2)创盟项目后台管理系统链接 java+selenium环境搭建 一.Se ...
- Vue2.0源码学习(2) - 数据和模板的渲染(下)
vm._render是怎么实现的 上述updateComponent方法调用是运行了一个函数: // src\core\instance\lifecycle.js updateComponent = ...
- IO_FILE——leak 任意读
在堆题没有show函数时,我们可以用 IO_FILE 进行leak,本文就记录一下如何实现这一手法. 拿一个输出函数 puts 来说,它在源码里的表现形式为 _IO_puts . _IO_puts ( ...
- Spring是什么? 核心总结
Spring是一个开源框架,它由Rod Johnson创建.它是为了解决企业应用开发的复杂性而创建的. Spring使用基本的JavaBean来完成以前只可能由EJB完成的事情. 然而,Spring ...