以下内容都是针对Pytorch 1.0-1.1介绍。



很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍。

自上而下理解三者关系

首先我们看一下DataLoader.next的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。

  1. class DataLoader(object):
  2. ...
  3. def __next__(self):
  4. if self.num_workers == 0:
  5. indices = next(self.sample_iter) # Sampler
  6. batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
  7. if self.pin_memory:
  8. batch = _utils.pin_memory.pin_memory_batch(batch)
  9. return batch

在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。

那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。

再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

综上可以知道DataLoader,Sampler和Dataset三者关系如下:

在阅读后文的过程中,你始终需要将上面的关系记在心里,这样能帮助你更好地理解。

Sampler

参数传递

要更加细致地理解Sampler原理,我们需要先阅读一下DataLoader 的源代码,如下:

  1. class DataLoader(object):
  2. def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
  3. batch_sampler=None, num_workers=0, collate_fn=default_collate,
  4. pin_memory=False, drop_last=False, timeout=0,
  5. worker_init_fn=None)

可以看到初始化参数里有两种sampler:samplerbatch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。例如下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分组。

  1. >>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
  2. >>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Pytorch中已经实现的Sampler有如下几种:

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读源码更深地理解,这里只做总结:

  • 如果你自定义了batch_sampler,那么这些参数都必须使用默认值:batch_size, shuffle,sampler,drop_last.
  • 如果你自定义了sampler,那么shuffle需要设置为False
  • 如果samplerbatch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
    • shuffle=True,则sampler=RandomSampler(dataset)
    • shuffle=False,则sampler=SequentialSampler(dataset)

如何自定义Sampler和BatchSampler?

仔细查看源代码其实可以发现,所有采样器其实都继承自同一个父类,即Sampler,其代码定义如下:

  1. class Sampler(object):
  2. r"""Base class for all Samplers.
  3. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
  4. way to iterate over indices of dataset elements, and a :meth:`__len__` method
  5. that returns the length of the returned iterators.
  6. .. note:: The :meth:`__len__` method isn't strictly required by
  7. :class:`~torch.utils.data.DataLoader`, but is expected in any
  8. calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
  9. """
  10. def __init__(self, data_source):
  11. pass
  12. def __iter__(self):
  13. raise NotImplementedError
  14. def __len__(self):
  15. return len(self.data_source)

所以你要做的就是定义好__iter__(self)函数,不过要注意的是该函数的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))

另外BatchSampler与其他Sampler的主要区别是它需要将Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表。也就是说在后面的读取数据过程中使用的都是batch sampler。

Dataset

Dataset定义方式如下:

  1. class Dataset(object):
  2. def __init__(self):
  3. ...
  4. def __getitem__(self, index):
  5. return ...
  6. def __len__(self):
  7. return ...

上面三个方法是最基本的,其中__getitem__是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]来访问第一个数据。在此之前我一直没弄清楚__getitem__是什么作用,所以一直不知道该怎么进入到这个函数进行调试。现在如果你想对__getitem__方法进行调试,你可以写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:

  1. class DataLoader(object):
  2. ...
  3. def __next__(self):
  4. if self.num_workers == 0:
  5. indices = next(self.sample_iter)
  6. batch = self.collate_fn([self.dataset[i] for i in indices]) # this line
  7. if self.pin_memory:
  8. batch = _utils.pin_memory.pin_memory_batch(batch)
  9. return batch

我们仔细看可以发现,前面还有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前我们需要知道每个参数的意义:

  • indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
  • self.dataset[i]: 前面已经介绍了,这里就是对第i个数据进行读取操作,一般来说self.dataset[i]=(img, label)

看到这不难猜出collate_fn的作用就是将一个batch的数据进行合并操作。默认的collate_fn是将img和label分别合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com


2019-8-6

