Pytorch数据读取详解
原文:http://studyai.com/article/11efc2bf#采样器 Sampler & BatchSampler
数据库DataBase + 数据集DataSet + 采样器Sampler = 加载器Loader
from torch.utils.data import *
IMDB
+ Dataset
+ Sampler
|| BatchSampler
= DataLoader
数据库 DataBase
Image DataBase 简称IMDB,指的是存储在文件中的数据信息。
文件格式可以多种多样。比如xml, yaml, json, sql.
VOC是xml格式的,COCO是JSON格式的。
构造IMDB的过程,就是解析这些文件,并建立数据索引的过程。
一般会被解析为Python列表, 以方便后续迭代读取。
数据集 DataSet
数据集 DataSet: 在数据库IMDB的基础上,提供对数据的单例或切片访问方法。
换言之,就是定义数据库中对象的索引机制,如何实现单例索引或切片索引。
简言之,DataSet,通过__getitem__
定义了数据集DataSet是一个可索引对象,An Indexerable Object。
即传入一个给定的索引Index之后,如何按此索引进行单例或切片访问,单例还是切片视Index是单值还是列表。
Pytorch源码如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
# 定义单例/切片访问方法,即 dataItem = Dataset[index]
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
自定义数据集要基于上述Dataset基类、IMDB基类,有两种方法。
# 方法一: 单继承
class XxDataset(Dataset)
# 将IMDB作为参数传入,进行二次封装
imdb = IMDB()
pass
# 方法二: 双继承
class XxDataset(IMDB, Dataset):
pass
采样器 Sampler & BatchSampler
在实际应用中,数据并不一定是循规蹈矩的序惯访问,而需要随机打乱顺序来访问,或需要随机加权访问,
因此,按某种特定的规则来读取数据,就是采样操作,需要定义采样器:Sampler。
另外,数据也可能并不是一个一个读取的,而需要一批一批的读取,即需要批量采样操作,定义批量采样器:BatchSampler。
所以,只有Dataset的单例访问方法还不够,还需要在此基础上,进一步的定义批量访问方法。
简言之,采样器定义了索引(index)的产生规则,按指定规则去产生索引,从而控制数据的读取机制
BatchSampler 是基于 Sampler 来构造的: BatchSampler = Sampler + BatchSize
Pytorch源码如下,
class Sampler(object):
"""Base class for all Samplers.
采样器基类,可以基于此自定义采样器。
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
# 序惯采样
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
# 随机采样
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)
# 随机子采样
class SubsetRandomSampler(Sampler):
pass
# 加权随机采样
class WeightedRandomSampler(Sampler):
pass
class BatchSampler(object):
"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler # ******
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
由上可见,Sampler本质就是个具有特定规则的可迭代对象,但只能单例迭代。
如 [x for x in range(10)]
, range(10)就是个最基本的Sampler,每次循环只能取出其中的一个值.
[x for x in range(10)]
Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import SequentialSampler
[x for x in SequentialSampler(range(10))]
Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import RandomSampler
[x for x in RandomSampler(range(10))]
Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]
BatchSampler对Sampler进行二次封装,引入了batchSize参数,实现了批量迭代。
from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]
加载器 DataLoader
在实际计算中,如果数据量很大,考虑到内存有限,且IO速度很慢,
因此不能一次性的将其全部加载到内存中,也不能只用一个线程去加载。
因而需要多线程、迭代加载, 因而专门定义加载器:DataLoader。
DataLoader 是一个可迭代对象, An Iterable Object, 内部配置了魔法函数——iter——
,调用它将返回一个迭代器。
该函数可用内置函数iter
直接调用,即 DataIteror = iter(DataLoader)
。
dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)
__init__
参数包含两部分,前半部分用于指定数据集 + 采样器
,后半部分为多线程参数
。
class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
# 检测是否存在参数冲突: 默认batchSampler vs 自定义BatchSampler
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
# 在此处会强行指定一个 BatchSampler
if batch_sampler is None:
# 在此处会强行指定一个 Sampler
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
# 使用自定义的采样器和批采样器
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
# 调用Pytorch的多线程迭代器加载数据
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
数据迭代器 DataLoaderIter
迭代器与可迭代对象之间是有区别的。
可迭代对象,意思是对其使用Iter
函数时,它可以返回一个迭代器,从而可以连续的迭代访问它。
迭代器对象,内部有额外的魔法函数__next__
,用内置函数next
作用其上,则可以连续产生下一个数据,产生规则即是由此函数来确定的。
可迭代对象描述了对象具有可迭代性,但具体的迭代规则由迭代器来描述,这样解耦的好处是可以对同一个可迭代对象配置多种不同规则的迭代器。
数据集/容器遍历的一般化流程:NILIS
NILIS规则
: data = next(iter(loader(DataSet[sampler])))data=next(iter(loader(DataSet[sampler])))
- sampler 定义索引index的生成规则,返回一个index列表,控制后续的索引访问过程。
- indexer 基于
__item__
在容器上定义按索引访问的规则,让容器成为可索引对象,可用[]操作。 - loader 基于
__iter__
在容器上定义可迭代性,描述加载规则,包括返回一个迭代器,让容器成为可迭代对象, 可用iter()操作。 - next 基于
__next__
在容器上定义迭代器,描述具体的迭代规则,让容器成为迭代器对象, 可用next()操作。
## 初始化
sampler = Sampler()
dataSet = DataSet(sampler) # __getitem__
dataLoader = DataLoader(dataSet, sampler) / DataIterable() # __iter__()
dataIterator = DataLoaderIter(dataLoader) #__next__()
data_iter = iter(dataLoader)
## 遍历方法1
for _ in range(len(data_iter))
data = next(data_iter)
## 遍历方法2
for i, data in enumerate(dataLoader):
data = data
Pytorch数据读取详解的更多相关文章
- hbase实践之数据读取详解
hbase基本存储组织结构与数据读取组织结构对比 Segment是Hbase2.0的概念,MemStore由一个可写的Segment,以及一个或多个不可写的Segments构成.故hbase 1.*版 ...
- Pytorch autograd,backward详解
平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...
- ContentProvider数据访问详解
ContentProvider数据访问详解 Android官方指出的数据存储方式总共有五种:Shared Preferences.网络存储.文件存储.外储存储.SQLite,这些存储方式一般都只是在一 ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- 【HANA系列】SAP HANA XS使用JavaScript数据交互详解
公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列]SAP HANA XS使用Jav ...
- JVM 运行时数据区详解
一.运行时数据区 Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同数据区域. 1.有一些是随虚拟机的启动而创建,随虚拟机的退出而销毁,所有的线程共享这些数据区. 2.第二种则 ...
- 学习《深度学习与计算机视觉算法原理框架应用》《大数据架构详解从数据获取到深度学习》PDF代码
<深度学习与计算机视觉 算法原理.框架应用>全书共13章,分为2篇,第1篇基础知识,第2篇实例精讲.用通俗易懂的文字表达公式背后的原理,实例部分提供了一些工具,很实用. <大数据架构 ...
- 【HANA系列】【第一篇】SAP HANA XS使用JavaScript数据交互详解
公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列][第一篇]SAP HANA XS ...
- 3dTiles 数据规范详解[1] 介绍
版权:转载请带原地址.https://www.cnblogs.com/onsummer/p/12799366.html @秋意正寒 Web中的三维 html5和webgl技术使得浏览器三维变成了可能. ...
随机推荐
- IBM X3650 m4 面板指示灯
- SCIE和SCI
SCI和SCIE(SCI Expanded)分别是科学引文索引及科学引文索引扩展版(即网络版),主要是收录自然科学.工程技术领域最具影响力的重要期刊,包括2000多种外围刊. SCIE和SCI一样吗? ...
- 【RS】:论文《Neural Collaborative Filtering》的思路及模型框架
[论文的思路] NCF 框架如上: 1.输入层:首先将输入的user.item表示为二值化的稀疏向量(用one-hot encoding) 2.嵌入层(embedding):将稀疏表示映射为稠密向量( ...
- scrapy中间件中发送邮件
背景介绍:之前写过通过通过scrapy的扩展发送邮件,在爬虫关闭的时候发送邮件.那个时候有个问题就是MailSender对象需要return出去.这次需要在中间件中发送邮件,但是中间件中不能随便使用r ...
- 【LeetCode】搜索旋转排序数组【两次二分】
假设按照升序排序的数组在预先未知的某个点上进行了旋转. ( 例如,数组 [0,1,2,4,5,6,7] 可能变为 [4,5,6,7,0,1,2] ). 搜索一个给定的目标值,如果数组中存在这个目标值, ...
- C/C++ 指针常量和常量指针
为了区分是指向常量的指针还是const指针(表示指针本身是常量) 一个简便方法:从由往左读,遇到p就替换为“p is a”,遇到*就替换为“point to”,其余不变. const int * p ...
- Python 文件编码问题解决
最近使用python操作文件,经常遇到编码错误的问题,例如: UnicodeDecodeError: 'utf-8' codec can't decode byte 0xbe in position ...
- 【题解】Luogu P5471 [NOI2019]弹跳
原题传送门 先考虑部分分做法: subtask1: 暴力\(O(nm)\)枚举,跑最短路 subtask2: 吧一行的点压到vector中并排序,二分查找每一个弹跳装置珂以到达的城市,跑最短路 sub ...
- (原创)C#监控软件通信模型
直接操作现场的设备是PLC,不是服务器和客户端.所以,以PLC为核心分析设备故障以及在PC端的C#程序中加入故障处理代码. PC端读和写PLC哪个重要?写重要.因为写会影响PLC的寄存器值,进而影响工 ...
- C# 调用Access数据库关于like模糊查询的写法
在access查询视图中要使用"*"做模糊匹配,但是在程序中要用%来匹配.在access中:NEIBUBH like '*1234*'在程序中:NEIBUBH like '%123 ...