在使用kears训练model的时候,一般会将所有的训练数据加载到内存中,然后喂给网络,但当内存有限,且数据量过大时,此方法则不再可用。此博客,将介绍如何在多核(多线程)上实时的生成数据,并立即的送入到模型当中训练。 本篇文章由圆柱模板博主发布。

   先看一下还未改进的版本:

   

import numpy as np
from keras.models import Sequential
#载入全部的数据!!
X, y = np.load('some_training_set_with_labels.npy')
#设计模型
model = Sequential()
[...] #网络结构
model.compile()
# 在数据集上进行模型训练
model.fit(x=X, y=y)

  下面的结构将改变一次性载入全部数据的情况。接下来将介绍如何一步一步的构造数据生成器,此数据生成器也可应用在你自己的项目当中;复制下来,并根据自己的需求填充空白处。

在构建之前先定义统一几个变量,并介绍几个小tips,对我们处理大的数据量很重要。 
ID type为string,代表数据集中的某个样本。 
调整以下结构,编译处理样本和他们的label:

1.新建一个词典名叫 partition :

partition[‘train’] 为训练集的ID,type为list
partition[‘validation’] 为验证集的ID,type为list

  2.新建一个词典名叫 * labels * ,根据ID可找到数据集中的样本,同样可通过labels[ID]找到样本标签。 
举个例子: 
假设训练集包含三个样本,ID分别为id-1,id-2和id-3,相应的label分别为0,1,2。验证集包含样本ID id-4,标签为 1。此时两个词典partition和 labels分别如下:

partition
{'train': ['id-1', 'id-2', 'id-3'], 'validation': ['id-4']}

  

labels
{'id-1': 0, 'id-2': 1, 'id-3': 2, 'id-4': 1}

  data/ 中为数据集文件。

数据生成器(data generator)

接下来将介绍如何构建数据生成器 DataGenerator ,DataGenerator将实时的对训练模型feed数据。 
接下来,将先初始化类。我们使此类继承自keras.utils.Sequence,这样我们可以使用多线程。

def __init__(self, list_IDs, labels, batch_size=32,
dim=(32,32,32), n_channels=1,
n_classes=10, shuffle=True):
'Initialization'
self.dim = dim
self.batch_size = batch_size
self.labels = labels
self.list_IDs = list_IDs
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.on_epoch_end()

  我们给了一些与数据相关的参数 dim,channels,classes,batch size ;方法 on_epoch_end 在一个epoch开始时或者结束时触发,shuffle决定是否在数据生成时要对数据进行打乱。

def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)

  另一个数据生成核心的方法__data_generation 是生成批数据。

def __data_generation(self, list_IDs_temp):
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
# Initialization
X = np.empty((self.batch_size, *self.dim, self.n_channels))
y = np.empty((self.batch_size), dtype=int) # Generate data
for i, ID in enumerate(list_IDs_temp):
# Store sample
X[i,] = np.load('data/' + ID + '.npy') # Store class
y[i] = self.labels[ID] return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

  在数据生成期间,代码读取包含各个样本ID的代码ID.py.因为我们的代码是可以应用多线程的,所以可以采用更为复杂的操作,不用担心数据生成成为总体效率的瓶颈。 
另外,我们使用Keras的方法keras.utils.to_categorical对label进行2值化 
(比如,对6分类而言,第三个label则相应的变成 to [0 0 1 0 0 0]) 。

def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_IDs) / self.batch_size))

  现在,当相应的index的batch被选到,则生成器执行_getitem_方法来生成它。

def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] # Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes] # Generate data
X, y = self.__data_generation(list_IDs_temp) return X, y

  

