TFRecord的Shuffle、划分和读取
对数据集的shuffle处理需要设置相应的buffer_size参数,相当于需要将相应数目的样本读入内存,且这部分内存会在训练过程中一直保持占用。完全的shuffle需要将整个数据集读入内存,这在大规模数据集的情况下是不现实的,故需要结合设备内存以及Batch大小将TFRecord文件随机划分为多个子文件,再对数据集做local shuffle(即设置相对较小的buffer_size,不小于单个子文件的样本数)。
Shuffle和划分
下文以一个异常检测数据集(正负样本不平衡)为例,在生成第一批TFRecord时,我将正负样本分别写入单独的TFrecord文件以备后续在对正负样本有不同处理策略的情况下无需再解析example_proto。比如在以下代码中,我对正负样本有不同的验证集比例,并将他们写入不同的验证集文件。
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm as tqdm
# TFRecord划分
raw_normal_dataset = tf.data.TFRecordDataset("normal_16_256.tfrecords","GZIP")
raw_anomaly_dataset = tf.data.TFRecordDataset("anomaly_16_256.tfrecords","GZIP")
normal_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
anomaly_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
train_writer_list = [tf.io.TFRecordWriter(r'ex_1/'+'train_16_256_{}.tfrecords'.format(i),"GZIP") for i in range(SUBFILE_NUM+1)]
with tqdm(total=LEN_NORMAL_DATASET+LEN_ANOMALY_DATASET) as pbar:
for example_proto in raw_normal_dataset:
# 划分训练集和测试集
if np.random.random() > 0.99: # 正样本测试集的比例
normal_val_writer.write(example_proto.numpy())
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
pbar.update(1)
for example_proto in raw_anomaly_dataset:
# 划分训练集和测试集
if np.random.random() > 0.7: # 负样本测试集的比例
anomaly_val_writer.write(example_proto.numpy())
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
pbar.update(1)
normal_val_writer.close()
anomaly_val_writer.close()
for train_writer in train_writer_list:
train_writer.close()
读取
raw_train_dataset = tf.data.TFRecordDataset([r'ex_1/'+'train_16_256_{}.tfrecords'.format(i) for i in range(SUBFILE_NUM+1)],"GZIP")
raw_train_dataset = raw_train_dataset.shuffle(buffer_size=100000).batch(BATCH_SIZE)
parsed_train_dataset = raw_train_dataset.map(map_func=map_func)
raw_normal_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
raw_anomaly_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
parsed_nomarl_val_dataset = raw_normal_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
parsed_anomaly_val_dateset = raw_anomaly_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
TFRecord的Shuffle、划分和读取的更多相关文章
- Tensorflow 中(批量)读取数据的案列分析及TFRecord文件的打包与读取
内容概要: 单一数据读取方式: 第一种:slice_input_producer() # 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中 ...
- 更加清晰的TFRecord格式数据生成及读取
TFRecords 格式数据文件处理流程 TFRecords 文件包含了 tf.train.Example 协议缓冲区(protocol buffer),协议缓冲区包含了特征 Features.Ten ...
- Tensorflow中使用tfrecord方式读取数据-深度学习-周振洋
本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用 ...
- Spark技术内幕:Stage划分及提交源码分析
http://blog.csdn.net/anzhsoft/article/details/39859463 当触发一个RDD的action后,以count为例,调用关系如下: org.apache. ...
- Spark技术内幕:Stage划分及提交源代码分析
当触发一个RDD的action后.以count为例,调用关系例如以下: org.apache.spark.rdd.RDD#count org.apache.spark.SparkContext#run ...
- 第十二节,TensorFlow读取数据的几种方法以及队列的使用
TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起 ...
- TensorFlow中数据读取之tfrecords
关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow ...
- tensorflow之数据读取探究(2)
tensorflow之tfrecord数据读取 Tensorflow关于TFRecord格式文件的处理.模型的训练的架构为: 1.获取文件列表.创建文件队列:http://blog.csdn.net/ ...
- spark 笔记 15: ShuffleManager,shuffle map两端的stage/task的桥梁
无论是Hadoop还是spark,shuffle操作都是决定其性能的重要因素.在不能减少shuffle的情况下,使用一个好的shuffle管理器也是优化性能的重要手段. ShuffleManager的 ...
随机推荐
- SQL注入之information_schema
在学习SQL注入时, 经常拿出来的例子就是PHP+MySQL这一套经典组合. 其中又经常提到的>=5.0版本的MySQL的内置库: information_schema 简单看一下informa ...
- SQL表的创建
一,创建表 1.使用鼠标创建表 1,进入SQL进行连接 编辑 2,在左边会有一个对象资源管理器,右键数据库,在弹出的窗口中选择新建数据库 编辑 3,给这个包取个名字,在这个界面可以给这个表选 ...
- CoaXPress 时间戳 Time Stamping
背景 在CXP2.0之前,CXP没有定义Time Stamping时间戳的概念,但是用户对Time Stamping是有实际需求的,比如我们要对比多台设备拍摄同一个物体不同角度的照片,或者记录触发完成 ...
- 手把手教你使用 Spring Boot 3 开发上线一个前后端分离的生产级系统(一) - 介绍
项目简介 novel 是一套基于时下最新 Java 技术栈 Spring Boot 3 + Vue 3 开发的前后端分离的学习型小说项目,配备详细的项目教程手把手教你从零开始开发上线一个生产级别的 J ...
- Unity-UGUI-无限循环列表
前言:项目准备新增一个竞技场排行榜,策划规定只显示0-400名的玩家.我一想,生成四百个游戏物体,怕不是得把手机给卡死?回想原来在GitHub上看到过一个实现思路就是无限循环列表,所以就想自己试试.公 ...
- [BZOJ5449] 序列
题目链接:序列 Description 给定一个\(1\)~\(n\)的排列x,每次你可以将 \(x_1, x_2, ..., x_i\) 翻转. 你需要求出将序列变为升序的最小操作次数. 多组数据. ...
- 浅谈Javascript单线程和事件循环
单线程 Javascript 是单线程的,意味着不会有其他线程来竞争.为什么是单线程呢? 假设 Javascript 是多线程的,有两个线程,分别对同一个元素进行操作: function change ...
- SpringBoot官方支持任务调度框架,轻量级用起来也挺香!
大家好,我是二哥呀.定时任务的应用场景其实蛮常见的,比如说: 数据备份 订单未支付则自动取消 定时爬取数据 定时推送信息 定时发布文章 等等(想不出来了,只能等等来凑,,反正只要等的都需要定时,怎么样 ...
- MathType7安装使用及please restart word to load mathtype addin properly的问题
MathType7安装使用及please restart word to load mathtype addin properly的问题.最近在自己的电脑上安装Mathtype7,把遇到的问题和解决办 ...
- php 访问控制可见性 public protected private
对属性或方法的访问控制,是通过在前面添加关键字public(公有),protected(受保护的),private(私有)来实现. 被定义为公有的类成员可以在任何地方被访问. 被定义为受保护的类成员则 ...