对数据集的shuffle处理需要设置相应的buffer_size参数,相当于需要将相应数目的样本读入内存,且这部分内存会在训练过程中一直保持占用。完全的shuffle需要将整个数据集读入内存,这在大规模数据集的情况下是不现实的,故需要结合设备内存以及Batch大小将TFRecord文件随机划分为多个子文件,再对数据集做local shuffle(即设置相对较小的buffer_size,不小于单个子文件的样本数)。

Shuffle和划分

下文以一个异常检测数据集(正负样本不平衡)为例,在生成第一批TFRecord时,我将正负样本分别写入单独的TFrecord文件以备后续在对正负样本有不同处理策略的情况下无需再解析example_proto。比如在以下代码中,我对正负样本有不同的验证集比例,并将他们写入不同的验证集文件。

import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm as tqdm # TFRecord划分
raw_normal_dataset = tf.data.TFRecordDataset("normal_16_256.tfrecords","GZIP")
raw_anomaly_dataset = tf.data.TFRecordDataset("anomaly_16_256.tfrecords","GZIP")
normal_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
anomaly_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
train_writer_list = [tf.io.TFRecordWriter(r'ex_1/'+'train_16_256_{}.tfrecords'.format(i),"GZIP") for i in range(SUBFILE_NUM+1)]
with tqdm(total=LEN_NORMAL_DATASET+LEN_ANOMALY_DATASET) as pbar:
for example_proto in raw_normal_dataset:
# 划分训练集和测试集
if np.random.random() > 0.99: # 正样本测试集的比例
normal_val_writer.write(example_proto.numpy())
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
pbar.update(1) for example_proto in raw_anomaly_dataset:
# 划分训练集和测试集
if np.random.random() > 0.7: # 负样本测试集的比例
anomaly_val_writer.write(example_proto.numpy())
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
pbar.update(1)
normal_val_writer.close()
anomaly_val_writer.close()
for train_writer in train_writer_list:
train_writer.close()

读取

raw_train_dataset = tf.data.TFRecordDataset([r'ex_1/'+'train_16_256_{}.tfrecords'.format(i) for i in range(SUBFILE_NUM+1)],"GZIP")
raw_train_dataset = raw_train_dataset.shuffle(buffer_size=100000).batch(BATCH_SIZE)
parsed_train_dataset = raw_train_dataset.map(map_func=map_func) raw_normal_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
raw_anomaly_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
parsed_nomarl_val_dataset = raw_normal_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
parsed_anomaly_val_dateset = raw_anomaly_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)

TFRecord的Shuffle、划分和读取的更多相关文章

  1. Tensorflow 中(批量)读取数据的案列分析及TFRecord文件的打包与读取

    内容概要: 单一数据读取方式: 第一种:slice_input_producer() # 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中 ...

  2. 更加清晰的TFRecord格式数据生成及读取

    TFRecords 格式数据文件处理流程 TFRecords 文件包含了 tf.train.Example 协议缓冲区(protocol buffer),协议缓冲区包含了特征 Features.Ten ...

  3. Tensorflow中使用tfrecord方式读取数据-深度学习-周振洋

    本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用 ...

  4. Spark技术内幕:Stage划分及提交源码分析

    http://blog.csdn.net/anzhsoft/article/details/39859463 当触发一个RDD的action后,以count为例,调用关系如下: org.apache. ...

  5. Spark技术内幕:Stage划分及提交源代码分析

    当触发一个RDD的action后.以count为例,调用关系例如以下: org.apache.spark.rdd.RDD#count org.apache.spark.SparkContext#run ...

  6. 第十二节,TensorFlow读取数据的几种方法以及队列的使用

    TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起 ...

  7. TensorFlow中数据读取之tfrecords

    关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow ...

  8. tensorflow之数据读取探究(2)

    tensorflow之tfrecord数据读取 Tensorflow关于TFRecord格式文件的处理.模型的训练的架构为: 1.获取文件列表.创建文件队列:http://blog.csdn.net/ ...

  9. spark 笔记 15: ShuffleManager,shuffle map两端的stage/task的桥梁

    无论是Hadoop还是spark,shuffle操作都是决定其性能的重要因素.在不能减少shuffle的情况下,使用一个好的shuffle管理器也是优化性能的重要手段. ShuffleManager的 ...

随机推荐

  1. SQL注入之information_schema

    在学习SQL注入时, 经常拿出来的例子就是PHP+MySQL这一套经典组合. 其中又经常提到的>=5.0版本的MySQL的内置库: information_schema 简单看一下informa ...

  2. SQL表的创建

    ​  一,创建表 1.使用鼠标创建表 1,进入SQL进行连接 ​编辑 2,在左边会有一个对象资源管理器,右键数据库,在弹出的窗口中选择新建数据库 ​编辑 3,给这个包取个名字,在这个界面可以给这个表选 ...

  3. CoaXPress 时间戳 Time Stamping

    背景 在CXP2.0之前,CXP没有定义Time Stamping时间戳的概念,但是用户对Time Stamping是有实际需求的,比如我们要对比多台设备拍摄同一个物体不同角度的照片,或者记录触发完成 ...

  4. 手把手教你使用 Spring Boot 3 开发上线一个前后端分离的生产级系统(一) - 介绍

    项目简介 novel 是一套基于时下最新 Java 技术栈 Spring Boot 3 + Vue 3 开发的前后端分离的学习型小说项目,配备详细的项目教程手把手教你从零开始开发上线一个生产级别的 J ...

  5. Unity-UGUI-无限循环列表

    前言:项目准备新增一个竞技场排行榜,策划规定只显示0-400名的玩家.我一想,生成四百个游戏物体,怕不是得把手机给卡死?回想原来在GitHub上看到过一个实现思路就是无限循环列表,所以就想自己试试.公 ...

  6. [BZOJ5449] 序列

    题目链接:序列 Description 给定一个\(1\)~\(n\)的排列x,每次你可以将 \(x_1, x_2, ..., x_i\) 翻转. 你需要求出将序列变为升序的最小操作次数. 多组数据. ...

  7. 浅谈Javascript单线程和事件循环

    单线程 Javascript 是单线程的,意味着不会有其他线程来竞争.为什么是单线程呢? 假设 Javascript 是多线程的,有两个线程,分别对同一个元素进行操作: function change ...

  8. SpringBoot官方支持任务调度框架,轻量级用起来也挺香!

    大家好,我是二哥呀.定时任务的应用场景其实蛮常见的,比如说: 数据备份 订单未支付则自动取消 定时爬取数据 定时推送信息 定时发布文章 等等(想不出来了,只能等等来凑,,反正只要等的都需要定时,怎么样 ...

  9. MathType7安装使用及please restart word to load mathtype addin properly的问题

    MathType7安装使用及please restart word to load mathtype addin properly的问题.最近在自己的电脑上安装Mathtype7,把遇到的问题和解决办 ...

  10. php 访问控制可见性 public protected private

    对属性或方法的访问控制,是通过在前面添加关键字public(公有),protected(受保护的),private(私有)来实现. 被定义为公有的类成员可以在任何地方被访问. 被定义为受保护的类成员则 ...