Making training mini-batches
Here is where we'll make our mini-batches for training. Remember that we want our batches to be multiple sequences of some desired number of sequence steps. Considering a simple example, our batches would look like this:
We have our text encoded as integers as one long array in encoded
. Let's create a function that will give us an iterator for our batches. I like using generator functions to do this. Then we can pass encoded
into this function and get our batch generator.
The first thing we need to do is discard some of the text so we only have completely full batches. Each batch contains N×MN×M characters, where NN is the batch size (the number of sequences) and MM is the number of steps. Then, to get the number of batches we can make from some array arr
, you divide the length of arr
by the batch size. Once you know the number of batches and the batch size, you can get the total number of characters to keep.
After that, we need to split arr
into NN sequences. You can do this using arr.reshape(size)
where size
is a tuple containing the dimensions sizes of the reshaped array. We know we want NN sequences (n_seqs
below), let's make that the size of the first dimension. For the second dimension, you can use -1
as a placeholder in the size, it'll fill up the array with the appropriate data for you. After this, you should have an array that is N×(M∗K)N×(M∗K) where KK is the number of batches.
Now that we have this array, we can iterate through it to get our batches. The idea is each batch is a N×MN×M window on the array. For each subsequent batch, the window moves over by n_steps
. We also want to create both the input and target arrays. Remember that the targets are the inputs shifted over one character. You'll usually see the first input character used as the last target character, so something like this:
y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]
where x
is the input batch and y
is the target batch.
The way I like to do this window is use range
to take steps of size n_steps
from 00 to arr.shape[1]
, the total number of steps in each sequence. That way, the integers you get from range
always point to the start of a batch, and each window is n_steps
wide.
def get_batches(arr, n_seqs, n_steps):
'''Create a generator that returns batches of size
n_seqs x n_steps from arr. Arguments
---------
arr: Array you want to make batches from
n_seqs: Batch size, the number of sequences per batch
n_steps: Number of sequence steps per batch
'''
# Get the number of characters per batch and number of batches we can make
characters_per_batch = n_seqs * n_steps
n_batches = len(arr) // characters_per_batch # Keep only enough characters to make full batches
arr = arr[:n_batches*characters_per_batch] # Reshape into n_seqs rows
arr = arr.reshape((n_seqs, -1)) for n in range(0, arr.shape[1], n_steps):
# The features
x = arr[:, n:n+n_steps]
# The targets, shifted by one
y = np.zeros_like(x)
y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]
yield x, y batches = get_batches(encoded, 10, 50)
x, y = next(batches) print('x\n', x[:10, :10])
print('\ny\n', y[:10, :10])
Making training mini-batches的更多相关文章
- 第四章——训练模型(Training Models)
前几章在不知道原理的情况下,已经学会使用了多个机器学习模型机器算法.Scikit-Learn很方便,以至于隐藏了太多的实现细节. 知其然知其所以然是必要的,这有利于快速选择合适的模型.正确的训练算法. ...
- deeplearning.ai 改善深层神经网络 week2 优化算法 听课笔记
这一周的主题是优化算法. 1. Mini-batch: 上一门课讨论的向量化的目的是去掉for循环加速优化计算,X = [x(1) x(2) x(3) ... x(m)],X的每一个列向量x(i)是 ...
- Must Know Tips/Tricks in Deep Neural Networks
Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei) Deep Neural Networks, especially C ...
- Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week2, Assignment(Optimization Methods)
声明:所有内容来自coursera,作为个人学习笔记记录在这里. 请不要ctrl+c/ctrl+v作业. Optimization Methods Until now, you've always u ...
- MLlib之LR算法源码学习
/** * :: DeveloperApi :: * GeneralizedLinearModel (GLM) represents a model trained using * Generaliz ...
- Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei)
http://lamda.nju.edu.cn/weixs/project/CNNTricks/CNNTricks.html Deep Neural Networks, especially Conv ...
- GeneralizedLinearAlgorithm in Spark MLLib
GeneralizedLinearAlgorithm SparkMllib涉及到的算法 Classification Linear Support Vector Machines (SVMs) Log ...
- 课程二(Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization),第二周(Optimization algorithms) —— 2.Programming assignments:Optimization
Optimization Welcome to the optimization's programming assignment of the hyper-parameters tuning spe ...
- 几种梯度下降方法对比(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)
https://blog.csdn.net/u012328159/article/details/80252012 我们在训练神经网络模型时,最常用的就是梯度下降,这篇博客主要介绍下几种梯度下降的变种 ...
- 吴恩达深度学习笔记(五) —— 优化算法:Mini-Batch GD、Momentum、RMSprop、Adam、学习率衰减
主要内容: 一.Mini-Batch Gradient descent 二.Momentum 四.RMSprop 五.Adam 六.优化算法性能比较 七.学习率衰减 一.Mini-Batch Grad ...
随机推荐
- Android USB转串口通信开发基本流程
好久没有写文章了,年前公司新开了一个项目,是和usb转串口通信相关的,需求是用安卓平板通过usb转接后与好几个外设进行通信.一直忙到近期,才慢慢闲下来,趁着这个周末不忙.记录下usb转串口通信开发的基 ...
- 解决异常断电导致的: CorruptSSTableException: java.io.EOFException
问题产生 服务器重启,导致cassandra损坏,整个集群不可用.所使用的cassandra为2.1.9版本. 问题描述 运行启动命令,报错如下: DEBUG :: All segments have ...
- server2008系统修改3389远程端口
我给大家简单谈谈正确修改远程端口的方法 在开始-----运行菜单里,输入regedit,进入注册表编辑,按先面的路径进入修改端口的地方 HKEY_LOCAL_MACHINE\System ...
- 为每个页面加上Session判断
首先新建一个类,继承自System.Web.UI.Page,然后重写OnInit,如下: using System; using System.Data; using System.Configura ...
- 点滴积累【JS】---JS小功能(checkbox实现全选和全取消)
效果: 代码: <head runat="server"> <title></title> <script type="text ...
- Sublime Text 2 入门与总结
Sublime Text 2 入门与总结 首语 : 考完试,但又没什么兴趣做课程设计,蛋疼的弄点软件入门的介绍,希望给各位还在吃香蕉的程序猿带来一点启示... 代码编辑器,就像武侠中的武 ...
- python之读取Excel 文件
# -*- coding: utf-8 -*- """ Created on Thu May 24 13:53:10 2018 @author: Frank " ...
- 通过某个进程号显示该进行打开的文件 lsof -p 1 11. 列出多个进程号对应的文件信息 lsof -p 123,456,789 5. 列出某个用户打开的文件信息 lsof -u username
linux命令 — lsof 查看进程打开那些文件 或者 查看文件给那个进程使用 lsof命令是什么? 可以列出被进程所打开的文件的信息.被打开的文件可以是 1.普通的文件,2.目录 3.网络文件系 ...
- sql server 订阅发布的配置
网上sql server 的发布订阅功能的教程很多,但是很多东西写的不是很详细,常常给人误解,现在根据自己的情况从新整理一下: 1.服务器端 然后一路下一步, 2.订阅端(重点) 给服务器在本地取一 ...
- sql语句中3表删除和3表查询
好久没来咱们博客园了,主要近期在忙一些七七八八的杂事,包括打羽毛球比赛的准备和自己在学jqgrid的迷茫.先不扯这些没用的了,希望大家能记得小弟,小弟在此谢过大家了. 回归正题:(以下的sql是本人在 ...