系列博客链接:

第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html

第二章第二部分)TensorFlow框架之读取图片数据:https://www.cnblogs.com/kongweisi/p/11050539.html

第二章第三部分)TensorFlow框架之读取二进制数据https://www.cnblogs.com/kongweisi/p/11050546.html

本文概述:

  • 目标

    • 说明Example的结构
    • 应用TFRecordWriter实现TFRecords文件存储器的构造
    • 应用parse_single_example实现解析Example结构
  • 应用
    • CIFAR10类图片的数据的TFRecords存储和读取

1、什么是TFRecords文件

TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。可以获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件

  • 文件格式 *.tfrecords

2、Example结构解析

tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features),Features包含了一个feature字段,feature中包含要写入的数据、并指明数据类型。这是一个样本的结构,批数据需要循环存入这样的结构

  • tf.train.Example(features=None)

    • 写入tfrecords文件
    • features:tf.train.Features类型的特征实例
    • return:example格式协议块
  • tf.train.Features(feature=None)
    • 构建每个样本的信息键值对
    • feature: 字典数据, key为要保存的名字
    • value为tf.train.Feature实例
    • return:Features类型
  • tf.train.Feature(options)
    • options:例如

      • bytes_list=tf.train. BytesList(value=[Bytes])
      • int64_list=tf.train. Int64List(value=[Value])
    • 支持存入的类型如下
    • tf.train.Int64List(value=[Value])
    • tf.train.BytesList(value=[Bytes])
    • tf.train.FloatList(value=[value])

上面的三段API是从上到下,一层一层嵌套使用的,如果理解有困难,可以看代码中的用法,相信就可以理解了。

这种结构很好的解决了数据和标签(训练的类别标签)或者其他属性数据存储在同一个文件中

3、案例:CIFAR10数据,存入TFRecords文件

3.1分析

  • 构造存储实例,tf.python_io.TFRecordWriter(path)

    • 写入tfrecords文件
    • path: TFRecords文件的路径 + 文件名字
    • return:写文件
    • method
    • write(record): 向文件中写入一个example
    • close(): 关闭文件写入器
  • 循环将数据填入到Example协议内存块(protocol buffer)

3.2代码

  1. def write_to_tfrecords(self, image_batch, label_batch):
  2. """
  3. 将数据存进tfrecords,方便管理每个样本的属性
  4. :param image_batch: 特征值
  5. :param label_batch: 目标值
  6. :return: None
  7. """
  8. # 1、构造tfrecords的存储实例
  9. writer = tf.python_io.TFRecordWriter("./tmp/cifar.tfrecords") # 路径 + 名字
  10.  
  11. # 2、循环将每个样本写入到文件当中
  12. for i in range(10):
  13.  
  14. # 一个样本一个样本的处理写入
  15. # 准备特征值,特征值必须是bytes类型 调用tostring()函数
  16. # [10, 32, 32, 3] ,在这里避免tensorflow的坑,取出来的不是真正的值,而是类型,所以要运行结果才能存入
  17. # 出现了eval,那就要在会话当中去运行该行数
  18. image = image_batch[i].eval().tostring()
  19.  
  20. # 准备目标值,目标值是一个Int类型
  21. # eval()-->[6]--->6
  22. label = label_batch[i].eval()[0]
  23.  
  24. # 绑定每个样本的属性
  25. example = tf.train.Example(features=tf.train.Features(feature={
  26. "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
  27. "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
  28. }))
  29.  
  30. # 写入每个样本的example,记住这里要将example序列化,成字符串
  31. writer.write(example.SerializeToString())
  32.  
  33. # 文件需要关闭
  34. writer.close()
  35. return None
  36.  
  37. # 开启会话打印内容
  38. with tf.Session() as sess:
  39. # 创建线程协调器
  40. coord = tf.train.Coordinator()
  41.  
  42. # 开启子线程去读取数据
  43. # 返回子线程实例
  44. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  45.  
  46. # 获取样本数据去训练
  47. print(sess.run([image_batch, label_batch]))
  48.  
  49. # 存入数据
  50. cr.write_to_tfrecords(image_batch, label_batch )
  51.  
  52. # 关闭子线程,回收
  53. coord.request_stop()
  54.  
  55. coord.join(threads)

4、读取TFRecords文件

读取这种文件整个过程与其他文件一样,只不过需要有个解析Example的步骤。从TFRecords文件中读取数据, 可以使用tf.TFRecordReadertf.parse_single_example解析器。这个操作可以Example协议内存块(protocol buffer)解析为张量

  • tf.parse_single_example(serialized, features=None, name=None)

    • 解析一个单一的Example原型
    • serialized:标量字符串Tensor,一个序列化的Example
    • features:dict字典数据,键为读取的名字,值为FixedLenFeature
    • return:一个键值对组成的字典,键为读取的名字
  • tf.FixedLenFeature(shape, dtype)

    • shape:输入数据的形状,一般不指定, 为空列表
    • dtype:输入数据类型,与存储进文件的类型要一致
    • 类型只能是float32, int64, string

