TensorFlow.org教程笔记(二) DataSets 快速入门
本文翻译自www.tensorflow.org的英文教程。
tf.data
模块包含一组类,可以让你轻松加载数据,操作数据并将其输入到模型中。本文通过两个简单的例子来介绍这个API
- 从内存中的numpy数组读取数据。
- 从csv文件中读取行
基本输入
对于刚开始使用tf.data
,从数组中提取切片(slices)是最简单的方法。
笔记(1)TensorFlow初上手里提到了训练输入函数train_input_fn
,该函数将数据传输到Estimator
中:
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Build the Iterator, and return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()
让我们进一步来看看这个过程。
参数
这个函数需要三个参数。期望“array”的参数几乎可以接受任何可以使用numpy.array
转换为数组的东西。其中有一个例外是对Datasets
有特殊意义的元组(tuple)。
- features :一个包含原始特征输入的
{'feature_name':array}
的字典(或者pandas.DataFrame
) - labels :一个包含每个样本标签的数组
- batch_size:指示所需批量大小的整数。
在前面的笔记中,我们使用iris_data.load_data()
函数加载了鸢尾花的数据。你可以运行下面的代码来获取结果:
import iris_data
# Fetch the data.
train, test = iris_data.load_data()
features, labels = train
然后你可以将数据输入到输入函数中,类似这样:
batch_size = 100
iris_data.train_input_fn(features, labels, batch_size)
我们来看看这个train_input_fn
切片(Slices)
在最简单的情况下,tf.data.Dataset.from_tensor_slices
函数接收一个array
并返回一个表示array
切片的tf.data.Dataset
。例如,mnist训练集的shape是(60000, 28, 28)
。将这个array
传递给from_tensor_slices
将返回一个包含60000个切片的数据集对象,每个切片大小为28X28
的图像。(其实这个API就是把array的第一维切开)。
这个例子的代码如下:
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train
mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)
将产生下面的结果:显示数据集中项目的type和shape。注意,数据集不知道它含有多少个sample。
<TensorSliceDataset shapes: (28,28), types: tf.uint8>
上面的数据集代表了简单数组的集合,但Dataset
的功能还不止如此。Dataset
能够透明地处理字典或元组的任何嵌套组合。例如,确保features
是一个标准的字典,你可以将数组字典转换为字典数据集。
先来回顾下features
,它是一个pandas.DataFrame
类型的数据:
SepalLength | SepalWidth | PetalLength | PetalWidth |
---|---|---|---|
0.6 | 0.8 | 0.9 | 1 |
... | ... | ... | ... |
而dict(features)
是一个字典,它的形式如下:
{key:value,key:value...} # key是string类型的列名,即SepalLength等
# value是pandas.core.series.Series类型的变量,即数据的一个列,是一个标量
对它进行切片
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
结果如下:
<TensorSliceDataset
shapes: {
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
types: {
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64}
>
这里我们看到,当数据集包含结构化元素时,数据集的形状和类型采用相同的结构。该数据集包含标量字典,所有类型为tf.float64。
train_input_fn
的第一行使用了相同的函数,但它增加了一层结构-----创建了一个包含(feature, labels)
对的数据集
我们继续回顾labels
的结构,它其实是一个pandas.core.series.Series
类型的变量,即它与dict(features)
的value是同一类型。且维度一致,都是标量,长度也一致。
以下代码展示了这个dataset
的形状:
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
()),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
操纵
对于目前的数据集,将以固定的顺序遍历数据一次,并且每次只生成一个元素。在它可以被用来训练之前,还需做进一步处理。幸运的是,tf.data.Dataset
类提供了接口以便于更好地在训练之前准备数据。输入函数的下一行利用了以下几种方法:
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
shuffle方法使用一个固定大小的缓冲区来随机对数据进行shuffle。设置大于数据集中sample数目的buffer_size可以确保数据完全混洗。鸢尾花数据集只包含150个数据。
repeat方法在读取到组后的数据时重启数据集。要限制epochs的数量,可以设置count
参数。
batch方法累计样本并堆叠它们,从而创建批次。这个操作的结果为这批数据的形状增加了一个维度。新维度被添加为第一维度。以下代码是早期使用mnist数据集上的批处理方法。这使得28x28
的图像堆叠为三维的数据批次。
print(mnist_ds.batch(100))
<BatchDataset
shapes: (?, 28, 28),
types: tf.uint8>
请注意,数据集具有未知的批量大小,因为最后一批的元素数量较少。
在train_input_fn
中,批处理后,数据集包含一维向量元素,其中每个标量先前都是:
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (?,), PetalWidth: (?,),
PetalLength: (?,), SepalWidth: (?,)},
(?,)),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
返回值
每个Estimator
的train
、predict
、evaluate
方法都需要输入函数返回一个包含Tensorflow张量的(features, label)
对。train_input_fn
使用以下代码将数据集转换为预期的格式:
# Build the Iterator, and return the read end of the pipeline.
features_result, labels_result = dataset.make_one_shot_iterator().get_next()
结果是TensorFlow张量的结构,匹配数据集中的项目层。
print((features_result, labels_result))
({
'SepalLength': <tf.Tensor 'IteratorGetNext:2' shape=(?,) dtype=float64>,
'PetalWidth': <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=float64>,
'PetalLength': <tf.Tensor 'IteratorGetNext:0' shape=(?,) dtype=float64>,
'SepalWidth': <tf.Tensor 'IteratorGetNext:3' shape=(?,) dtype=float64>},
Tensor("IteratorGetNext_1:4", shape=(?,), dtype=int64))
读取CSV文件
Dataset
最常见的实际用例是按流的方式从磁盘上读取文件。tf.data
模块包含各种文件读取器。让我们来看看如何使用Dataset
从csv文件中分析鸢尾花数据集。
以下对iris_data.maybe_download
函数的调用在需要时会下载数据,并返回下载结果文件的路径名称:
import iris_data
train_path, test_path = iris_data.maybe_download()
iris_data.csv_input_fn
函数包含使用Dataset
解析csv文件的替代实现。
构建数据集
我们首先构建一个TextLineDataset
对象,一次读取一行文件。然后,我们调用skip
方法跳过文件第一行,它包含一个头部,而不是样本:
ds = tf.data.TextLineDataset(train_path).skip(1)
构建csv行解析器
最终,我们需要解析数据集中的每一行,以产生必要的(features, label)
对。
我们将开始构建一个函数来解析单个行。
下面的iris_data.parse_line
函数使用tf.decode_csv
函数和一些简单的代码完成这个任务:
我们必须解析数据集中的每一行以生成必要的(features, label)
对。以下的_parse_line
函数调用tf.decode_csv
将单行解析为其features
和label
。由于Estimator
要求将特征表示为字典,因此我们依靠python的内置字典和zip
函数来构建该字典。特征名是该字典的key。然后我们调用字典的pop
方法从特征字典中删除标签字段。
# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth',
'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
# Decode the line into its fields
fields = tf.decode_csv(line, FIELD_DEFAULTS)
# Pack the result into a dictionary
features = dict(zip(COLUMNS, fields))
# Separate the label from the features
label = features.pop('label')
return features, label
解析行
Datasets
有很多方法用于在数据传输到模型时处理数据。最常用的方法是map,它将转换应用于Dataset
的每个元素。
map
方法使用一个map_func
参数来描述Dataset
中每个项目应该如何转换。
因此为了解析流出csv文件的行,我们将_parse_line
函数传递给map
方法:
ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
{SepalLength: (), PetalWidth: (), ...},
()),
types: (
{SepalLength: tf.float32, PetalWidth: tf.float32, ...},
tf.int32)>
现在的数据集不是简单的标量字符串,而是包含了(features, label)
对。
iris_data.csv_input_fn
函数其余部分与基本输入部分中涵盖的iris_data.train_input_fn
相同。
试试看
该函数可以用来替代iris_data.train_input_fn
。它可以用来提供一个如下的Estimator
:
train_path, test_path = iris_data.maybe_download()
# All the inputs are numeric
feature_columns = [
tf.feature_column.numeric_column(name)
for name in iris_data.CSV_COLUMN_NAMES[:-1]]
# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
n_classes = 3)
# Train the estimator
batch_size = 100
est.train(
steps=1000
input_fn=lambda:iris_data.csv_input_fn(train_path, batch_size))
Estimator
期望input_fn
不带任何参数。为了解除这个限制,我们使用lambda
来捕获参数并提供预期的接口。
总结
tf.data
模块提供了一组用于轻松读取各种来源数据的类和函数。此外,tf.data
具有简单强大的方法来应用各种标准和自定义转换。
现在你已经了解如何有效地将数据加载到Estimator
中的基本想法。接下来考虑以下文档:
TensorFlow.org教程笔记(二) DataSets 快速入门的更多相关文章
- Apache Superset 1.2.0教程 (二)——快速入门(可视化王者英雄数据)
上一篇我们已经成功的安装了superset,那么该如何可视化我们的数据呢?本文将可视化王者英雄的数据,快速的入门Superset. 一.连接数据源 首先确保mysql可以正常连接使用,并且准备好数据. ...
- git-github-TortoiseGit综合使用教程(二)快速入门
:建立版本库 在github网站上创建一个版本库,并复制clone地址. git@github.com:jackadam1981/Flask_Base.git https://github.com/j ...
- Expression Blend实例中文教程(8) - 动画设计快速入门StoryBoard http://silverlightchina.net/html/tips/2010/0329/934.html
Expression Blend实例中文教程(8) - 动画设计快速入门StoryBoard 时间:2010-03-29 11:13来源:SilverlightChina.Net 作者:jv9 点击: ...
- Yii2框架RESTful API教程(一) - 快速入门
前不久做一个项目,是用Yii2框架写一套RESTful风格的API,就去查了下<Yii 2.0 权威指南 >,发现上面写得比较简略.所以就在这里写一篇教程贴,希望帮助刚接触Yii2框架RE ...
- 【Python】【学习笔记】1.快速入门
1.软件安装 从官网下载相应版本的安装包,一般不大. https://www.python.org/ 安装一路默认即可 2. 参考教程:快速入门:十分钟学会Python 本文的内容介于教程(Totur ...
- 二:Redis快速入门及应用
Redis的使用难吗?不难,Redis用好容易吗?不容易.Redis的使用虽然不难,但与业务结合的应用场景特别多.特别紧,用好并不容易.我们希望通过一篇文章及Demo,即可轻松.快速入门并学会应用. ...
- MyBatis学习笔记(一)——MyBatis快速入门
转自孤傲苍狼的博客:http://www.cnblogs.com/xdp-gacl/p/4261895.html 一.Mybatis介绍 MyBatis是一个支持普通SQL查询,存储过程和高级映射的优 ...
- 【笔记】PyTorch快速入门:基础部分合集
PyTorch快速入门 Tensors Tensors贯穿PyTorch始终 和多维数组很相似,一个特点是可以硬件加速 Tensors的初始化 有很多方式 直接给值 data = [[1,2],[3, ...
- 利用python 数据分析入门,详细教程,教小白快速入门
这是一篇的数据的分析的典型案列,本人也是经历一次从无到有的过程,倍感珍惜,所以将其详细的记录下来,用来帮助后来者快速入门,,希望你能看到最后! 需求:对obo文件进行解析,输出为json字典格式 数据 ...
随机推荐
- 【pG&&CYH-01】元旦联欢会
题解: t1: 题解是循环矩阵 但我并没有往矩阵上想下去... 这个东西比较显然的是可以把它看成生成函数 然后就可以任意模数fft了 复杂度比题解优 $nlog^2$ t2: 随便推推式子就好了 t3 ...
- 白话大数据 | Spark和Hadoop到底谁更厉害?
要想搞清楚spark跟Hadoop到底谁更厉害,首先得明白spark到底是什么鬼. 经过之前的介绍大家应该非常了解什么是Hadoop了(不了解的点击这里:白话大数据 | hadoop究竟是什么鬼),简 ...
- MVC中ztree异步加载
var setting = { async: { enable: true, url: "*****/LoadChild", autoParam: ["id"] ...
- this指向及改变this指向的方法
一.函数的调用方式决定了 this 的指向不同,但总的原则,this指的是调用函数的那个对象: 1.普通函数调用,此时 this 指向 全局对象window function fn() { conso ...
- 1.XGBOOST算法推导
最近因为实习的缘故,所以开始复习各种算法推导~~~就先拿这个xgboost练练手吧. (参考原作者ppt 链接:https://pan.baidu.com/s/1MN2eR-4BMY-jA5SIm6W ...
- python入门编程之mysql编程
python关于mysql方面的连接编程 前提:引入mysql模块MySQLdb,即:MySQL_python-1.2.5-cp27-none-win_amd64.whl 如果要用线程池,则要引用模块 ...
- ImCash:论拥有靠谱数字钱包的重要性!
数字货币被盗已经不是什么新鲜事,前有交易所币安被黑客攻击,Youbit破产,后有“钓鱼邮件“盗号木马,安全对于数字货币用户来讲至关重要. 现行市场痛点: 2017年9月以太坊Parity钱包的漏洞 ...
- Django“少折腾”
1.Django中文语言.时区 修改项目setting文件 LANGUAGE_CODE = 'zh-hans' TIME_ZONE = 'Asia/Shanghai'
- Metasploit运行环境内存不要低于2GB
Metasploit运行环境内存不要低于2GB Metasploit启用的时候,会占用大量的内存.如果所在系统剩余内存不足(非磁盘剩余空间),会直接导致运行出错.这种情况特别容易发生在虚拟机Kali ...
- 让java代码在Idea外面运行起来
今天在写聊天程序,终于写到双方通信的时候,发现idea只能开一个客户端.虽说可以开多线程来实现多开,但是懒得改动代码,所以我就试试能不能把jar包导出来运行.首先我用maven自带的工具打了jar包, ...