机器学习: TensorFlow 的数据读取与TFRecords 格式
最近学习tensorflow,发现其读取数据的方式看起来有些不同,所以又重新系统地看了一下文档,总得来说,tensorflow 有三种主流的数据读取方式:
1) 传送 (feeding): Python 可以在程序的运行过程中,将数据传送进定义好的 tensor 变量中
2) 从文件读取 (reading from files): 一个输入流从文件中直接读取数据
3) 预加载数据 (preloaded data): 这个很好理解,就是将所有的数据一次性全部读进内存里。
对于第三种方式,在数据量小的时候,是非常高效的,但是如果数据量很大的时候,这种方法显然非常耗内存,所以在数据量很大的时候,一般选择第二种读取方式,即从文件读取。在利用第二种方式读取的时候,我们常常会用到一种 TFRecords 的格式来保存读取的文件。TFRecords 是一种二进制文件。可以在TensorFlow 中方便的进行各种存取操作以及预处理。
我们先来看看,如何将一张图片转换成字符流
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import skimage.io as io
dir_path = 'Face'
file_list = os.listdir(dir_path)
print file_list
for f in file_list:
print (dir_path + os.sep + f)
img_1 = io.imread(dir_path + os.sep + file_list[0])
#plt.imshow(img_1, cmap='gray')
#plt.show()
# 将图像转换成字符
img_str = img_1.tostring()
# 将字符流还原成图像
img_rec_vec = np.fromstring(img_str, dtype=np.uint8)
img_rec = img_rec_vec.reshape(img_1.shape)
#plt.imshow(img_rec, cmap='gray')
#plt.show()
接下来,我们看看,如何生成 TFRecords 文件:
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
tfrecords_filename = 'Face.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
for img_path in file_list:
img = np.array(io.imread(dir_path + os.sep + img_path))
# 从文件夹里读取图像
# 获取图像的宽和高,图像的维度需要存入 TFRecords 文件中
# 以方便后续的处理
#
height = img.shape[0]
width = img.shape[1]
# 将图像转换成字符流
img_raw = img.tostring()
# 将字符流以及图像的尺度信息存入TFRecords 文件
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'image_raw': _bytes_feature(img_raw)}))
writer.write(example.SerializeToString())
writer.close()
最后,我们看看如何从 TFrecords 文件中读数据,并且做批处理:
# 可以重新定义图像的宽和高,
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
# 定义读取与解码函数
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 获取 features,包含图像,以及图像宽和高
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
})
# 获取图像信息
image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
# 将图像转换成多维数组的形式
image_shape = [height, width, 1]
image = tf.reshape(image, image_shape)
# 重新定义图像的尺度
image_size_const = tf.constant((IMAGE_HEIGHT, IMAGE_WIDTH, 1), dtype=tf.int32)
# Random transformations can be put here: right before you crop images
# to predefined size. To get more information look at the stackoverflow
# question linked above.
# 对图像进行预处理,包括裁剪,增边等
resized_image = tf.image.resize_image_with_crop_or_pad(image=image,
target_height=IMAGE_HEIGHT,
target_width=IMAGE_WIDTH)
return resized_image
#
filename_queue = tf.train.string_input_producer(
[tfrecords_filename], num_epochs=10)
# Even when reading in multiple threads, share the filename
# queue.
train_images = read_and_decode(filename_queue)
# 要注意 min_after_dequeue 不能超过 capacity
image = tf.train.shuffle_batch([train_images],
batch_size=1,
capacity=5,
num_threads=2,
min_after_dequeue=1)
# The op for initializing the variables.
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Let's read off 3 batches just for example
for i in xrange(1):
img = sess.run([image])
img_batch = img[0]
img_1 = tf.reshape(img_batch[0, :, :, :], [IMAGE_HEIGHT, IMAGE_WIDTH])
print (img_1.shape)
plt.imshow(sess.run(img_1), cmap='gray')
# coord.request_stop()
# coord.join(threads)
plt.show()
print 'all is well'
参考来源:
http://codecloud.net/16485.html
http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
https://www.tensorflow.org/programmers_guide/reading_data
机器学习: TensorFlow 的数据读取与TFRecords 格式的更多相关文章
- TensorFlow中数据读取之tfrecords
关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow ...
- 由浅入深之Tensorflow(3)----数据读取之TFRecords
转载自http://blog.csdn.net/u012759136/article/details/52232266 原文作者github地址 概述 关于Tensorflow读取数据,官网给出了三种 ...
- DataTable to Excel(使用NPOI、EPPlus将数据表中的数据读取到excel格式内存中)
/// <summary> /// DataTable to Excel(将数据表中的数据读取到excel格式内存中) /// </summary> /// <param ...
- tensorflow之数据读取探究(1)
Tensorflow中之前主要用的数据读取方式主要有: 建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用.使用这种方法十分灵活,可以一下子将所有数据 ...
- 关于Tensorflow 的数据读取环节
Tensorflow读取数据的一般方式有下面3种: preloaded直接创建变量:在tensorflow定义图的过程中,创建常量或变量来存储数据 feed:在运行程序时,通过feed_dict传入数 ...
- Tensorflow学习-数据读取
Tensorflow数据读取方式主要包括以下三种 Preloaded data:预加载数据 Feeding: 通过Python代码读取或者产生数据,然后给后端 Reading from file: 通 ...
- 『TensorFlow』数据读取类_data.Dataset
一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...
- tensorflow之数据读取探究(2)
tensorflow之tfrecord数据读取 Tensorflow关于TFRecord格式文件的处理.模型的训练的架构为: 1.获取文件列表.创建文件队列:http://blog.csdn.net/ ...
- TensorFlow的数据读取机制
一.tensorflow读取机制图解 首先需要思考的一个问题是,什么是数据读取?以图像数据为例,读取的过程可以用下图来表示 假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003 ...
随机推荐
- 关于hive里安装mysql出现错误,如何删除指定的主机或用户?(解决Access denied)
前期博客 你可以按照我写的这篇博客去,按照hive的mysql. 1 复习ha相关 + weekend110的hive的元数据库mysql方式安装配置(完全正确配法)(CentOS版本)(包含卸载系统 ...
- Android 为什么要有handler机制?handler机制的原理
为什么要有handler机制? 在Android的UI开发中,我们经常会使用Handler来控制主UI程序的界面变化.有关Handler的作用,我们总结为:与其他线程协同工作,接收其他线程的消息并通过 ...
- 关于python中数组的问题,序列格式转换
https://blog.csdn.net/sinat_34474705/article/details/74458605?utm_source=blogxgwz1 https://www.cnblo ...
- Android圆环控件
Android圆环控件 近期在做一个功能.界面效果要求例如以下: 看到这个界面,我首先想到了曾经在做phone模块的时候,我们定制的来电界面InCallTouchUi,界面效果是相似的. 来电控件使用 ...
- jemter--录制的脚本设置循环次数不起作用
以下是比较jmeter线程组中设置循环次数和循环控制器中设置循环次数的区别 1.jmeter生成的脚本没有step1(循环控制器)控制器,故循环在线程组中设置 2.badboy录制的脚本有setp ...
- 【Codeforces Round #442 (Div. 2) D】Olya and Energy Drinks
[链接] 我是链接,点我呀:) [题意] 给一张二维点格图,其中有一些点可以走,一些不可以走,你每次可以走1..k步,问你起点到终点的最短路. [题解] 不能之前访问过那个点就不访问了.->即k ...
- Android新控件RecyclerView剖析
传智·没羽箭(传智播客北京校区Java学院高级讲师) 个人简单介绍:APKBUS专家之中的一个,黑马技术沙龙会长,在移动领域有多年的实际开发和研究经验.精通HTML5.Oracle.J2EE .Jav ...
- cocos2D(一)----第一个cocos2D程序
简单介绍 我们这个专题要学习的是一款iOS平台的2D游戏引擎cocos2d.严格来说叫做cocos2d-iphone,由于cocos2d有非常多个版本号.我们学习的是iphone版本号的.既然是个游戏 ...
- php实现 24点游戏算法
php实现 24点游戏算法 一.总结 一句话总结:把多元运算转化为两元运算,先从四个数中取出两个数进行运算,然后把运算结果和第三个数进行运算,再把结果与第四个数进行运算.在求表达式的过程中,最难处理的 ...
- 每日技术总结:setInterval,setTimeout,文本溢出,小程序,wepy
前言: 项目背景:vue,电商,商品详情页 1.倒计时,倒计到0秒时停止 data () { return { n: 10 } }, created () { let int = setInterva ...