一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系的更多相关文章

  1. pytorch中DataLoader, DataSet, Sampler之间的关系

    转自:https://mp.weixin.qq.com/s/RTv0cUWvc0kuXBeNoXVu_A 自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为 ...

  2. 一文弄懂pytorch搭建网络流程+多分类评价指标

    讲在前面,本来想通过一个简单的多层感知机实验一下不同的优化方法的,结果写着写着就先研究起评价指标来了,之前也写过一篇:https://www.cnblogs.com/xiximayou/p/13700 ...

  3. 【编码】彻底弄懂ASCII、Unicode、UTF-8之间的关系

    计算机中的所有字符,说到底都是用二进制的0.1的排列组合来表示的,因此就需要有一个规范,来枚举规定每个字符对应哪个0.1的排列组合,这样的规范就是字符集. ASCII 全称是“美国信息交换标准码”(A ...

  4. 一文弄懂神经网络中的反向传播法——BackPropagation【转】

    本文转载自:https://www.cnblogs.com/charlotte77/p/5629865.html 一文弄懂神经网络中的反向传播法——BackPropagation   最近在看深度学习 ...

  5. 一文弄懂-Netty核心功能及线程模型

    目录 一. Netty是什么? 二. Netty 的使用场景 三. Netty通讯示例 1. Netty的maven依赖 2. 服务端代码 3. 客户端代码 四. Netty线程模型 五. Netty ...

  6. 一文弄懂-《Scalable IO In Java》

    目录 一. <Scalable IO In Java> 是什么? 二. IO架构的演变历程 1. Classic Service Designs 经典服务模型 2. Event-drive ...

  7. 一文弄懂-BIO,NIO,AIO

    目录 一文弄懂-BIO,NIO,AIO 1. BIO: 同步阻塞IO模型 2. NIO: 同步非阻塞IO模型(多路复用) 3.Epoll函数详解 4.Redis线程模型 5. AIO: 异步非阻塞IO ...

  8. 一文弄懂CGAffineTransform和CTM

    一文弄懂CGAffineTransform和CTM 一些概念 坐标空间(系):视图(View)坐标空间与绘制(draw)坐标空间 CTM:全称current transformation matrix ...

  9. 【TensorFlow】一文弄懂CNN中的padding参数

    在深度学习的图像识别领域中,我们经常使用卷积神经网络CNN来对图像进行特征提取,当我们使用TensorFlow搭建自己的CNN时,一般会使用TensorFlow中的卷积函数和池化函数来对图像进行卷积和 ...

随机推荐

  1. Unix/Linux小计

    1. centos查看cpu信息 cat /proc/cpuinfo processor有几个就是有几个cpu,每一列是每个cpu的信息 每个processor中的cores是当前cpu中有几个核心. ...

  2. 命名法:骆驼(Camel)、帕斯卡(pascal)、匈牙利(Hungarian)、下划线(_)

    首先欢迎大家到来! 常用的命名法:骆驼(Camel).帕斯卡(pascal).匈牙利(Hungarian).下划线(_) 骆驼:是指混合使用大小写字母来构成变量和函数的名字 帕斯卡:与骆驼命名法类似只 ...

  3. Oracle--存储过程中之循环语句

    一般循环语句有两种: 1)使用for循环实现 declare  cursor cur is    select * from tablename;   aw_row  tablename%rowtyp ...

  4. docker 学习操作记录 3

    记录3 [BEGIN] // :: Last :: from 192.168.114.1 root@coder:~# man addgroup ADDUSER() System Manager's M ...

  5. C中关键字inline用法

    一.什么是内联函数 在C语言中,如果一些函数被频繁的调用,不断地用函数入栈,即函数栈,则会造成栈空间或者栈内存的大量消耗,为了解决这个问题,特别的引入了inline关键字,表示为内联函数.栈空间指的是 ...

  6. Shuffle an Array (水塘抽样)

    随机性问题 水塘抽样算法可保证每个样本被抽到的概率相等 使用场景:从包含n个项目的集合S中选取k个样本,其中n为一很大或未知的数量,尤其适用于不能把所有n个项目都存放到主内存的情况 Knuth洗牌算法 ...

  7. SQL系列(十二)—— insert update delete

    前言 这个系列的前面都一直在介绍查询select.但是SQL中十分广泛,按对数据的不同处理可以分为: DML:全称Data Manipulation Language,从名字上可以看出,DML是对数据 ...

  8. SQL系列(五)—— 排序(order by)

    对查询结果进行排序是日常应用开发中最为常见的需求,在SQL中通过order by实现.order by是select语句中一部分,即子句. 1.order by 1.1 单列排序 其实,检索出的数据并 ...

  9. (原创)对比组态软件,使用C#开发的服务器和客户端软件的优势

    在当前经济形势和市场环境下,中小企业面对萧条的消费市场,恶化的外部贸易环境,刚性支出高成本人工和生产要素,通货膨胀,隐性的腐化支出等各种因素的作用导致企业生存艰难,企业需要在各方面削减支出,拓展市场寻 ...

  10. vs2019 netocore项目本地程序ip地址访问需修改的配置文件

    IISPress启动项目后,打开IISPress托盘可以看到当前项目 根据图中标识出来的applicationhost.config文件路径,一般为你的项目解决方案目录下的.vs\解决方案文件夹\co ...