Tensorflow datasets.shuffle repeat batch方法
机器学习中数据读取是很重要的一个环节,TensorFlow也提供了很多实用的方法,为了避免以后时间久了又忘记,所以写下笔记以备日后查看。
最普通的正常情况
首先我们看看最普通的情况:
# 创建0-10的数据集,每个batch取个数。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(2):
value = sess.run(next_element)
print(value)
输出结果
[0 1 2 3 4 5]
[6 7 8 9]
由结果我们可以知道TensorFlow能很好地帮我们自动处理最后一个batch的数据。
datasets.batch(batch_size)与迭代次数的关系
但是如果上面for循环次数超过2会怎么样呢?也就是说如果 循环次数*批数量 > 数据集数量 会怎么样?我们试试看:
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
>>==for i in range(3):==<<
value = sess.run(next_element)
print(value)
输出结果
[0 1 2 3 4 5]
[6 7 8 9]
---------------------------------------------------------------------------
OutOfRangeError Traceback (most recent call last)
D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1277 try:
...
...省略若干信息...
...
OutOfRangeError (see above for traceback): End of sequence
[[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]
可以知道超过范围了,所以报错了。
datasets.repeat()
为了解决上述问题,repeat方法登场。还是直接看例子吧:
dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
输出结果
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
可以知道repeat其实就是将数据集重复了指定次数,上面代码将数据集重复了2次,所以这次即使for循环次数是4也依旧能正常读取数据,并且都能完整把数据读取出来。同理,如果把for循环次数设置为大于4,那么也还是会报错,这么一来,我每次还得算repeat的次数,岂不是很心累?所以更简便的办法就是对repeat方法不设置重复次数,效果见如下:
dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(6):
value = sess.run(next_element)
print(value)
输出结果:
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
此时无论for循环多少次都不怕啦~~
datasets.shuffle(buffer_size)
仔细看可以知道上面所有输出结果都是有序的,这在机器学习中用来训练模型是浪费资源且没有意义的,所以我们需要将数据打乱,这样每批次训练的时候所用到的数据集是不一样的,这样啊可以提高模型训练效果。
另外shuffle前需要设置buffer_size:
- 不设置会报错,
- buffer_size=1:不打乱顺序,既保持原序
- buffer_size越大,打乱程度越大,演示效果见如下代码:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
输出结果:
[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]
注意:shuffle的顺序很重要,一般建议是最开始执行shuffle操作,因为如果是先执行batch操作的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。不信你看:
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
输出结果:
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
Tensorflow datasets.shuffle repeat batch方法的更多相关文章
- TensorFlow高效读取数据的方法——TFRecord的学习
关于TensorFlow读取数据,官网给出了三种方法: 供给数据(Feeding):在TensorFlow程序运行的每一步,让python代码来供给数据. 从文件读取数据:在TensorFlow图的起 ...
- 【tf.keras】tensorflow datasets,tfds
一些最常用的数据集如 MNIST.Fashion MNIST.cifar10/100 在 tf.keras.datasets 中就能找到,但对于其它也常用的数据集如 SVHN.Caltech101,t ...
- TensorFlow模型保存和加载方法
TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...
- TensorFlow指定CPU和GPU方法
TensorFlow指定CPU和GPU方法 TensorFlow 支持 CPU 和 GPU.它也支持分布式计算.可以在一个或多个计算机系统的多个设备上使用 TensorFlow. TensorFlow ...
- [TensorFlow] Introduction to TensorFlow Datasets and Estimators
Datasets and Estimators are two key TensorFlow features you should use: Datasets: The best practice ...
- Tensorflow高效读取数据的方法
最新上传的mcnn中有完整的数据读写示例,可以参考. 关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码 ...
- TensorFlow加载图片的方法
方法一:直接使用tensorflow提供的函数image = tf.gfile.FastGFile('PATH')来读取一副图片: import matplotlib.pyplot as plt; i ...
- tensorflow中的参数初始化方法
1. 初始化为常量 tf中使用tf.constant_initializer(value)类生成一个初始值为常量value的tensor对象. constant_initializer类的构造函数定义 ...
- TensorFlow 常见错误与解决方法——长期不定时更新
1. TypeError: Cannot interpret feed_dict key as Tensor: Can not convert a builtin_function_or_method ...
随机推荐
- SQL partition (小组排序)
很多时候,我们在SQL中进行数据去重(distinct) 结果发现有2条一样ID,或者name的数据,我们想要最接近的那条数据. 直接看看题目: 原表 select ID,Title,PRICE fr ...
- 洛谷P2516 [HAOI2010]最长公共子序列(LCS,最短路)
洛谷题目传送门 一进来就看到一个多月前秒了此题的ysn和YCB%%% 最长公共子序列的\(O(n^2)\)的求解,Dalao们想必都很熟悉了吧!不过蒟蒻突然发现,用网格图貌似可以很轻松地理解这个东东? ...
- 自学Python1.5-Centos内python2识别中文
自学Python之路 自学Python1.5-Centos内python2识别中文 方法一,python推荐使用utf-8编码方案 经验一:在开头声明: # -*- coding: utf-8 -*- ...
- 洛谷 P2725 邮票 Stamps 解题报告
P2725 邮票 Stamps 题目背景 给一组 N 枚邮票的面值集合(如,{1 分,3 分})和一个上限 K -- 表示信封上能够贴 K 张邮票.计算从 1 到 M 的最大连续可贴出的邮资. 题目描 ...
- Redis中的简单动态字符串
Redis没有直接使用C语言传统的字符串表示(以空字符结尾的字符数组,以下简称C字符串),而是自己构建了一种名为简单动态字符串(simple dynamic string,SDS)的抽象类型,并将SD ...
- P2569 股票交易
题目大意: 你初始时有∞ 元钱,并且每天持有的股票不超过 Maxp . 有 T 天,你知道每一天的买入价格( AP[i] ),卖出价格( Bp[i] ), 买入数量限制( AS[i] ),卖出数量限制 ...
- 使用 <!DOCTYPE html> 让 <div><img></div>中的图片下面产生几个像素的空白间隔
今天算是第一次使用 <!DOCTYPE html> 不经意间发现图片下方有5个像素左右的空白间隔,检查半天也没查出原因. 最后百度了一下,网上说这是 <!DOCTYPE html&g ...
- 转:@ControllerAdvice + @ExceptionHandler 全局处理 Controller 层异常
继承 ResponseEntityExceptionHandler 类来实现针对 Rest 接口 的全局异常捕获,并且可以返回自定义格式: 复制代码 1 @Slf4j 2 @ControllerAdv ...
- Spring MVC程序中怎么得到静态资源文件css,js,图片文件的路径问题
问题描述 在用springmvc开发应用程序的时候.对于像我一样的初学者,而且还是自学的人,有一个很头疼的问题.那就是数据都已经查出来了,但是页面的样式仍然十分简陋,加载不了css.js,图片等资源文 ...
- 自己的Promise
废话不多说,直接上代码: class Promise2{ constructor(fn){ const _this=this; //重点 this.__queue=[]; this.__succ_re ...