keras系列︱利用fit_generator最小化显存占用比率/数据Batch化
本文主要参考两篇文献:
1、《深度学习theano/tensorflow多显卡多人使用问题集》
2、基于双向LSTM和迁移学习的seq2seq核心实体识别
运行机器学习算法时,很多人一开始都会有意无意将数据集默认直接装进显卡显存中,如果处理大型数据集(例如图片尺寸很大)或是网络很深且隐藏层很宽,也可能造成显存不足。
这个情况随着工作的深入会经常碰到,解决方法其实很多人知道,就是分块装入。以keras为例,默认情况下用fit方法载数据,就是全部载入。换用fit_generator方法就会以自己手写的方法用yield逐块装入。这里稍微深入讲一下fit_generator方法。
.
— fit_generator源码
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight=None, max_q_size=10, **kwargs):
.
— generator 该怎么写?
其中generator参数传入的是一个方法,validation_data参数既可以传入一个方法也可以直接传入验证数据集,通常我们都可以传入方法。这个方法需要我们自己手写,伪代码如下:
def generate_batch_data_random(x, y, batch_size):
"""逐步提取batch数据到显存,降低对显存的占用"""
ylen = len(y)
loopcount = ylen // batch_size
while (True):
i = randint(0,loopcount)
yield x[i * batch_size:(i + 1) * batch_size], y[i * batch_size:(i + 1) * batch_size]
.
为什么推荐在自己写的方法中用随机呢?
- 因为fit方法默认shuffle参数也是True,fit_generator需要我们自己随机打乱数据。
- 另外,在方法中需要用while写成死循环,因为每个epoch不会重新调用方法,这个是新手通常会碰到的问题。
当然,如果原始数据已经随机打乱过,那么可以不在这里做随机处理。否则还是建议加上随机取数逻辑(如果数据集比较大则可以保证基本乱序输出)。深度学习中随机打乱数据是非常重要的,具体参见《深度学习Deep Learning》一书的8.1.3节:《Batch and Minibatch Algorithm》。(2017年5月25日补充说明)
.
调用示例:
model.fit_generator(self.generate_batch_data_random(x_train, y_train, batch_size),
samples_per_epoch=len(y_train)//batch_size*batch_size,
nb_epoch=epoch,
validation_data=self.generate_valid_data(x_valid, y_valid,batch_size),
nb_val_samples=(len(y_valid)//batch_size*batch_size),
verbose=verbose,
callbacks=[early_stopping])
这样就可以将对显存的占用压低了,配合第一部分的方法可以方便同时执行多程序。
.
来看看一个《基于双向LSTM和迁移学习的seq2seq核心实体识别》实战案例:
'''
gen_matrix实现从分词后的list来输出训练样本
gen_target实现将输出序列转换为one hot形式的目标
超过maxlen则截断,不足补0
'''
gen_matrix = lambda z: np.vstack((word2vec[z[:maxlen]], np.zeros((maxlen-len(z[:maxlen]), word_size))))
gen_target = lambda z: np_utils.to_categorical(np.array(z[:maxlen] + [0]*(maxlen-len(z[:maxlen]))), 5)
#从节省内存的角度,通过生成器的方式来训练
def data_generator(data, targets, batch_size):
idx = np.arange(len(data))
np.random.shuffle(idx)
batches = [idx[range(batch_size*i, min(len(data), batch_size*(i+1)))] for i in range(len(data)/batch_size+1)]
while True:
for i in batches:
xx, yy = np.array(map(gen_matrix, data[i])), np.array(map(gen_target, targets[i]))
yield (xx, yy)
batch_size = 1024
history = model.fit_generator(data_generator(d['words'], d['label'], batch_size), samples_per_epoch=len(d), nb_epoch=200)
model.save_weights('words_seq2seq_final_1.model')
延伸一:edwardlib/observations 规范数据导入、数据Batch化
def generator(array, batch_size):
"""Generate batch with respect to array's first axis."""
start = 0 # pointer to where we are in iteration
while True:
stop = start + batch_size
diff = stop - array.shape[0]
if diff <= 0:
batch = array[start:stop]
start += batch_size
else:
batch = np.concatenate((array[start:], array[:diff]))
start = diff
yield batch
To use it, simply write
from observations import cifar10
(x_train, y_train), (x_test, y_test) = cifar10("~/data")
x_train_data = generator(x_train, 256)
keras系列︱利用fit_generator最小化显存占用比率/数据Batch化的更多相关文章
- Linux显存占用无进程清理方法(附批量清理命令)
在跑TensorFlow.pytorch之类的需要CUDA的程序时,强行Kill掉进程后发现显存仍然占用,这时候可以使用如下命令查看到top或者ps中看不到的进程,之后再kill掉: fuser -v ...
- 深度学习中GPU和显存分析
刚入门深度学习时,没有显存的概念,后来在实验中才渐渐建立了这个意识. 下面这篇文章很好的对GPU和显存总结了一番,于是我转载了过来. 作者:陈云 链接:https://zhuanlan.zhihu. ...
- keras系列︱keras是如何指定显卡且限制显存用量
keras在使用GPU的时候有个特点,就是默认全部占满显存. 若单核GPU也无所谓,若是服务器GPU较多,性能较好,全部占满就太浪费了. 于是乎有以下三种情况: - 1.指定GPU - 2.使用固定显 ...
- 我的Keras使用总结(5)——Keras指定显卡且限制显存用量,常见函数的用法及其习题练习
Keras 是一个高层神经网络API,Keras是由纯Python编写而成并基于TensorFlow,Theano以及CNTK后端.Keras为支持快速实验而生,能够将我们的idea迅速转换为结果.好 ...
- keras系列︱Sequential与Model模型、keras基本结构功能(一)
引自:http://blog.csdn.net/sinat_26917383/article/details/72857454 中文文档:http://keras-cn.readthedocs.io/ ...
- [Pytorch]深度模型的显存计算以及优化
原文链接:https://oldpan.me/archives/how-to-calculate-gpu-memory 前言 亲,显存炸了,你的显卡快冒烟了! torch.FatalError: cu ...
- MegEngine亚线性显存优化
MegEngine亚线性显存优化 MegEngine经过工程扩展和优化,发展出一套行之有效的加强版亚线性显存优化技术,既可在计算存储资源受限的条件下,轻松训练更深的模型,又可使用更大batch siz ...
- 自制操作系统Antz(3)——进入保护模式 (中) 直接操作显存
Antz系统更新地址: https://www.cnblogs.com/LexMoon/category/1262287.html Linux内核源码分析地址:https://www.cnblogs. ...
- 解决GPU显存未释放问题
前言 今早我想用多块GPU测试模型,于是就用了PyTorch里的torch.nn.parallel.DistributedDataParallel来支持用多块GPU的同时使用(下面简称其为Dist). ...
随机推荐
- 20145301实验五 Java网络编程及安全
北京电子科技学院(BESTI)实验报告 课程:Java程序设计 班级:1453 指导教师:娄嘉鹏 实验日期:2016.05.06 18:30-21:30 实验名称:实验五 Java网络编程 实验内容 ...
- 网络攻防工具介绍——Metasploit
Metasploit 简介 Metasploit是一款开源的安全漏洞检测工具,可以帮助安全和IT专业人士识别安全性问题,验证漏洞的缓解措施,并管理专家驱动的安全性进行评估,提供真正的安全风险情报.这些 ...
- markdown工作随笔总结
1. 锚点 (使用方法和链接很像) ## 目录 1. [命名](#命名) ....... **[返回顶部](#目录)** ## 命名 ###命名原则 可以从返回顶部回到目录,也可以点击目录的命名跳到命 ...
- 如何使一个openwrt下的软件开机自启动
条件有三: 1.需要在软件包的Makefile中添加宏定义Package/$(package-name)/preinst和Package/$(package-name)/prerm define Pa ...
- [BZOJ2109]Plane 航空管制
Description 世博期间,上海的航空客运量大大超过了平时,随之而来的航空管制也频频 发生.最近,小X就因为航空管制,连续两次在机场被延误超过了两小时.对此, 小X表示很不满意. 在这次来烟台的 ...
- codeforces 848c - two TVs
2017-08-22 15:42:44 writer:pprp 参考:http://blog.csdn.net/qq_37497322/article/details/77463376#comment ...
- 使用 Python 连接 Caché 数据库
有不少医院的 HIS 系统用的是 Caché 数据库,比如北京协和医院.四川大学华西医院等.用过 Caché 开发的都知道,Caché 数据库的开发维护同我们常见的关系型数据库有很大差别,如 SQL ...
- angular2.x 下拉多选框选择组件
angular2.x - 5.x 的下拉多选框选择组件 ng2 -- ng5.最近在学angular4,经常在交流群看见很多人问 下拉多选怎么做... 今天就随便写的个. 组件源码 百度云 链接: ...
- (转)全面认识一下.NET 4的缓存功能
很多关于.NET 4.0新特性的介绍,缓存功能的增强肯定是不会被忽略的一个重要亮点.在很多文档中都会介绍到在.NET 4.0中,缓存功能的增强主要是在扩展性方面做了改进,改变了原来只能利用内存进行缓存 ...
- Svn Replacement For Git Stash
svn 实现git stash类似的功能 % svn diff > WorkInProgress.txt % svn revert -R . <make changes> % svn ...