Tensorflow数据读取机制
展示如何将数据输入到计算图中
Dataset
可以看作是相同类型“元素”的有序列表,在实际使用时,单个元素可以是向量、字符串、图片甚至是tuple或dict。
数据集对象实例化:
dataset=tf.data.Dataset.from_tensor_slice(<data>)
迭代器对象实例化:
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
读取结束异常:如果一个dataset
中的元素被读取完毕,再尝试sess.run(one_element)
的话,会抛出tf.errors.OutOfRangeError
异常,这个行为与使用队列方式读取数据是一致的。
高维数据集的使用
tf.data.Dataset.from_tensor_slices
真正作用是切分传入Tensor的第一个维度,生成相应的dataset,即第一维表明数据集中数据的数量,之后切分batch等操作均以第一维为基础。
dataset=tf.data.Dataset.from_tensor_slices(np.random.uniform((5,2)))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session(config=config) as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError as e:
print('end~')
输出:
[0.1,0.2]
[0.3,0.2]
[0.1,0.6]
[0.4,0.3]
[0.5,0.2]
tuple组合数据
dataset=tf.data.Dataset.from_tensor_slices((np.array([1.,2.,3.,4.,5.]),
np.random.uniform(size=(5,2))))
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run(one_element))
except tf.errors.OutOfRangeError:
print('end~')
输出:
(1.,array(0.1,0.3))
(2.,array(0.2,0.4))
...
数据集处理方法
Dataset
支持一类特殊操作:Transformation
。一个Dataset
通过Transformation
变成一个新的Dataset
。常用的Transformation
:
map
batch
shuffle
repeat
其中,
map
和python中的map
一致,接受一个函数,Dataset
中的每个元素都会作为这个函数的输入,并将函数返回值作为新的Dataset
dataset=dataset.map(lambda x:x+1)
注意:
map
函数可以使用num_parallel_calls
参数并行化batch
就是将多个元素组成batch。dataset=tf.data.Dataset.from_tensor_slices(
{
'a':np.array([1.,2.,3.,4.,5.]),
'b':np.random.uniform(size=(5,2))
})
###
dataset=dataset.batch(2) # batch_size=2
###
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(one_element)
except tf.errors.OutOfRangeError:
print('end~')
输出:
{'a':array([1.,2.]),'b':array([[1.,2.],[3.,4.]])}
{'a':array([3.,4.]),'b':array([[5.,6.],[7.,8.]])}
shuffle
的功能是打乱dataset
中的元素,它有个参数buffer_size
,表示打乱时使用的buffer
的大小,不应设置过小,推荐值1000.dataset=tf.data.Dataset.from_tensor_slices(
{
'a':np.array([1.,2.,3.,4.,5.]),
'b':np.random.uniform(size=(5,2))
})
###
dataset=dataset.shuffle(buffer_size=5)
###
iterator=dataset.make_one_shot_iterator()
one_element=iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(one_element)
except tf.errors.OutOfRangeError:
print('end~')
repeat
的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch
。假设原先的数据是一个epoch
,使用repeat(2)
可以使之变成2个epoch.dataset=tf.data.Dataset.from_tensor_slices({
'a':np.array([1.,2.,3.,4.,5.]),
'b':np.random.uniform(size=(5,2))
})
###
dataset=dataset.repeat(2) # 2epoch
###
# iterator, one_element...
注意:如果直接调用
repeat()
函数的话,生成的序列会无限重复下去,没有结果,因此不会抛出tf.errors.OutOfRangeError
异常。
模拟读入磁盘图片及其Label示例
def _parse_function(filename,label): # 接受单个元素,转换为目标
img_string=tf.read_file(filename)
img_decoded=tf.image.decode_images(img_string)
img_resized=tf.image.resize_images(image_decoded,[28,28])
return image_resized,label
filenames=tf.constant(['data/img1.jpg','data/img2.jpg',...])
labels=tf.constant([1,3,...])
dataset=tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset=dataset.map(_parse_function) # num_parallel_calls 并行
dataset=dataset.shuffle(buffer_size=1000).batch_size(32).repeat(10)
更多Dataset创建方法
tf.data.TextLineDataset()
:函数输入一个文件列表,输出一个Dataset。dataset中的每一个元素对应文件中的一行,可以使用该方法读入csv文件。tf.data.FixedLengthRecordDataset()
:函数输入一个文件列表和record_bytes
参数,dataset中每一个元素是文件中固定字节数record_bytes
的内容,可用来读取二进制保存的文件,如CIFAR10。tf.data.TFRecordDataset()
:读取TFRecord文件,dataset中每一个元素是一个TFExample。
更多Iterator创建方法
最简单的创建Iterator
方法是通过dataset.make_one_shot_iterator()
创建一个iterator。
除了这种iterator之外,还有更复杂的Iterator:
- initializable iterator
- reinitializable iterator
- feedable iterator
其中,initializable iterator方法要在使用前通过sess.run()
进行初始化,initializable iterator还可用于读入较大数组。在使用tf.data.Dataset.from_tensor_slices(array)
时,实际上发生的事情是将array作为一个tf.constants
保存到了计算图中,当array很大时,会导致计算图变得很大,给传输保存带来不便,这时可以使用一个placeholder
取代这里的array,并使用initializable iterator,只在需要时将array传进去,这样即可避免将大数组保存在图里。
features_placeholder=tf.placeholder(<features.dtype>,<features.shape>)
labels_placeholder=tf.placeholder(<labels.dtype>,<labels.shape>)
dataset=tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
sess.run(iterator.initializer,feed_dict={features_placeholder:features,labels_placeholder:labels})
Tensorflow内部读取机制
对于文件名队列,使用tf.train.string_input_producer()
函数,tf.train.string_input_producer()
还有两个重要参数,num_epoches
和shuffle
内存队列不需要我们建立,只需要使用reader
对象从文件名队列中读取数据即可,使用tf.train.start_queue_runners()
函数启动队列,填充两个队列的数据。
with tf.Session() as sess:
filenames=['A.jpg','B.jpg','C.jpg']
filename_queue=tf.train.string_input_producer(filenames,shuffle=True,num_epoch=5)
reader=tf.WholeFileReader()
key,value=reader.read(filename_queue)
# tf.train.string_input_producer()定义了一个epoch变量,需要对其进行初始化
tf.local_variables_initializer().run()
threads=tf.train.start_queue_runners(sess=sess)
i=0
while True:
i+=1
image_data=sess.run(value)
with open('reader/test_%d.jpg'%i,'wb') as f:
f.write(image_data)
Tensorflow数据读取机制的更多相关文章
- 十图详解tensorflow数据读取机制(附代码)转知乎
十图详解tensorflow数据读取机制(附代码) - 何之源的文章 - 知乎 https://zhuanlan.zhihu.com/p/27238630
- tensorflow 1.0 学习:十图详解tensorflow数据读取机制
本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...
- 十图详解tensorflow数据读取机制
在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...
- 十图详解TensorFlow数据读取机制(附代码)
在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...
- 【转载】 十图详解tensorflow数据读取机制(附代码)
原文地址: https://zhuanlan.zhihu.com/p/27238630 何之源 深度学习(Deep Learning) 话题的优秀回答者 --------------- ...
- tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数
tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程 ...
- TensorFlow数据读取
TensorFlow高效读取数据的方法 TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取 Tensorflow从文件读取数据 极客学院-数据读取 十 ...
- TensorFlow数据读取方式:Dataset API
英文详细版参考:https://www.cnblogs.com/jins-note/p/10243716.html Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服 ...
- 详解Tensorflow数据读取有三种方式(next_batch)
转自:https://blog.csdn.net/lujiandong1/article/details/53376802 Tensorflow数据读取有三种方式: Preloaded data: 预 ...
随机推荐
- 理解Erlang/OTP - Application
http://www.cnblogs.com/me-sa/archive/2011/12/27/erlang0025.html 1>application:start(log4erl). 我们就 ...
- WPF中自动增加行(动画)的TextBox
原文:WPF中自动增加行(动画)的TextBox WPF中自动增加行(动画)的TextBox WPF中的Textbox控件是可以自动换行的,只要设置TextWrapping属性为"Wrap& ...
- 微信上传素材 {"errcode":41005,"errmsg":"media data missing"} 解决方法和思路
哎lol 连跪两把 就来写写博客 今天遇到一个问题 ,微信公众号开发上传素材是提示报错 41005 errcode":41005,"errmsg":&q ...
- ashx 请求的内容似乎是脚本,因而将无法由静态文件处理程序来处理。
1.点击查看ashx在浏览器中显示的信息 2.自定义协议头 这样问题就搞定了.当然只是我遇到的一种.
- Android 它们的定义View
安卓开发过程,安卓官方控制有时来自往往不能满足我们的需求.这一次,我必须定义自己.下面我们就来看看他们的定义View: package com.example.myview; import andro ...
- sql知识收藏小总结
div { background-color: #eee; border-radius: 3px; border: 1px solid #999; padding: 4px; display: blo ...
- RadioButton分组的实现
原文:RadioButton分组的实现 XAML如下 <StackPanel> <RadioButton GroupName="colorgrp"> ...
- jQuery多库共存处理$.noConflict()
如果我们需要同时使用jQuery和其他JavaScript库,我们可以使用 $.noConflict()把$的控制权交给其他库.旧引用的$ 被保存在jQuery的初始化; noConflict() 简 ...
- Managing remote devices
A method and apparatus for managing remote devices. In one embodiment of the present invention, ther ...
- (让你提前知道软件开发33):数据操纵语言(DML)
文章2部分 数据库SQL语言 数据操纵语言(DML) 数据操纵语言(Data Manipulation Language,DML)包含insert.delete和update语句,用于增.删.改数据. ...