5、案例:读取CIFAR的TFRecords文件

5.1 分析

  • 使用tf.train.string_input_producer构造文件队列
  • tf.TFRecordReader 读取TFRecords数据并进行解析
    • tf.parse_single_example进行解析
  • tf.decode_raw解码
    • 类型是bytes类型需要解码
    • 其他类型不需要
  • 处理图片数据形状以及数据类型,批处理返回
  • 开启会话线程运行

5.2 代码

  1. def read_tfrecords(self):
  2. """
  3. 读取tfrecords的数据
  4. :return: None
  5. """
  6. # 1、构造文件队列
  7. file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"])
  8.  
  9. # 2、构造tfrecords读取器,读取队列
  10. reader = tf.TFRecordReader()
  11.  
  12. # 默认也是只读取一个样本
  13. key, values = reader.read(file_queue)
  14.  
  15. # tfrecords
  16. # 多了解析example的一个步骤
  17. feature = tf.parse_single_example(values, features={
  18. "image": tf.FixedLenFeature([], tf.string),
  19. "label": tf.FixedLenFeature([], tf.int64)
  20. })
  21.  
  22. # 取出feature里面的特征值和目标值
  23. # 通过键值对获取
  24. image = feature["image"]
  25.  
  26. label = feature["label"]
  27.  
  28. # 3、解码操作
  29. # 对于image是一个bytes类型,所以需要decode_raw去解码成uint8张量
  30. # 对于Label:本身是一个int类型,不需要去解码
  31. image = tf.decode_raw(image, tf.uint8)
  32.  
  33. print(image, label)
  34.  
  35. # 处理image的形状和类型
  36. image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
  37.  
  38. # 处理label的形状和类型
  39. label_cast = tf.cast(label, tf.int32)
  40.  
  41. print(image_reshape, label_cast)
  42.  
  43. # 4、批处理操作
  44. image_batch, label_batch = tf.train.batch([image_reshape, label_cast], batch_size=10, num_threads=1, capacity=10)
  45.  
  46. print(image_batch, label_batch)

  47. return image_batch, label_batch
  48.  
  49. # 从tfrecords文件读取数据
  50. image_batch, label_batch = cr.read_tfrecords()
  51.  
  52. # 开启会话打印内容
  53. with tf.Session() as sess:
  54. # 创建线程协调器
  55. coord = tf.train.Coordinator()
  • 结合第二章第二部分读取商品图片和第三部分读取二进制数据完整代码
  1. import tensorflow as tf
  2. import os
  3.  
  4. def picread(file_list):
  5. """
  6. 读取商品图片数据到张量
  7. :param file_list:路径+文件名的列表
  8. :return:
  9. """
  10. # 构造文件队列
  11. file_queue = tf.train.string_input_producer(file_list)
  12.  
  13. # 利用图片读取器去读取文件队列内容
  14.  
  15. reader = tf.WholeFileReader()
  16. # 默认一次一张图片,没有形状
  17. _, value = reader.read(file_queue)
  18.  
  19. print(value)
  20.  
  21. # 对图片数据进行解码
  22. # string --> uint8
  23. # () ----> (?, ?, ?)
  24. image = tf.image.decode_jpeg(value)
  25.  
  26. print(image)
  27.  
  28. # 图片的形状固定、大小处理
  29. # 把图片大小固定统一大小(算法训练要求样本的特征值数量一样)
  30. # 固定[200, 200]
  31. image_resize = tf.image.resize_images(image, [200, 200])
  32. print(image_resize)
  33.  
  34. # 设置图片图片形状
  35. image_resize.set_shape([200, 200, 3])
  36.  
  37. # 进行批处理
  38. # 3D ----> 4D
  39. image_batch = tf.train.batch([image_resize], batch_size=10, num_threads=1, capacity=10)
  40.  
  41. print(image_batch)
  42. return image_batch
  43.  
  44. class CifarRead(object):
  45. """读取CIFAR10类别的二进制文件
  46. """
  47. def __init__(self):
  48.  
  49. # 每个图片样本的属性
  50. self.height = 32
  51. self.width = 32
  52. self.channel = 3
  53.  
  54. # bytes
  55. # 1
  56. self.label_bytes = 1
  57. # 3072
  58. self.image_bytes = self.height * self.width * self.channel
  59. # 3073
  60. self.all_bytes = self.label_bytes + self.image_bytes
  61.  
  62. def bytes_read(self, file_list):
  63. """读取二进制解码张量
  64. :return:
  65. """
  66. # 1、构造文件队列
  67. file_queue = tf.train.string_input_producer(file_list)
  68.  
  69. # 2、使用tf.FixedLengthRecordReader(bytes)读取
  70. # 默认必须指定读取一个样本
  71. reader = tf.FixedLengthRecordReader(self.all_bytes)
  72.  
  73. _, value = reader.read(file_queue)
  74.  
  75. # 3、解码操作
  76. # (?, ) (3073, ) = label(1, ) + feature(3072, )
  77. label_image = tf.decode_raw(value, tf.uint8)
  78. # 为了训练方便,一般会把特征值和目标值分开处理
  79. print(label_image)
  80.  
  81. # 使用tf.slice进行切片
  82. label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
  83.  
  84. image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
  85.  
  86. print(label, image)
  87.  
  88. # 处理类型和图片数据的形状
  89. # 图片形状
  90. # reshape (3072, )----[channel, height, width]
  91. # transpose [channel, height, width] --->[height, width, channel]
  92. depth_major = tf.reshape(image, [self.channel, self.height, self.width])
  93. print(depth_major)
  94.  
  95. image_reshape = tf.transpose(depth_major, [1, 2, 0])
  96.  
  97. print(image_reshape)
  98.  
  99. # 4、批处理
  100. image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
  101.  
  102. return image_batch, label_batch
  103.  
  104. def write_to_tfrecords(self, image_batch, label_batch):
  105. """
  106. 将数据写入TFRecords文件
  107. :param image_batch: 特征值
  108. :param label_batch: 目标值
  109. :return:
  110. """
  111. # 构造TFRecords存储器
  112. writer = tf.python_io.TFRecordWriter("./tmp/cifar.tfrecords")
  113.  
  114. # 循环将每个样本构造成一个example,然后序列化写入
  115. for i in range(10):
  116.  
  117. # 取出相应的第i个样本的特征值和目标值
  118. # 写入的是具体的张量的值,不是OP的名字
  119. # [10, 32, 32, 3]
  120. # [32, 32, 3]值
  121. image = image_batch[i].eval().tostring()
  122.  
  123. # [10, 1]
  124. # 整形类型
  125. label = int(label_batch[i].eval()[0])
  126.  
  127. # 每个样本的example
  128. example = tf.train.Example(features=tf.train.Features(feature={
  129. "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
  130. "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
  131. }))
  132.  
  133. # 写入第i样本的example
  134. writer.write(example.SerializeToString())
  135.  
  136. writer.close()
  137. return None
  138.  
  139. def read_tfrecords(self):
  140. """
  141. 读取TFRecords文件
  142. :return: None
  143. """
  144. # 1、构造文件队列
  145. file_queue = tf.train.string_input_producer(["./tmp/cifar.tfrecords"])
  146.  
  147. # 2、tf.TFRecordReader 读取TFRecords数据并
  148.  
  149. reader = tf.TFRecordReader()
  150.  
  151. # 默认只读取一个样本
  152. _, value = reader.read(file_queue)
  153.  
  154. # 进行解析example协议
  155. feature = tf.parse_single_example(value, features={
  156. "image": tf.FixedLenFeature([], tf.string),
  157. "label": tf.FixedLenFeature([], tf.int64)
  158. })
  159.  
  160. # 3、解码操作 二进制的格式必须解码
  161. image = tf.decode_raw(feature['image'], tf.uint8)
  162.  
  163. label = tf.cast(feature['label'], tf.int32)
  164.  
  165. # 形状
  166. # [32, 32, 3]--->bytes---> tf.uint8
  167. image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
  168.  
  169. # 4、批处理
  170. image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
  171.  
  172. return image_batch, label_batch
  173.  
  174. if __name__ == '__main__':
  175. filename = os.listdir("./data/cifar10/cifar-10-batches-bin/")
  176.  
  177. file_list = [os.path.join("./data/cifar10/cifar-10-batches-bin/", file) for file in filename if file[-3:] == "bin"]
  178.  
  179. # print(file_list)3
  180. cr = CifarRead()
  181.  
  182. # image_batch, label_batch = cr.bytes_read(file_list)
  183. # 读取TFRecords的结果
  184. image_batch, label_batch = cr.read_tfrecords()
  185.  
  186. with tf.Session() as sess:
  187.  
  188. # 创建线程回收的协调员
  189. coord = tf.train.Coordinator()
  190.  
  191. # 需要手动开启子线程去进行批处理读取到队列操作
  192. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  193.  
  194. print(sess.run([image_batch, label_batch]))
  195.  
  196. # 写入文件
  197. # cr.write_to_tfrecords(image_batch, label_batch)
  198.  
  199. # 回收线程
  200. coord.request_stop()
  201.  
  202. coord.join(threads)

