tensorflow学习笔记--dataset使用,创建自己的数据集
数据读入需求
我们在训练模型参数时想要从训练数据集中一次取出一小批数据(比如50条、100条)做梯度下降,不断地分批取出数据直到损失函数基本不再减小并且在训练集上的正确率足够高,取出的n条数据还要是预处理过的,一次取出的要包含输入数据和对应的lable,并且希望在达到训练效果之前可以不断地取出数据而不会因数据集取空了提前结束训练,最好取出的数据还是乱序的。
基于上面的要求,我们可以利用TensorFlow的dataset模块创建我们所需的数据集。
Dataset简介
TensorFlow程序数据导入的方法有多种。一是通过 feed_dict 传入具体值。二是利用tf的Queues创建数据队列,一次取出batch个数据进行训练,队列可以用多线程读数据,速度比较快,但是队列模块的用法比较复杂,要修改程序的时候就感觉很乱。
Dataset与队列相比就简单多了,Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道。
dataset用法
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
创建了一个dataset,这个dataset中含有5个元素1….,5,为了将5个元素取出,方法是从Dataset中示例化一个iterator,然后对iterator进行迭代。
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
print(sess.run(one_element))
语句iterator = dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次。one_element = iterator.get_next()表示从iterator里取出一个元素。这里取5次后dataset里的元素就空了,再取的话就就会抛出tf.errors.OutOfRangeError异常。
除了one-hot iterator,tf还支持其他三种iterator
- initializable
- reinitializable
- feedable
这三个迭代器比one-hot复杂,这里就不介绍他们了。
dataset元素变换
dataset数据集API还有一些操作元素的函数来满足我们的对输入数据的需求。
- map
- shuffle
- batch
- repeat
1. map
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:
def add1(x):
return x+1 dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(add1)
2. shuffle
shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:
dataset = dataset.shuffle(500)
3. batch
使用一次iterator返回一批数据的数量:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(one_element)) # 这样就一次获取两个数,可以取3次,第三次取到一个数
4. repeat
上面的代码取3次数就取完了,再取得话就会抛出异常,如果想重复取数,可以用dataset.repeat(count),count的值表示将全部的数在dataset中重复几次:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
dataset = dataset.batch(2).repeat(2)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(one_element))
这样就将5个数重复了两遍。这里需要注意的一点是它虽然重复了两次,但并不是可以取5次,一次取两个数,而是:[1,2], [3,4] , [5], [1,2], [3,4] , [5] 。这样再取到数据集末尾的时候得到的数据数量不是我们设置的batch_size 条数据。要想重复取数并且每次得到的都是batch_size条数据,可以设置batch_size的大小能被总数据量整除。
repeat()中的参数如果是None,则可以无限取数。
读入图片和lable,创建自己的数据集
import tensorflow as tf
import os batch_size = 50
img_resize = [100,100]
epoch_num = None # dataset.repeat() 的参数,设置为None,可以不断取数
# 传入图片名,返回正则化后的图片的像素值
def read_img(img_name, lable):
image = tf.read_file(img_name)
image = tf.image.decode_jpeg(image)
image = tf.image.resize_images(image, img_resize)
image = tf.image.per_image_standardization(image)
return image,lable
# 传入图片所在的文件夹,图片名含有图片的lable,返回利用文件夹中图片创建的dataset
def create_dataset(path):
files = os.listdir(path) # 列出文件夹中所有的图片
img_names = []
lables = []
for f in files:
img_names.append(os.path.join(path,f)) # 图片的完整路径append到文件名list中
lable = f.split('.')[0]
lables.append([int(i) for i in lable]) # 根据规则得到图片的lable img_names = tf.convert_to_tensor(img_names, dtype=tf.string)
lables = tf.convert_to_tensor(lables, dtype=tf.float32) # 将图片名list和lable的list转换成Tensor类型
dataset = tf.data.Dataset.from_tensor_slices((img_names,lables)) # 创建dataset,传入的需要是tensor类型
dataset = dataset.map(read_img) # 传入read_img函数,将图片名转为像素
# 将dataset打乱,设置一次获取batch_size条数据
dataset = dataset.shuffle(buffer_size=800).batch(batch_size).repeat(epoch_num)
return dataset
dataset = create_dataset('./img') # 图片所在的路径为./img
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next() # 创建dataset是batch_size 为多少这里一次就能获取多少个数据
在程序中,sess.run(one_element) 一次就能获取到batch_size条数据和对应的lable
参考链接
https://blog.csdn.net/ssmixi/article/details/80572813
https://www.jianshu.com/p/d80ea5d73446
tensorflow学习笔记--dataset使用,创建自己的数据集的更多相关文章
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- tensorflow学习笔记——自编码器及多层感知器
1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- Tensorflow学习笔记No.5
tf.data卷积神经网络综合应用实例 使用tf.data建立自己的数据集,并使用CNN卷积神经网络实现对卫星图像的二分类问题. 数据下载链接:https://pan.baidu.com/s/141z ...
- Tensorflow学习笔记No.7
tf.data与自定义训练综合实例 使用tf.data自定义猫狗数据集,并使用自定义训练实现猫狗数据集的分类. 1.使用tf.data创建自定义数据集 我们使用kaggle上的猫狗数据以及tf.dat ...
- Tensorflow学习笔记No.8
使用VGG16网络进行迁移学习 使用在ImageNet数据上预训练的VGG16网络模型对猫狗数据集进行分类识别. 1.预训练网络 预训练网络是一个保存好的,已经在大型数据集上训练好的卷积神经网络. 如 ...
- Tensorflow学习笔记No.10
多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...
- Tensorflow学习笔记No.11
图像定位 图像定位是指在图像中将我们需要识别的部分使用定位框进行定位标记,本次主要讲述如何使用tensorflow2.0实现简单的图像定位任务. 我所使用的定位方法是训练神经网络使它输出定位框的四个顶 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
随机推荐
- Day1-B-CF-1144B
简述:有一个n个元素的序列,选奇数下一个就选偶数,偶数则下一个就是奇数,问能否取完,能取完输出0,否则输出能剩下的最小的之和 思路:统计奇偶数个数,若相等或相差一则取完,否则排列后取出最小的前x个(x ...
- Java程序生成exe可执行文件
Java程序打包成exe可执行文件,分为两大步骤. 第一步:将Java程序通过Eclipse或者Myeclipse导成Jar包 第二步:通过exe4j讲Jar包程序生成exe可执行文件 第一步详解: ...
- C# Connection:连接数据库---转载
C# 语言中 Connection 类是 ADO.NET 组件连接数据库时第一个要使用的类,也是通过编程访问数据库的第一步. 接下来我们来了解一下 Connection 类中的常用属性和方法,以及如何 ...
- S7-300 与TP900 组态 棒图 量表 滚动条 滚动条设置的值通过IO输出域显示出来
切换编程语言 注意 一定要 先选中 某一个组织块 例如 OB1 然后单击 菜单 编辑 切换编程语言 组态 300 PLC 的CPU 点击 SIMENSE LOGO 查看 循环 中断 OB35 可以 在 ...
- Day2-O-Coloring a Tree CodeForces-902B
You are given a rooted tree with n vertices. The vertices are numbered from 1 to n, the root is the ...
- django中添加日志功能
官方文档 猛戳这里 在settings中配置以下代码 #LOGGING_DIR 日志文件存放目录 LOGGING_DIR = "logs" # 日志存放路径 if not os.p ...
- JAVA(windows)安装教程
JAVA(windows)安装教程 一.下载: https://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133 ...
- Jackson自定义反序列化
// 设置jackson时间反系列化格式 SimpleModule module = new SimpleModule(); module.addDeserializer(Date.class, ne ...
- 图片转换到指定大小PDF
1.首先转换为eps jpeg2ps compile to exec file ./jpeg2ps -p a4 a.jpg -o x.eps2.从eps转换到pdf ps2pdf -dDownsa ...
- jdk环境变量、maven环境变量、Mysql环境变量配置
jdk官网地址:http://www.oracle.com/index.htmlhttp://www.java.sun.com 一.配置 jdk环境变量1.新建JAVA_HOME,在变量值复制JDK安 ...