pytorch之dataloader深入剖析
- dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;
- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;
- 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;
- 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;
① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存
② Queue的特点 当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。
当数据满了: queue.put() 会阻塞 ③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展
输入数据PipeLine
pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img, label in dataloader:
....
所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。
首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。
1.DataLoader
先介绍一下DataLoader(object)的参数:
dataset(Dataset): 传入的数据集
batch_size(int, optional): 每个batch有多少个样本
shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中. drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。 timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each
worker subprocess with the worker id (an int in [0, num_workers - 1]) as
input, after seeding and before data loading. (default: None)
- 首先dataloader初始化时得到datasets的采样list
class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset. Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: None) .. note:: By default, each worker will have its PyTorch seed set to
``base_seed + worker_id``, where ``base_seed`` is a long generated
by main process using its RNG. However, seeds for other libraies
may be duplicated upon initializing workers (w.g., NumPy), causing
each worker to return identical random numbers. (See
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
use ``torch.initial_seed()`` to access the PyTorch seed for each
worker in :attr:`worker_init_fn`, and use it to set other seeds
before data loading. .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
unpicklable object, e.g., a lambda function.
""" __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn if timeout < 0:
raise ValueError('timeout option should be non-negative') if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
self.batch_size = None
self.drop_last = None if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
'shuffle') if self.num_workers < 0:
raise ValueError('num_workers option cannot be negative; '
'use num_workers=0 to disable multiprocessing.') if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset) //将list打乱
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True def __setattr__(self, attr, val):
if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__)) super(DataLoader, self).__setattr__(attr, val) def __iter__(self):
return _DataLoaderIter(self) def __len__(self):
return len(self.batch_sampler)
其中:RandomSampler,BatchSampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!
class RandomSampler(Sampler):
r"""Samples elements randomly, without replacement. Arguments:
data_source (Dataset): dataset to sample from
""" def __init__(self, data_source):
self.data_source = data_source def __iter__(self):
return iter(torch.randperm(len(self.data_source)).tolist()) def __len__(self):
return len(self.data_source)
class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices. Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size`` Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
""" def __init__(self, sampler, batch_size, drop_last):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
- 其中 _DataLoaderIter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;
- 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的
__getitem__()方法;
- 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;
class _DataLoaderIter(object):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) base_seed = torch.LongTensor(1).random_().item() if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {} self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queues[i],
self.worker_result_queue, self.collate_fn, base_seed + i,
self.worker_init_fn, i))
for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
if self.pin_memory:
maybe_device_id = torch.cuda.current_device()
else:
# do not initialize cuda context if not necessary
maybe_device_id = None
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
maybe_device_id))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
self.data_queue = self.worker_result_queue for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start() _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True # prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices() def __len__(self):
return len(self.batch_sampler) def _get_batch(self):
if self.timeout > 0:
try:
return self.data_queue.get(timeout=self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
else:
return self.data_queue.get() def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch # check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch) if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self):
return self def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
self.batches_outstanding += 1
self.send_idx += 1 def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal happened again already.
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
_set_worker_signal_handlers() torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed) if init_fn is not None:
init_fn(worker_id) watchdog = ManagerWatchdog() while True:
try:
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
except queue.Empty:
if watchdog.is_alive():
continue
else:
break
if r is None:
break
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
del samples
- 需要对队列操作,缓存数据,使得加载提速!
pytorch之dataloader深入剖析的更多相关文章
- PyTorch 之 DataLoader
DataLoader DataLoader 是 PyTorch 中读取数据的一个重要接口,该接口定义在 dataloader.py 文件中,该接口的目的: 将自定义的 Dataset 根据 batch ...
- PyTorch之DataLoader杂谈
输入数据PipeLine pytorch 的数据加载到模型的操作顺序是这样的: ①创建一个 Dataset 对象②创建一个 DataLoader 对象③循环这个 DataLoader 对象,将img, ...
- [pytorch修改]dataloader.py 实现darknet中的subdivision功能
dataloader.py import random import torch import torch.multiprocessing as multiprocessing from torch. ...
- 对pytorch中Tensor的剖析
不是python层面Tensor的剖析,是C层面的剖析. 看pytorch下lib库中的TH好一阵子了,TH也是torch7下面的一个重要的库. 可以在torch的github上看到相关文档.看了半天 ...
- Pytorch自定义dataloader以及在迭代过程中返回image的name
pytorch官方给的加载数据的方式是已经定义好的dataset以及loader,如何加载自己本地的图片以及label? 形如数据格式为 image1 label1 image2 label2 ... ...
- 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...
- pytorch中DataLoader, DataSet, Sampler之间的关系
转自:https://mp.weixin.qq.com/s/RTv0cUWvc0kuXBeNoXVu_A 自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为 ...
- pytorch Dataset Dataloader用法(一个示例)
from torch.utils.data import Dataset from torch.utils.data import DataLoader from torch.utils.data i ...
- pytorch 中Dataloader中的collate_fn参数
一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个b ...
随机推荐
- 【翻译】What is State Machine Diagram(什么是状态机图)?
[翻译]What is State Machine Diagram(什么是状态机图)? 写在前面 在上一篇学习类图的时候将这个网站上的类图的一篇文章翻译了出来,感觉受益良多,今天来学习UML状态机图, ...
- 隐马尔科夫模型(HMM)与词性标注问题
一.马尔科夫过程: 在已知目前状态(现在)的条件下,它未来的演变(将来)不依赖于它以往的演变 (过去 ).例如森林中动物头数的变化构成——马尔可夫过程.在现实世界中,有很多过程都是马尔可夫过程,如液体 ...
- 2977 二叉堆练习1 codevs
题目描述 Description 已知一个二叉树,判断它是否为二叉堆(小根堆) 输入描述 Input Description 二叉树的节点数N和N个节点(按层输入) 输出描述 Output Descr ...
- Nginx报Primary script unknown的错误解决
fastcgi_param SCRIPT_FILENAME $document_root$fastcgi_script_name; 改成红色部分变量 root /usr/local/nginx/h ...
- 在Eclipse添加Android兼容包( v4、v7 appcompat )(转)
昨天添加Android兼容包,碰到了很多问题,在这里记录一下,让后面的路好走. 如何选择兼容包, 请参考Android Support Library Features(二) 一.下载Support ...
- 微信小程序开发需要注意的29个坑
1.小程序名称可以由中文.数字.英文.长度在3-20个字符之间,一个中文字等于2个字符. 2.小程序名称不得与公众平台已有的订阅号.服务号重复.如提示重名,请更换名称进行设置. 3.小程序名称在帐号信 ...
- 在线即时展现 Html、JS、CSS 编辑工具 - JSFiddle
在线即时展现 Html.JS.CSS 编辑工具 - JSFiddle 想对它做些说明介绍.但是它确是那么的easy使用. 兴许有时间,把左側列表作以相关介绍和演示样例演示吧.
- [DLX反复覆盖] hdu 3656 Fire station
题意: N个点.再点上建M个消防站. 问消防站到每一个点的最大距离的最小是多少. 思路: DLX直接二分推断TLE了. 这时候一个非常巧妙的思路 我们求的距离一定是两个点之间的距离 因此我们把距离都求 ...
- delphi project of object
http://www.cnblogs.com/ywangzi/archive/2012/08/28/2659811.html 其实要了解这些东西,适当的学些反汇编,WINDOWS内存管理机制,PE结构 ...
- c#开发activex注册问题
最近使用c#开发activex,遇到一个问题,生成的dll文件在本地可以嵌入到web里面,但是到其他机器上就会出现activex无法加载的情况,页面里面出现一个红色的X.mfc开发的activex是使 ...