
trn_data = datasets.CIFAR10(root=data_path, train=True, download=False, transform=train_transform)
shape = trn_data.train_data.shape


shape = trn_data.data.shape


from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys if sys.version_info[0] == 2:
import cPickle as pickle
import pickle from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive [docs]class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. """
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
] test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
} def __init__(self, root, train=True, transform=None, target_transform=None,
download=False): super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform) self.train = train # training set or test set if download:
self.download() if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it') if self.train:
downloaded_list = self.train_list
downloaded_list = self.test_list self.data = []
self.targets = [] # now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
entry = pickle.load(f, encoding='latin1')
if 'labels' in entry:
self.targets.extend(entry['fine_labels']) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC self._load_meta()