Keras神经网络data generators解决数据内存的更多相关文章

  1. mysql查询语句出现sending data耗时解决

    在执行一个简单的sql查询,表中数据量为14万 sql语句为:SELECT id,titile,published_at from spider_36kr_record where is_analyz ...

  2. 压缩Sqlite数据文件大小,解决数据删除后占用空间不变的问题

    最近有一网站使用Sqlite数据库作为数据临时性的缓存,对多片区进行划分 Sqlite数据库文件,每天大概新增近1万的数据量,起初效率有明显的提高,但历经一个多月后数据库文件从几K也上升到了近160M ...

  3. [转]怎样解决Myeclipse内存溢出?

    在用myeclipes10 开发 遇到了 内存溢出问题,百度了很久,这篇比较完善. 总结起来三个方面去检查 1)myeclipes的配置:myeclipes 10 的安装路径下 的myeclipse. ...

  4. JAVA 大数据内存耗用测试

    JAVA 大数据内存耗用测试import java.lang.management.ManagementFactory;import java.lang.management.MemoryMXBean ...

  5. 解决Windows内存问题的两个小工具RamMap和VMMap(这个更牛更好)

    来源:http://www.cr173.com/html/13006_1.html .net程序内存监测分配工具(CLR Profiler for .NET Framework 4)官方安装版 类型: ...

  6. Spark性能调优之解决数据倾斜

    Spark性能调优之解决数据倾斜 数据倾斜七种解决方案 shuffle的过程最容易引起数据倾斜 1.使用Hive ETL预处理数据    • 方案适用场景:如果导致数据倾斜的是Hive表.如果该Hiv ...

  7. SAS DATA步读取数据

    上面一节讲了SAS的基本概念,以及语法结构,这次主要讲解SAS DATA步读取数据.    1 ·列表输入    2 ·按列输入    3 ·格式化输入  使用DATA步读取数据的基本形式如下: DA ...

  8. [MapReduce_add_3] MapReduce 通过分区解决数据倾斜

    0. 说明 数据倾斜及解决方法的介绍与代码实现 1. 介绍 [1.1 数据倾斜的含义] 大量数据发送到同一个节点进行处理,造成此节点繁忙甚至瘫痪,而其他节点资源空闲 [1.2 解决数据倾斜的方式] 重 ...

  9. 解决Windows内存问题的两个小工具RamMap和VMMap

    解决Windows内存问题需要对操作系统的深入理解,同时对于如何运用Windows调试器或性能监控器要有工作认知.如果你正试着得到细节,诸如内核堆栈大小或硬盘内存消耗,你会需要调试器命令和内核数据架构 ...

随机推荐

  1. objectARX2010及其以上版本使用publish打印(发布)图纸,后台布局打印图纸例子浅析

    AutoCAD 2010版本开始新增了一个发布图纸的功能,可以后台打印图纸,以下是ADN官方博客例子浅析 原文地址 https://adndevblog.typepad.com/autocad/201 ...

  2. IP通信学习心得01

    一.物理拓扑 1. 1) 总线拓扑 特点:所有设备都处于同一个冲突域与广播域,共享相同的带宽 一次只能有一个设备传输,且两端要安装端接器. 传输介质:同轴电缆.(注:10Base5:容量10M 传输5 ...

  3. Word2016经常复制公式卡死无响应如何解决?

    Word文件 > 选项 > 高级 > 显示 > 禁用“硬件图形加速”

  4. 深入浅出JVM(一):运行时数据区域

    程序计数器 线程私有 指向了正在执行的虚拟机字节码指令的地址:如果是本地方法,数值为空 没有 OutOfMemoryError 错误的区域 Java虚拟机栈 线程私有: 生命周期与线程相同: 代表着 ...

  5. yield再理解--绝对够透彻

    首先,拿好宝剑: 先把yield看做“return”, 普通的return是什么意思,就是在程序中返回某个值,返回之后程序就不再往下运行了. 看做return之后再把它看做一个是生成器(generat ...

  6. 登录和退出Mysql

    这里介绍的是通过cmd方式登录和退出Mysql的方式 一.登录命令 登录命令:mysql.exe -h主机地址   -P端口   -u用户名    -p密码 即依次输入服务器地址.服务器监听的端口.用 ...

  7. System.AccessViolationException处理

    程序出现 System.AccessViolationException异常会终止进程,try catch是无法捕捉的. 有个处理方法在引发异常的发放上面加上 [System.Runtime.Exce ...

  8. K8S conul部署

    官网有Helm方式的安装文档(https://www.consul.io/docs/platform/k8s/index.html) 一,准备工作: 1,k8s环境 2,nfs服务器 二,创建PV n ...

  9. 3)创建,测试,发布 第一个NET CORE程序

    工具:Visual Studio Code 或者 Visual Studio 环境:.NET CORE 2.0 VS Code很强大 当然支持netcore的开发,但是我还是选择更熟悉更强大的VS. ...

  10. Install Gnome desktop

    Install Gnome desktop http://www.dinggd.com/index.php/freebsd-8-0-rc1-gnome%E6%A1%8C%E9%9D%A2%E5%AE% ...