关于为什么要用Sampler可以阅读一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

本文我们会从源代码的角度了解Sampler。

Sampler

首先需要知道的是所有的采样器都继承自Sampler这个类,如下:

可以看到主要有三种方法:分别是:

  • __init__: 这个很好理解,就是初始化
  • __iter__: 这个是用来产生迭代索引值的,也就是指定每个step需要读取哪些数据
  • __len__: 这个是用来返回每次迭代器的长度
class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
# 一个 迭代器 基类
def __init__(self, data_source):
pass def __iter__(self):
raise NotImplementedError def __len__(self):
raise NotImplementedError

子类Sampler

介绍完父类后我们看看Pytorch给我们提供了哪些采样器

SequentialSampler

这个看名字就很好理解,其实就是按顺序对数据集采样。

其原理是首先在初始化的时候拿到数据集data_source,之后在__iter__方法中首先得到一个和data_source一样长度的range可迭代器。每次只会返回一个索引值

class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
# 产生顺序 迭代器
def __init__(self, data_source):
self.data_source = data_source def __iter__(self):
return iter(range(len(self.data_source))) def __len__(self):
return len(self.data_source)

使用示例:

a = [1,5,78,9,68]
b = torch.utils.data.SequentialSampler(a)
for x in b:
print(x) >>> 0
1
2
3
4

RandomSampler

参数作用:

  • data_source: 同上
  • num_samples: 指定采样的数量,默认是所有。
  • replacement: 若为True,则表示可以重复采样,即同一个样本可以重复采样,这样可能导致有的样本采样不到。所以此时我们可以设置num_samples来增加采样数量使得每个样本都可能被采样到。
class RandomSampler(Sampler):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``, default=False
""" def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self.num_samples = num_samples if self.num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.") if self.num_samples is None:
self.num_samples = len(self.data_source) if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integeral "
"value, but got num_samples={}".format(self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement)) def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist()) def __len__(self):
return len(self.data_source)

SubsetRandomSampler

class SubsetRandomSampler(Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
""" def __init__(self, indices):
self.indices = indices def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices))) def __len__(self):
return len(self.indices)

这个采样器常见的使用场景是将训练集划分成训练集和验证集,示例如下:

n_train = len(train_dataset)
split = n_train // 3
indices = random.shuffle(list(range(n_train)))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
train_loader = DataLoader(..., sampler=train_sampler, ...)
valid_loader = DataLoader(..., sampler=valid_sampler, ...)

WeightedRandomSampler

参数作用同上面的RandomSampler,不再赘述。


class WeightedRandomSampler(Sampler):
r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
Arguments:
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
""" def __init__(self, weights, num_samples, replacement=True):
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
num_samples <= 0:
raise ValueError("num_samples should be a positive integeral "
"value, but got num_samples={}".format(num_samples))
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(replacement))
self.weights = torch.tensor(weights, dtype=torch.double)
self.num_samples = num_samples
self.replacement = replacement def __iter__(self):
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist()) def __len__(self):
return self.num_samples ## 指的是一次一共采样的样本的数量

BatchSampler

前面的采样器每次都只返回一个索引,但是我们在训练时是对批量的数据进行训练,而这个工作就需要BatchSampler来做。也就是说BatchSampler的作用就是将前面的Sampler采样得到的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回。

class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
# 批次采样
def __init__(self, sampler, batch_size, drop_last):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com




2020-01-23 17:45:35

Pytorch Sampler详解的更多相关文章

  1. 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)

    首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...

  2. 目标检测之Faster-RCNN的pytorch代码详解(模型训练篇)

    本文所用代码gayhub的地址:https://github.com/chenyuntc/simple-faster-rcnn-pytorch  (非本人所写,博文只是解释代码) 好长时间没有发博客了 ...

  3. 目标检测之Faster-RCNN的pytorch代码详解(模型准备篇)

    十月一的假期转眼就结束了,这个假期带女朋友到处玩了玩,虽然经济仿佛要陷入危机,不过没关系,要是吃不上饭就看书,吃精神粮食也不错,哈哈!开个玩笑,是要收收心好好干活了,继续写Faster-RCNN的代码 ...

  4. Pytorch autograd,backward详解

    平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...

  5. pytorch之nn.Conv1d详解

    转自:https://blog.csdn.net/sunny_xsc1994/article/details/82969867,感谢分享 pytorch之nn.Conv1d详解

  6. [转载]Pytorch详解NLLLoss和CrossEntropyLoss

    [转载]Pytorch详解NLLLoss和CrossEntropyLoss 来源:https://blog.csdn.net/qq_22210253/article/details/85229988 ...

  7. 【小白学PyTorch】11 MobileNet详解及PyTorch实现

    文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...

  8. 【小白学PyTorch】21 Keras的API详解(上)卷积、激活、初始化、正则

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑答疑解惑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx6450 ...

  9. PyTorch常用参数初始化方法详解

    1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...

随机推荐

  1. 【06月10日】A股ROE最高排名

    个股滚动ROE = 最近4个季度的归母净利润 / ((期初归母净资产 + 期末归母净资产) / 2). 查看更多个股ROE最高排名 兰州民百(SH600738) - ROE_TTM:86.45% - ...

  2. Ribbon核心组件IRule及配置指定的负载均衡算法

    Ribbon在工作时分为两步: 第一步:先选择 EurekaServer,它优先选择在同一个区域内负载较少的Server: 第二步:再根据用户指定的策略,在从Server取到的服务注册列表中选择一个地 ...

  3. c# EPPlus读取Excel里面的时间字段时,1900-01-01转成了1899-12-31

    看到一篇文章:https://bbs.csdn.net/topics/70511379,5楼的回复: 我发现EXCEL有千年虫BUG,把1900年算成闰年了,2月有29天.1900年3月1日以后就没问 ...

  4. 解决 ImportError: cannot import name 'initializations' from 'keras' (C:\Users\admin\AppData\Roaming\Python\Python37\site-packages\keras\__init__.py)

    解决 ImportError: cannot import name 'initializations' from 'keras' : [原因剖析] 上述代码用的是 Keras version: '1 ...

  5. mknod命令的使用

    1.mknod命令 在Linux系统下,mknod命令可用于系统下字符设备文件和块设备文件的创建. (1)命令语法 mknod(选项)(参数) (2)常用选项说明 -Z:设置安全的上下文. -m:设置 ...

  6. Zhe Second Language——

    You got things to do. Place to go. People to see. Futures to make. 我们还有事要做,有地儿要去,有人要见,有梦要实现. Remembe ...

  7. Zookeeper的介绍与基本部署

    目录 简介 架构 安装 StandAlone模式 1. 安装 2. 修改配置 3. 启动 4. 验证 5. 基本用法 Distributed模式 1. 配置hosts 2. 配置zoo.cfg 3. ...

  8. springmvc流程图以及配置

    springmvc:是完成数据的封装和跳转的功能 流程图如下: springmvc的配置流程 1.导入jar包 二.配置servlet文件 init-param的作用是在启动servlet启动时规定其 ...

  9. Scala Type Parameters 2

    类型关系 Scala 支持在泛型类上使用型变注释,用来表示复杂类型.组合类型的子类型关系间的相关性 协变 +T,变化方向相同,通常用在生产 假设 A extends T, 对于 Clazz[+T],则 ...

  10. C++动态规划求解0-1背包问题

    问题描述: 给定n种物品和一背包.物品i的重量是wi,其价值为vi,背包的容量为C.问:应该如何选择装入背包的物品,是的装入背包中物品的总价值最大? 细节须知: 暂无. 算法原理: a.最优子结构性质 ...