(第二章第四部分)TensorFlow框架之TFRecords数据的存储与读取的更多相关文章

  1. (第二章第一部分)TensorFlow框架之文件读取流程

    本章概述:在第一章的系列文章中介绍了tf框架的基本用法,从本章开始,介绍与tf框架相关的数据读取和写入的方法,并会在最后,用基础的神经网络,实现经典的Mnist手写数字识别. 有四种获取数据到Tens ...

  2. (第一章第一部分)TensorFlow框架介绍

    接下来会更新一系列博客,介绍TensorFlow的入门使用,尽可能详细. 本文概述: 说明TensorFlow的数据流图结构 1.数据流图介绍  TensorFlow是一个采用数据流图(data fl ...

  3. tensorflow2.0学习笔记第二章第四节

    2.4损失函数损失函数(loss):预测值(y)与已知答案(y_)的差距 nn优化目标:loss最小->-mse -自定义 -ce(cross entropy)均方误差mse:MSE(y_,y) ...

  4. (第二章第三部分)TensorFlow框架之读取二进制数据

    系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html (第二章第二部分)Tens ...

  5. (第二章第二部分)TensorFlow框架之读取图片数据

    系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html 本文概述: 目标 说明图片 ...

  6. (第一章第五部分)TensorFlow框架之变量OP

    系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...

  7. (第一章第六部分)TensorFlow框架之实现线性回归小案例

    系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...

  8. C# Language Specification 5.0 (翻译)第二章 词法结构

    程序 C# 程序(program)由至少一个源文件(source files)组成,其正式称谓为编译单元(compilation units)[1].每个源文件都是有序的 Unicode 字符序列.源 ...

  9. (第一章第四部分)TensorFlow框架之张量

    系列博客链接: (一)TensorFlow框架介绍:https://www.cnblogs.com/kongweisi/p/11038395.html (二)TensorFlow框架之图与Tensor ...

随机推荐

  1. 如何在pyqt中实现窗口磨砂效果

    磨砂效果的实现思路 这两周一直在思考怎么在pyqt上实现窗口磨砂效果,网上搜了一圈,全都是 C++ 的实现方法.正好今天查python的官方文档的时候看到了 ctypes 里面的 HWND,想想倒不如 ...

  2. Java BigDecimal 的舍入模式(RoundingMode)详解

    BigDecimal.divide方法中必须设置roundingMode,不然会报错. ROUND_UP:向正无穷方向对齐(转换为正无穷方向最接近的所需数值) ROUND_DOWN:向负无穷方向对齐 ...

  3. 新一代Python包管理工具来了

    1 简介 说起Python的包管理工具,大家第一时间想到的肯定是pip.conda等经典工具.但最近我发现了一款新颖的Python包管理工具--pdm,它受到PEP582(https://www.py ...

  4. 「JOISC 2014 Day1」 历史研究

    「JOISC 2014 Day1」 历史研究 Solution 子任务2 暴力,用\(cnt\)记录每种权值出现次数. 子任务3 这不是一个尺取吗... 然后用multiset维护当前的区间,动态加, ...

  5. Swift数组

    数组的介绍 数组(Array)是一串有序的由相同类型元素构成的集合 数组中的集合元素是有序的,可以重复出现 Swift中的数组 swift数组类型是Array,是一个泛型集合 数组的初始化 数组分成: ...

  6. 通过bindservice方式调用服务方法里面的过程

    为什么要引入bindService:目的为了调用服务里面的方法 (1)定义一个服务 服务里面有一个方法需要Activity调用 (2)定义一个中间人对象(IBinder) 继承Binder (3)在o ...

  7. DbUnit入门实战

    原文地址: http://yangzb.iteye.com/blog/947292 相信做过单元测试的人都会对JUnit 非常的熟悉了, 今天要介绍的DbUnit(http://dbunit.sour ...

  8. TestNG--@Factory

    原文地址:http://blog.csdn.net/wanghantong TestNg的@Factory注解从字面意思上来讲就是采用工厂的方法来创建测试数据并配合完成测试 其主要应对的场景是:对于某 ...

  9. Netty入门使用教程

    原创:转载需注明原创地址 https://www.cnblogs.com/fanerwei222/p/11827026.html 本文介绍Netty的使用, 结合我本人的一些理解和操作来快速的让初学者 ...

  10. Apache——网页优化与安全

    Apache--网页优化与安全 1.Apache 网页优化概述 2.网页压缩 3.网页缓存 4.隐藏版本信息 5.Apache 防盗链 1.Apache 网页优化概述: 企业中,部署Apache后只采 ...