From: https://liudongdong1.github.io/

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers.

  • Dataloader: offer a way to parallelly load data, batch load, and offer shuffle policy.
  • Dataset: the dataset entry, offer getitem function., Transformer function 在这里执行。

0. DataLoader

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
  • dataset (Dataset) – dataset from which to load the data.
  • batch_size (int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler or Iterable**, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
  • batch_sampler (Sampler or Iterable**, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • collate_fn (callable**, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • timeout (numeric**, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callable**, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
  • prefetch_factor (int, optional**, keyword-only arg) – Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2)
  • persistent_workers (bool, optional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)

.1. Map-style Dataset

implements the __getitem__() and __len__() protocols, and represents a map from (poissibly non-integral) indices/keys to data samples. dataset[idx]={image, label}

.2. Iterable-style dataset

implements the __iter__() protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.

class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`. .. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
""" def __getitem__(self, index) -> T_co:
raise NotImplementedError def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other]) # No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py

1. Available Datasets

All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively.

Take MNIST for example:

from .vision import VisionDataset
import warnings
from PIL import Image
import os
import os.path
import numpy as np
import torch
import codecs
import string
from .utils import download_url, download_and_extract_archive, extract_archive, \
verify_str_arg
[docs]class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
resources = [
("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
]
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
data_file = self.training_file
else:
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))
def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files
for url, md5 in self.resources:
filename = url.rpartition('/')[2]
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
# process and save as torch files
print('Processing...')
training_set = (
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)
print('Done!')
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")

2. Generic Dataloader

2.1. ImageFolder

A generic data loader where the images are arranged in this way:

torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=, is_valid_file=None)

__getitem__(index)

  • Parameters

    index (int) – Index

  • Returns

    (sample, target) where target is class_index of the target class.

  • Return type

    tuple

def load_data(root_path, dir, batch_size, phase):
transform_dict = {
'src': transforms.Compose(
[transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]),
'tar': transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])}
data = datasets.ImageFolder(root=root_path + dir, transform=transform_dict[phase])
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4)
return data_loader

2.2. DatasetFolder

torchvision.datasets.``DatasetFolder(root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None)

3. Examples

from multiprocessing import freeze_support
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
# Hyperparameters.
num_epochs = 20
num_classes = 5
batch_size = 100
learning_rate = 0.001
num_of_workers = 5
DATA_PATH_TRAIN = Path('C:/Users/Aeryes/PycharmProjects/simplecnn/images/train/')
DATA_PATH_TEST = Path('C:/Users/Aeryes/PycharmProjects/simplecnn/images/test/')
MODEL_STORE_PATH = Path('C:/Users/Aeryes/PycharmProjects/simplecnn/model')
trans = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])
# Flowers dataset.
train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
test_dataset = datasets.ImageFolder(root=DATA_PATH_TEST, transform=trans)
# Create custom random sampler class to iter over dataloader.
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_of_workers)
for i, (images, labels) in enumerate(train_loader):
# Move images and labels to gpu if available
if cuda_avail:
images = Variable(images.cuda())
labels = Variable(labels.cuda())

Learning Resources

【AI】TorchVision_DataLoad的更多相关文章

  1. 【AI】Exponential Stochastic Cellular Automata for Massively Parallel Inference - 大规模并行推理的指数随机元胞自动机

    [论文标题]Exponential Stochastic Cellular Automata for Massively Parallel Inference     (19th-ICAIS,PMLR ...

  2. 【AI】【人工智能】【计算机】人工智能工程技术人员等职业信息公示

    人社厅发[2019]48号 各省.自治区.直辖市及新疆生产建设兵团人力资源社会保障厅(局).市场监管局.统计局,国务院各部门.各直属机构.各中央企业.有关社会组织人事劳动保障工作机构,中央军委政治工作 ...

  3. 【AI】Android Pie中引入的AI功能

    前言 “无AI,不未来”,绝对不是一句豪情壮语,AI早已进入到了我们生活当中.去年Google发布的Android Pie系统在AI功能方面就做了重大革新,本文就对Google在新系统中引入的AI功能 ...

  4. 【AI】Computing Machinery and Intelligence - 计算机器与智能

    [论文标题] Computing Machinery and Intelligence (1950) [论文作者] A. M. Turing (Alan Mathison Turing) [论文链接] ...

  5. 【AI】【计算机】【中国人工智能学会通讯】【学会通讯2019年第01期】中国人工智能学会重磅发布 《2018 人工智能产业创新评估白皮书》

    封面: 中国人工智能学会重磅发布 <2018 人工智能产业创新评估白皮书> < 2018 人工智能产业创新评估白皮书>由中国人工智能学会.国家工信安全中心.华夏幸福产业研究院. ...

  6. 【AI】蒙特卡洛搜索树

    http://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/ 蒙特卡洛方法与随机优化: http://iacs-co ...

  7. 【AI】PaddlePaddle-Docker运行

    1.参考官方安装Docker环境,使用一键安装包安装 https://www.jianshu.com/p/b2766173d754 http://www.paddlepaddle.org/docume ...

  8. 【AI】神经网络基本词汇

    neural networks 神经网络activation function 激活函数hyperbolic tangent 双曲正切函数bias units 偏置项activation 激活值for ...

  9. 【AI】基本概念-准确率、精准率、召回率的理解

    样本全集:TP+FP+FN+TN TP:样本为正,预测结果为正 FP:样本为负,预测结果为正 TN:样本为负,预测结果为负 FN:样本为正,预测结果为负 准确率(accuracy):(TP+TN)/ ...

随机推荐

  1. Ubuntu命令总结

    sudo apt-get update 系统更新 shutdown -h now 关闭服务器 shutdown -r now 重启服务器 uname -a ubuntu中查看内核版本的命令 gedit ...

  2. I-Identical Day[题解]

    原题目地址(牛客) Identical Day 题目大意 给定一个长度为 \(n\) 的 \(01\) 串,对于每段长度为 \(l\) 的连续的 \(1\) ,其权值为 \(\frac{l\times ...

  3. java+selenium UI自动化001

    selenium是一个用于Web应用程序测试的工具,可以用来模拟用户在浏览器上的操作. 支持的浏览器包括IE(7, 8, 9, 10, 11),Mozilla Firefox,Safari,Googl ...

  4. 如何热更新长缓存的 HTTP 资源

    前言 HTTP 缓存时间一直让开发者头疼.时间太短,性能不够好:时间太长,更新不及时.当遇到严重问题需紧急修复时,尽管后端文件可快速替换,但前端文件仍从本地缓存加载,导致更新长时间无法生效. 对于这个 ...

  5. 每天五分钟Go - 指针

    什么是指针 一个指向内存地址的变量,称为指针变量,指针是一个特殊的变量,他的值存储的是另一个值的内存地址 指针变量的声明 var var_name *type var_name 是指针变量的名称,ty ...

  6. P5110 块速递推-光速幂、斐波那契数列通项

    P5110 块速递推 题意 多次询问,求数列 \[a_i=\begin{cases}233a_{i-1}+666a_{i-2} & i>1\\ 0 & i=0\\ 1 & ...

  7. Python RPC 不会?不妨看看这篇文章

    1. 前言 大家好,我是安果! RPC,全程为 Remote Procedure Call,是一种进程间的通信方式,它采用「 服务端 / 客户机 」模式,是一种请求响应模型 其中,服务端负责提供服务程 ...

  8. C++手写内存池

    引言 使用new expression为类的多个实例分配动态内存时,cookie导致内存利用率可能不高,此时我们通过实现类的内存池来降低overhead.从不成熟到巧妙优化的内存池,得益于union的 ...

  9. Windows系统安装Mariadb数据库(zip包方式安装)--九五小庞

    1.去Mariadb官网下载zip安装包 下载地址:https://downloads.mariadb.org/mariadb/10.3.31/ 2.解压压缩包到指定的安装位置 3.在安装包的data ...

  10. WarError syncing load balancer: failed to ensure load balancer: network.SubnetsClient#Get: Failure responding to request: StatusCode=403

    Warning SyncLoadBalancerFailed 4m55s (x8 over 15m) service-controller Error syncing load balancer: f ...