TensorFlow读写数据
前言
只有光头才能变强。
文本已收录至我的GitHub仓库,欢迎Star:https://github.com/ZhongFuCheng3y/3y
回顾前面:
众所周知,要训练出一个模型,首先我们得有数据。我们第一个例子中,直接使用dataset的api去加载mnist的数据。(minst的数据要么我们是提前下载好,放在对应的目录上,要么就根据他给的url直接从网上下载)。
一般来说,我们使用TensorFlow是从TFRecord文件中读取数据的。
TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据
所以,这篇文章来聊聊怎么读取TFRecord文件的数据。
一、入门对数据集的数据进行读和写
首先,我们来体验一下怎么造一个TFRecord文件,怎么从TFRecord文件中读取数据,遍历(消费)这些数据。
1.1 造一个TFRecord文件
现在,我们还没有TFRecord文件,我们可以自己简单写一个:
def write_sample_to_tfrecord():
gmv_values = np.arange(10)
click_values = np.arange(10)
label_values = np.arange(10)
with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer:
for _ in range(10):
feature_internal = {
"gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])),
"click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]]))
}
features_extern = tf.train.Features(feature=feature_internal)
# 使用tf.train.Example将features编码数据封装成特定的PB协议格式
# example = tf.train.Example(features=tf.train.Features(feature=features_extern))
example = tf.train.Example(features=features_extern)
# 将example数据系列化为字符串
example_str = example.SerializeToString()
# 将系列化为字符串的example数据写入协议缓冲区
writer.write(example_str)
if __name__ == '__main__':
write_sample_to_tfrecord()
我相信大家代码应该是能够看得懂的,其实就是分了几步:
- 生成TFRecord Writer
- tf.train.Feature生成协议信息
- 使用tf.train.Example将features编码数据封装成特定的PB协议格式
- 将example数据系列化为字符串
- 将系列化为字符串的example数据写入协议缓冲区
参考资料:
ok,现在我们就有了一个TFRecord文件啦。
1.2 读取TFRecord文件
其实就是通过
tf.data.TFRecordDataset
这个api来读取到TFRecord文件,生成处dataset对象对dataset进行处理(shape处理,格式处理...等等)
使用迭代器对dataset进行消费(遍历)
demo代码如下:
import tensorflow as tf
def read_tensorflow_tfrecord_files():
# 定义消费缓冲区协议的parser,作为dataset.map()方法中传入的lambda:
def _parse_function(single_sample):
features = {
"gmv": tf.FixedLenFeature([1], tf.float32),
"click": tf.FixedLenFeature([1], tf.int64), # ()或者[]没啥影响
"label": tf.FixedLenFeature([1], tf.int64)
}
parsed_features = tf.parse_single_example(single_sample, features=features)
# 对parsed 之后的值进行cast.
gmv = tf.cast(parsed_features["gmv"], tf.float64)
click = tf.cast(parsed_features["click"], tf.float64)
label = tf.cast(parsed_features["label"], tf.float64)
return gmv, click, label
# 开始定义dataset以及解析tfrecord格式
filenames = tf.placeholder(tf.string, shape=[None])
# 定义dataset 和 一些列trasformation method
dataset = tf.data.TFRecordDataset(filenames)
parsed_dataset = dataset.map(_parse_function) # 消费缓冲区需要定义在dataset 的map 函数中
batchd_dataset = parsed_dataset.batch(3)
# 创建Iterator
sample_iter = batchd_dataset.make_initializable_iterator()
# 获取next_sample
gmv, click, label = sample_iter.get_next()
training_filenames = [
"/Users/zhongfucheng/data/fashin/demo.tfrecord"]
with tf.Session() as session:
# 初始化带参数的Iterator
session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})
# 读取文件
print(session.run(gmv))
if __name__ == '__main__':
read_tensorflow_tfrecord_files()
无意外的话,我们可以输出这样的结果:
[[0.]
[1.]
[2.]]
ok,现在我们已经大概知道怎么写一个TFRecord文件,以及怎么读取TFRecord文件的数据,并且消费这些数据了。
二、epoch和batchSize术语解释
我在学习TensorFlow翻阅资料时,经常看到一些机器学习的术语,由于自己没啥机器学习的基础,所以很多时候看到一些专业名词就开始懵逼了。
2.1epoch
当一个完整的数据集通过了神经网络一次并且返回了一次,这个过程称为一个epoch。
这可能使我们跟dataset.repeat()
方法联系起来,这个方法可以使当前数据集重复一遍。比如说,原有的数据集是[1,2,3,4,5]
,如果我调用dataset.repeat(2)
的话,那么我们的数据集就变成了[1,2,3,4,5],[1,2,3,4,5]
- 所以会有个说法:假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch
2.2batchSize
一般来说我们的数据集都是比较大的,无法一次性将整个数据集的数据喂进神经网络中,所以我们会将数据集分成好几个部分。每次喂多少条样本进神经网络,这个叫做batchSize。
在TensorFlow也提供了方法给我们设置:dataset.batch()
,在API中是这样介绍batchSize的:
representing the number of consecutive elements of this dataset to combine in a single batch
我们一般在每次训练之前,会将整个数据集的顺序打乱,提高我们模型训练的效果。这里我们用到的api是:dataset.shffle();
三、再来聊聊dataset
我从官网的介绍中截了一个dataset的方法图(部分):
dataset的功能主要有以下三种:
- 创建dataset实例
- 通过文件创建(比如TFRecord)
- 通过内存创建
- 对数据集的数据进行变换
- 比如上面的batch(),常见的
map(),flat_map(),zip(),repeat()
等等 - 文档中一般都有给出例子,跑一下一般就知道对应的意思了。
- 比如上面的batch(),常见的
- 创建迭代器,遍历数据集的数据
3.1 聊聊迭代器
迭代器可以分为四种:
- 单次。对数据集进行一次迭代,不支持参数化
- 可初始化迭代
- 使用前需要进行初始化,支持传入参数。面向的是同一个DataSet
- 可重新初始化:同一个Iterator从不同的DataSet中读取数据
- DataSet的对象具有相同的结构,可以使用
tf.data.Iterator.from_structure
来进行初始化 - 问题:每次 Iterator 切换时,数据都从头开始打印了
- DataSet的对象具有相同的结构,可以使用
- 可馈送(也是通过对象相同的结果来创建的迭代器)
- 可让您在两个数据集之间切换的可馈送迭代器
- 通过一个string handler来实现。
- 可馈送的 Iterator 在不同的 Iterator 切换的时候,可以做到不从头开始。
简单总结:
- 1、 单次 Iterator ,它最简单,但无法重用,无法处理数据集参数化的要求。
- 2、 可以初始化的 Iterator ,它可以满足 Dataset 重复加载数据,满足了参数化要求。
- 3、可重新初始化的 Iterator,它可以对接不同的 Dataset,也就是可以从不同的 Dataset 中读取数据。
- 4、可馈送的 Iterator,它可以通过 feeding 的方式,让程序在运行时候选择正确的 Iterator,它和可重新初始化的 Iterator 不同的地方就是它的数据在不同的 Iterator 切换时,可以做到不重头开始读取数据。
string handler(可馈送的 Iterator)这种方式是最常使用的,我当时也写了一个Demo来使用了一下,代码如下:
def read_tensorflow_tfrecord_files():
# 开始定义dataset以及解析tfrecord格式.
train_filenames = tf.placeholder(tf.string, shape=[None])
vali_filenames = tf.placeholder(tf.string, shape=[None])
# 加载train_dataset batch_inputs这个方法每个人都不一样的,这个方法我就不给了。
train_dataset = batch_inputs([
train_filenames], batch_size=5, type=False,
num_epochs=2, num_preprocess_threads=3)
# 加载validation_dataset batch_inputs这个方法每个人都不一样的,这个方法我就不给了。
validation_dataset = batch_inputs([vali_filenames
], batch_size=5, type=False,
num_epochs=2, num_preprocess_threads=3)
# 创建出string_handler()的迭代器(通过相同数据结构的dataset来构建)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types, train_dataset.output_shapes)
# 有了迭代器就可以调用next方法了。
itemid = iterator.get_next()
# 指定哪种具体的迭代器,有单次迭代的,有初始化的。
training_iterator = train_dataset.make_initializable_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# 定义出placeholder的值
training_filenames = [
"/Users/zhongfucheng/tfrecord_test/data01aa"]
validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"]
with tf.Session() as sess:
# 初始化迭代器
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
for _ in range(2):
sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})
print("this is training iterator ----")
for _ in range(5):
print(sess.run(itemid, feed_dict={handle: training_handle}))
sess.run(validation_iterator.initializer,
feed_dict={vali_filenames: validation_filenames})
print("this is validation iterator ")
for _ in range(5):
print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))
if __name__ == '__main__':
read_tensorflow_tfrecord_files()
参考资料:
3.2 dataset参考资料
在翻阅资料时,发现写得不错的一些博客:
- https://www.jianshu.com/p/91803a119f18
- https://irvingzhang0512.github.io/2018/04/19/tensorflow-api-2/
- http://www.feiguyunai.com/index.php/2017/12/25/pyhtonai-ml-dataprocess-datasetapi/
最后
乐于输出干货的Java技术公众号:Java3y。公众号内有200多篇原创技术文章、海量视频资源、精美脑图,不妨来关注一下!
下一篇文章打算讲讲如何理解axis~
觉得我的文章写得不错,不妨点一下赞!
TensorFlow读写数据的更多相关文章
- 第十二节,TensorFlow读取数据的几种方法以及队列的使用
TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起 ...
- HDFS读写数据块--${dfs.data.dir}选择策略
最近工作需要,看了HDFS读写数据块这部分.不过可能跟网上大部分帖子不一样,本文主要写了${dfs.data.dir}的选择策略,也就是block在DataNode上的放置策略.我主要是从我们工作需要 ...
- win10 svchost.exe (LocalSystemNetworkRestricted)大量读写数据
博主的笔记本联想Y50开机完毕后会不停滴读硬盘/写硬盘,导致开机后一段时间内无法正常使用电脑(硬盘读写高峰期).打开资源监视器发现是"svchost.exe (LocalSystemNetw ...
- inputstream和outputstream读写数据模板代码
//读写数据模板代码 byte buffer[] = new byte[1024]; int len=0; while((len=in.read(buffer))>0){ out.write(b ...
- 百度地图LBS云平台读写数据操作类
最近写了个叫<行踪记录仪>的手机软件,用了百度云来记录每个用户的最近位置,以便各用户能在地图上找到附近的人,为此写了个类来读写数据,大致如下: import java.util.Array ...
- 01. SQL Server 如何读写数据
原文:01. SQL Server 如何读写数据 一. 数据读写流程简要SQL Server作为一个关系型数据库,自然也维持了事务的ACID特性,数据库的读写冲突由事务隔离级别控制.无论有没有显示开启 ...
- SQL Server 如何读写数据
01. SQL Server 如何读写数据 一. 数据读写流程简要SQL Server作为一个关系型数据库,自然也维持了事务的ACID特性,数据库的读写冲突由事务隔离级别控制.无论有没有显示开启事 ...
- STM32F10X SPI操作flash MX25L64读写数据(转)
源:STM32F10X SPI操作flash MX25L64读写数据 前一段时间在弄SPI,之前没接触过嵌入式外围应用,就是单片机也只接触过串口通信,且也是在学校的时候了.从离开手机硬件测试岗位后,自 ...
- .net环境下跨进程、高频率读写数据
一.需求背景 1.最近项目要求高频次地读写数据,数据量也不是很大,多表总共加起来在百万条上下. 单表最大的也在25万左右,历史数据表因为不涉及所以不用考虑, 难点在于这个规模的热点数据,变化非常频繁. ...
随机推荐
- 第二课:Hadoop集群环境配置
一.Yum配置 1.检查Yum是否安装 rpm -qa|grep yum 2.修改yum源,我使用的是163的镜像源(http://mirrors.163.com/),根据自己的系统选择源, #进入目 ...
- leetCode刷题(找到两个数组拼接后的中间数)
There are two sorted arrays nums1 and nums2 of size m and n respectively. Find the median of the two ...
- Sublime中文编码问题
1. 改配置 Preferences->Settings 三个全部加上 "default_encoding": "UTF-8" 2. 代码编写 2.1 ...
- .NET开发微信小程序-上传图片到服务器
1.上传图片分为几种: a:上传图片到本地(永久保存) b:上传图片到本地(临时保存) c:上传图片到服务器 a和b在小程序的api文档里面有.直接说C:上传图片到服务器 前端代码: /* 上传图片到 ...
- Caffe 编译后 make runtest 出现locale::facet::_S_create_c_locale 错误
You might need to append LC_ALL="en_US.UTF-8" to file: /etc/default/locale and reboot your ...
- 由于github仓库中提前建立readme文件,导致git push报错error: failed to push some refs to 'git@github.com:
$ git push -u origin master To git@github.com:xxx/xxx.git ! [rejected] master -> master (fetch fi ...
- mysql 创建用户
以管理员方式打开cmd命令提示符进入MySql的Bin目录下 一.以管理员身份登录mysql 密码不隐藏的登录方式:mysql -u root -p 123456 密码隐藏的登录方式:mysql -u ...
- SSM-SpringMVC-23:SpringMVC中初探异常解析器
------------吾亦无他,唯手熟尔,谦卑若愚,好学若饥------------- 本篇博客要讲的是异常解析器,SimpleMappingExceptionResolver简单映射异常解析器 可 ...
- 获取GRIDVIEW中的TemplateField显示的文本值
GRIDVIEW中数据源绑定后的属性绑定我一般采取2种办法 一个是BoundField,只要设置DataField的对应属性名即可: 如: <asp:BoundField HeaderText ...
- app的安装与卸载测试点
安装 1)软件在不同操作系统(Palm OS.Symbian.Linux.Android.iOS.Black Berry OS .Windows Phone )下安装是否正常. 2)软件安装后的是否能 ...