Tensorflow高效读取数据
关于Tensorflow读取数据,官网给出了三种方法:
- 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
- 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
- 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
在使用Tensorflow训练数据时,第一步为准备数据,现在我们只讨论图像数据。其数据读取大致分为:原图读取、二进制文件读取、tf标准存储文件读取。
一、原图文件读取
很多情况下我们的图片训练集就是原始图片本身,并没有像cifar dataset那样存成bin等格式。我们需要根据一个train_list列表,去挨个读取图片。这里我用到的方法是首先获取image list和labellist,然后读入队列中,那么对每次dequeue的内容中可以提取当前图片的路劲和label。
1、获取文件列表
def get_image_list(fileDir):
imageList = []
labelList = [] filelist = os.listdir(fileDir) for var in filelist:
imagename = os.path.join(fileDir, var)
label = int(os.path.basename(var).split('_')[0])
imageList.append(imagename)
labelList.append(label) return imageList, labelList
上述程序是从指定目录中获取文件列表和标签,其我的文件为
总共15个文件,’_’前为文件标签,记得要转化为int类型,否则后面程序或报错。
2、将文件列表加载到内存列表中,并进行读取
步骤分为:
a、列表转化为tensor类型,并存到内存中
b、 从内存列表中读取数据,进行获取图像和label
c、 根据训练要求对数据进行转化
d、利用batch获取批次文件
def input_data_imageslist_slice(fileDir): # 获取文件列表
imageList , labelList = get_image_list(fileDir) # 将文件列表和标签列表转为为tensor,进而能存入内存列表中,记得label在上面转为int,否则下面会出错,这是相对应的
imagesTensor = tf.convert_to_tensor(imageList, dtype = tf.string)
labelsTensor = tf.convert_to_tensor(labelList, dtype = tf.uint8) # 从内存列表中读取文件,此处只读取一个文件,并记录文件位置
queue = tf.train.slice_input_producer([imagesTensor, labelsTensor]) # 提取图片内容和标签内容,一定注意数据之间的转化;
image_content = tf.read_file(queue[0])
imageData = tf.image.decode_jpeg(image_content,channels=3) #channels必须要制定,当时没指定,程序报错
imageData = tf.image.convert_image_dtype(imageData,tf.uint8) # 图片数据进行转化,此处为了显示而转化
labelData = tf.cast(queue[1],tf.uint8) # show_single_data(imageData, labelData)
#根据数据训练尺寸,调整图片大小,此处设置为32*32
new_size = tf.constant([IMAGE_WIDTH,IMAGE_WIDTH], dtype=tf.int32)
image = tf.image.resize_images(imageData, new_size) # 这是数据提取关键,因为设置了batch_size,决定了每次提取数据的个数,比如此处是3,则每次为3个文件
imageBatch, labelBatch = tf.train.shuffle_batch([image, labelData], batch_size = BATCH_SIZE,
capacity = 2000,min_after_dequeue = 1000) return imageBatch, labelBatch
3、文件测试
在文件测试中,必须添加
threads = tf.train.start_queue_runners(sess = sess),会话窗口才会从内存堆栈中读取数据。
def test_record(filename):
image_batch, label_batch = input_data_imageslist_slice(filename) with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
threads = tf.train.start_queue_runners(sess = sess)
for i in range(5):
val, label = sess.run([image_batch, label_batch])
print(val.shape, label)
我在程序中设置了batchsize为3,所以shape为(3,32,32,)后面则是label,此处很好的读取了数据
(3, 32, 32, 3) [12 3 22]
(3, 32, 32, 3) [ 7 2 22]
(3, 32, 32, 3) [ 2 5 15]
(3, 32, 32, 3) [12 5 13]
(3, 32, 32, 3) [12 6 10]
4、校验对比
为了更好的得知shuffle_batch是否让文件和label对应,程序中进行了修改
image = tf.image.resize_images(imageData, new_size)
修改为:
image = tf.cast(queue[0],tf.string)
还有
print(val.shape, label)
修改为
print(val, label)
结果为:
[b'E:\\010_test_tensorflow\\02_produce_data\\images1\\1_1.jpg'
b'E:\\010_test_tensorflow\\02_produce_data\\images1\\12_03.jpg'
b'E:\\010_test_tensorflow\\02_produce_data\\images1\\10_2.jpg'] [ 1 12 10]
[b'E:\\010_test_tensorflow\\02_produce_data\\images1\\2_2.jpg'
b'E:\\010_test_tensorflow\\02_produce_data\\images1\\22_9.jpg'
b'E:\\010_test_tensorflow\\02_produce_data\\images1\\33_0.jpg'] [ 2 22 33]
[b'E:\\010_test_tensorflow\\02_produce_data\\images1\\33_0.jpg'
b'E:\\010_test_tensorflow\\02_produce_data\\images1\\7_0.jpg'
b'E:\\010_test_tensorflow\\02_produce_data\\images1\\13_06.jpg'] [33 7 13]
我们发现不但image和label相对应,而且还打乱了顺序,真的是很完美啊。
二、TFRecords读取
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch
输入网络进行训练(tip:使用这种方法时,结合yield
使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue
,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecords。
FRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。
TFRecords文件包含了tf.train.Example
协议内存块(protocol buffer)(协议内存块包含了字段 Features
)。我们可以写一段代码获取你的数据, 将数据填入到Example
协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter
写入到TFRecords文件。
从TFRecords文件中读取数据, 可以使用tf.TFRecordReader
的tf.parse_single_example
解析器。这个操作可以将Example
协议内存块(protocol buffer)解析为张量。
2.1 生成TFRecords文件
class SaveRecord(object):
def __init__(self,recordDir, fileDir, imageSize):
self._imageSize = imageSize trainRecord = os.path.join(recordDir,'train.tfrecord')
validRecord = os.path.join(recordDir,'valid.tfrecord') # 获取文件列表
filenames = os.listdir(fileDir)
np.random.shuffle(filenames)
fileNum = len(filenames)
print('the count of images is ' + str(fileNum)) # 获取训练和测试样本,比例为4:1
splitNum = int(fileNum * 0.8)
trainImages = filenames[ : splitNum]
validImages = filenames[splitNum : ] # 保存数据到制定位置
self.save_data_to_record( fileDir = fileDir, datas = trainImages, recordname = trainRecord)
self.save_data_to_record(fileDir = fileDir,datas = validImages, recordname = validRecord) def save_data_to_record(self,fileDir, datas, recordname):
writer = tf.python_io.TFRecordWriter(recordname) for var in datas:
filename = os.path.join(fileDir, var)
label = int(os.path.basename(var).split('_')[0])
image = Image.open(filename) # 打开图片
image = image.resize((self._imageSize,self._imageSize))
imageArray = image.tobytes() #转为bytes example = tf.train.Example(features = tf.train.Features(feature = {
'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [imageArray]))
,'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))}))
writer.write(example.SerializeToString()) writer.close()
编程为一个类,其核心代码在于save_data_to_records,其主要流程为:
- 初始化写入器writer,来源于tf.python_io.TFRecordWriter。
- 遍历传入的数据,可以为文件名,意味后面二进制解析也是文件名
- 解析文件名,获取label,这是之前处理好的
- 利用IPL的Image读入图像数据,预处理数据:调整大小,且转为化二值化数据
- 利用tf中的Example中获取数据,原理是利用字典对应关系,获取features,当然里面有点绕,仔细读读全是在类型转化而已
- example二进制化,然后写入。
- 关闭写入器
其中里面关键点:图片bytes的转化,以及example的赋值。
基本的,一个Example
中包含Features
,Features
里包含Feature
(这里没s)的字典。最后,Feature
里包含有一个 FloatList
, 或者ByteList
,或者Int64List。
2.2 读取record文件
for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
example = tf.train.Example()
example.ParseFromString(serialized_example) image = example.features.feature['image'].bytes_list.value
label = example.features.feature['label'].int64_list.value
# 可以做一些预处理之类的
print image, label
上面为一个解析文件的一个例子,主要是利用example直接进行解析,简单,但是这样比较耗内存,常用的方法是利用文件队列读取。
即利用string_input_produce,结合tf.recordreader进行数据读取,最后进行解析,其例子为:
def read_and_decode(filename):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader()
_, serialized = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized = serialized, features = {
'image' : tf.FixedLenFeature([], tf.string),
'label' : tf.FixedLenFeature([], tf.int64)}) image = tf.decode_raw(features['image'], tf.uint8)
image= tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
# image = tf.cast(features['image'], tf.string)
label = tf.cast(features['label'], tf.int32) img_batch, label_batch = tf.train.shuffle_batch([image, label],
batch_size=BATCH_SIZE, capacity=2000,
min_after_dequeue=1000) return img_batch, label_batch
其流程为:
- record文件放入文件队列中
- 初始化话RecordReader,发现这个reader和writer初始化方式不一样
- 从队列中读取数据,记得返回两个值,我们只要第二个,此时数据为二进制数据(前面我们存入的二进制数据)
- 根据约定解析数据,类型即为之前存储的格式。
- 利用tf的转化获取image和label
- 关键一步就是tf.train.shuffle_batch,利用此函数可以批量获取数据,当然是在文件列表中。
此处对文件列表中数据读取的过程中,我们发现读取器是不一样的。比如此次是读取record的内存文件,代码为:
filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader()
_, serialized = reader.read(filename_queue) #返回文件名和文件
而之前从文件列表和数据列表读取的时候为:
imageList , labelList = get_image_list(fileDir)
queue = tf.train.string_input_producer(imageList) reader = tf.WholeFileReader()
_, image_content = reader.read(queue)
而我们使用的slice_input_produce的时候,变成了tf.read_file,一定记得各个的不同。
# 从内存列表中读取文件,此处只读取一个文件,并记录文件位置
queue = tf.train.slice_input_producer([imagesTensor, labelsTensor]) # 提取图片内容和标签内容,一定注意数据之间的转化;
image_content = tf.read_file(queue[0])
imageData = tf.image.decode_jpeg(image_content,channels=3)
2.3 测试数据
前面我们解析了shuffle_batch的好处,此处我们即检测是否读取了数据。
def test_record(filename):
image_batch, label_batch = read_and_decode(filename) with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
threads = tf.train.start_queue_runners(sess = sess)
for i in range(5):
val, label = sess.run([image_batch, label_batch])
print(val.shape, label)
此时的输出结果为:
(10, 16, 16, 3) [15 7 33 5 4 10 13 7 1 3]
(10, 16, 16, 3) [33 22 10 10 4 12 7 4 13 10]
(10, 16, 16, 3) [10 10 12 7 5 15 12 22 15 5]
(10, 16, 16, 3) [10 10 5 1 12 10 3 5 33 3]
(10, 16, 16, 3) [12 4 7 15 4 7 4 13 5 10]
结果表明有效的对数据进行了读取。
Tensorflow高效读取数据的更多相关文章
- Tensorflow高效读取数据的方法
最新上传的mcnn中有完整的数据读写示例,可以参考. 关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码 ...
- TensorFlow高效读取数据的方法——TFRecord的学习
关于TensorFlow读取数据,官网给出了三种方法: 供给数据(Feeding):在TensorFlow程序运行的每一步,让python代码来供给数据. 从文件读取数据:在TensorFlow图的起 ...
- Tensorflow中使用TFRecords高效读取数据--结合Attention-over-Attention Neural Network for Reading Comprehension
原文链接:https://arxiv.org/pdf/1607.04423.pdf 本片论文主要讲了Attention Model在完形填空类的阅读理解上的应用. 转载:https://blog.cs ...
- tensorflow批量读取数据
Tensorflow 数据读取有三种方式: Preloaded data: 预加载数据,在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况). Feeding: Pyt ...
- "笨方法"学习CNN图像识别(二)—— tfrecord格式高效读取数据
原文地址:https://finthon.com/learn-cnn-two-tfrecord-read-data/-- 全文阅读5分钟 -- 在本文中,你将学习到以下内容: 将图片数据制作成tfre ...
- 吴裕雄 PYTHON 神经网络——TENSORFLOW MNIST读取数据
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("E ...
- 第十二节,TensorFlow读取数据的几种方法以及队列的使用
TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起 ...
- TensorFlow queue多线程读取数据
一.tensorflow读取机制图解 我们必须要把数据先读入后才能进行计算,假设读入用时0.1s,计算用时0.9s,那么就意味着每过1s,GPU都会有0.1s无事可做,这就大大降低了运算的效率. 解决 ...
- tensorflow读取数据的方式
转载:https://blog.csdn.net/u014038273/article/details/77989221 TensorFlow程序读取数据一共有四种方法(一般针对图像): 供给数据(F ...
随机推荐
- SSD报告 - QRadar远程命令执行
SSD报告 - QRadar远程命令执行 漏洞摘要 QRadar中的多个漏洞允许远程未经身份验证的攻击者使产品执行任意命令.每个漏洞本身并不像链接那么强大 - 这允许用户从未经身份验证的访问更改为经过 ...
- Shell-8--数值运算及处理
RANDOM 默认范围是 0~32767
- [EXP]ThinkPHP 5.0.23/5.1.31 - Remote Code Execution
# Exploit Title: ThinkPHP .x < v5.0.23,v5.1.31 Remote Code Execution # Date: -- # Exploit Author: ...
- 可能比文档还详细--VueRouter完全指北
可能比文档还详细--VueRouter完全指北 前言 关于标题,应该算不上是标题党,因为内容真的很多很长很全面.主要是在官网的基础上又详细总结,举例了很多东西.确保所有新人都能理解!所以实际上很多东西 ...
- zookeeper集群操作【这里只说明简单的操作步骤,zk的相关参数、说明请参考官方文档】
本文版权归 远方的风lyh和博客园共有,欢迎转载,但须保留此段声明,并给出原文链接,谢谢合作. [这里是在一台机器上搭建的 zk伪集群] 1.从官网下载下载zk http://apa ...
- Docker容器绑定外部IP和端口
Docker允许通过外部访问容器或者容器之间互联的方式来提供网络服务. 以下操作通过myfirstapp镜像模拟,如何制作myfirstapp镜像请点击此处. 1.外部访问容器容器启动之后,容器中可以 ...
- Android中不能在子线程中更新View视图的原因
这是一条规律,很多coder知道,但原因是什么呢? 如下: When a process is created for your application, its main thread is ded ...
- Jenkins : 安装 master 和 slave
目录 安装 master 安装 slave 设置 master 与 slave 的通信方式 添加 slave 配置 在 salve 上安装 jre 安装并配置 Jenkins salve Jenkin ...
- docker使用技巧小记
1.在使用docker的时候有很多人习惯使用官方镜像.有的人喜欢自己制作镜像,有的时候都是使用默认的配置启动的服务,或者自己在制作镜像的时候直接将配置文件打包到镜像里面了.有的时候会碰到要修改配置文件 ...
- checkbox在vue中的用法总结
前言 关于checkbox多选框是再常见不过的了,几乎很多地方都会用到,这两天在使用vue框架时需要用到checkbox多选功能,实在着实让我头疼,vue和原生checkbox用法不太一样, 之前对于 ...