Pytorch数据读取与预处理实现与探索
在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录两个DEMO,便于加快以后的代码效率。
根据数据是否一次性读取完,将DEMO分为:
1、串行式读取。也就是一次性读取完所有需要的数据到内存,模型训练时不会再访问外存。通常用在内存足够的情况下使用,速度更快。
2、并行式读取。也就是边训练边读取数据。通常用在内存不够的情况下使用,会占用计算资源,如果分配的好的话,几乎不损失速度。
Pytorch官方的数据提取方式尽管方便编码,但由于它提取数据方式比较死板,会浪费资源,下面对其进行分析。
串行式读取
DEMO代码
import torch
from torch.utils.data import Dataset,DataLoader class MyDataSet(Dataset):# ————1————
def __init__(self):
self.data = torch.tensor(range(10)).reshape([5,2])
self.label = torch.tensor(range(5)) def __getitem__(self, index):
return self.data[index], self.label[index] def __len__(self):
return len(self.data) my_data_set = MyDataSet()# ————2————
my_data_loader = DataLoader(
dataset=my_data_set, # ————3————
batch_size=2, # ————4————
shuffle=True, # ————5————
sampler=None, # ————6————
batch_sampler=None, # ————7————
num_workers=0 , # ————8————
collate_fn=None, # ————9————
pin_memory=True, # ————10————
drop_last=True # ————11————
) for i in my_data_loader: # ————12————
print(i)
注释处解释如下:
1、重写数据集类,用于保存数据。除了 __init__() 外,必须实现 __getitem__() 和 __len__() 两个方法。前一个方法用于输出索引对应的数据。后一个方法用于获取数据集的长度。
2~5、 2准备好数据集后,传入DataLoader来迭代生成数据。前三个参数分别是传入的数据集对象、每次获取的批量大小、是否打乱数据集输出。
6、采样器,如果定义这个,shuffle只能设置为False。所谓采样器就是用于生成数据索引的可迭代对象,比如列表。因此,定义了采样器,采样都按它来,shuffle再打乱就没意义了。
7、批量采样器,如果定义这个,batch_size、shuffle、sampler、drop_last都不能定义。实际上,如果没有特殊的数据生成顺序的要求,采样器并没有必要定义。torch.utils.data 中的各种 Sampler 就是采样器类,如果需要,可以使用它们来定义。
8、用于生成数据的子进程数。默认为0,不并行。
9、拼接多个样本的方法,默认是将每个batch的数据在第一维上进行拼接。这样可能说不清楚,并且由于这里可以探究一下获取数据的速度,后面再详细说明。
10、是否使用锁页内存。用的话会更快,内存不充足最好别用。
11、是否把最后小于batch的数据丢掉。
12、迭代获取数据并输出。
速度探索
首先看一下DEMO的输出:
输出了两个batch的数据,每组数据中data和label都正确排列,符合我们的预期。那么DataLoader是怎么把数据整合起来的呢?首先,我们把collate_fn定义为直接映射(不用它默认的方法),来查看看每次DataLoader从MyDataSet中读取了什么,将上面部分代码修改如下:
my_data_loader = DataLoader(
dataset=my_data_set,
batch_size=2,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=0 ,
collate_fn=lambda x:x, #修改处
pin_memory=True,
drop_last=True
)
结果如下:
输出还是两个batch,然而每个batch中,单个的data和label是在一个list中的。似乎可以看出,DataLoader是一个一个读取MyDataSet中的数据的,然后再进行相应数据的拼接。为了验证这点,代码修改如下:
import torch
from torch.utils.data import Dataset,DataLoader class MyDataSet(Dataset):
def __init__(self):
self.data = torch.tensor(range(10)).reshape([5,2])
self.label = torch.tensor(range(5)) def __getitem__(self, index):
print(index) #修改处2
return self.data[index], self.label[index] def __len__(self):
return len(self.data) my_data_set = MyDataSet()
my_data_loader = DataLoader(
dataset=my_data_set,
batch_size=2,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=0 ,
collate_fn=lambda x:x, #修改处1
pin_memory=True,
drop_last=True
) for i in my_data_loader:
print(i)
输出如下:
验证了前面的猜想,的确是一个一个读取的。如果数据集定义的不是格式化的数据,那还好,但是我这里定义的是tensor,是可以直接通过列表来索引对应的tensor的。因此,DataLoader的操作比直接索引多了拼接这一步,肯定是会慢很多的。一两次的读取还好,但在训练中,大量的读取累加起来,就会浪费很多时间了。
自定义一个DataLoader可以证明这一点,代码如下:
import torch
from torch.utils.data import Dataset,DataLoader
from time import time class MyDataSet(Dataset):
def __init__(self):
self.data = torch.tensor(range(100000)).reshape([50000,2])
self.label = torch.tensor(range(50000)) def __getitem__(self, index):
return self.data[index], self.label[index] def __len__(self):
return len(self.data) # 自定义DataLoader
class MyDataLoader():
def __init__(self, dataset,batch_size):
self.dataset = dataset
self.batch_size = batch_size
def __iter__(self):
self.now = 0
self.shuffle_i = np.array(range(self.dataset.__len__()))
np.random.shuffle(self.shuffle_i)
return self def __next__(self):
self.now += self.batch_size
if self.now <= len(self.shuffle_i):
indexes = self.shuffle_i[self.now-self.batch_size:self.now]
return self.dataset.__getitem__(indexes)
else:
raise StopIteration # 使用官方DataLoader
my_data_set = MyDataSet()
my_data_loader = DataLoader(
dataset=my_data_set,
batch_size=256,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=0 ,
collate_fn=None,
pin_memory=True,
drop_last=True
) start_t = time()
for t in range(10):
for i in my_data_loader:
pass
print("官方:", time() - start_t) #自定义DataLoader
my_data_set = MyDataSet()
my_data_loader = MyDataLoader(my_data_set,256) start_t = time()
for t in range(10):
for i in my_data_loader:
pass
print("自定义:", time() - start_t)
运行结果如下:
以上使用batch大小为256,仅各读取10 epoch的数据,都有30多倍的时间上的差距,更大的batch差距会更明显。另外,这里用于测试的每个数据只有两个浮点数,如果是图像,所需的时间可能会增加几百倍。因此,如果数据量和batch都比较大,并且数据是格式化的,最好自己写数据生成器。
并行式读取
DEMO代码
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder path = r'E:\DataSets\ImageNet\ILSVRC2012_img_train\10-19\128x128'
my_data_set = ImageFolder( #————1————
root = path, #————2————
transform = transforms.Compose([ #————3————
transforms.ToTensor(),
transforms.CenterCrop(64)
]),
loader = plt.imread #————4————
)
my_data_loader = DataLoader(
dataset=my_data_set,
batch_size=128,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=True,
drop_last=True
) for i in my_data_loader:
print(i)
注释处解释如下:
1/2、ImageFolder类继承自DataSet类,因此可以按索引读取图像。路径必须包含文件夹,ImageFolder会给每个文件夹中的图像添加索引,并且每张图像会给予其所在文件夹的标签。举个例子,代码中my_data_set[0] 输出的是图像对象和它对应的标签组成的列表。
3、图像到格式化数据的转换组合。更多的转换方法可以看 transform 模块。
4、图像法的读取方式,默认是PIL.Image.open(),但我发现plt.imread()更快一些。
由于是边训练边读取,transform会占用很多时间,因此可以先将图像转换为需要的形式存入外存再读取,从而避免重复操作。
其中transform.ToTensor()会把正常读取的图像转换为torch.tensor,并且像素值会映射至$[0,1]$。由于plt.imread()读取png图像时,像素值在$[0,1]$,而读取jpg图像时,像素值却在$[0,255]$,因此使用transform.ToTensor()能将图像像素区间统一化。
Pytorch数据读取与预处理实现与探索的更多相关文章
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- Pytorch数据读取框架
训练一个模型需要有一个数据库,一个网络,一个优化函数.数据读取是训练的第一步,以下是pytorch数据输入框架. 1)实例化一个数据库 假设我们已经定义了一个FaceLandmarksDataset数 ...
- Pytorch数据读取详解
原文:http://studyai.com/article/11efc2bf#%E9%87%87%E6%A0%B7%E5%99%A8%20Sampler%20&%20BatchSampler ...
- 从零搭建Pytorch模型教程(一)数据读取
前言 本文介绍了classdataset的几个要点,由哪些部分组成,每个部分需要完成哪些事情,如何进行数据增强,如何实现自己设计的数据增强.然后,介绍了分布式训练的数据加载方式,数据读取的整个 ...
- [Pytorch]PyTorch Dataloader自定义数据读取
整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...
- 『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理
Fork版本项目地址:SSD 一.TFR数据读取 创建slim.dataset.Dataset对象 在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dat ...
- tensorflow之数据读取探究(1)
Tensorflow中之前主要用的数据读取方式主要有: 建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用.使用这种方法十分灵活,可以一下子将所有数据 ...
- SSD源码解读——数据读取
之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html. 为了加深对SSD的理解,因此对SSD的源码进 ...
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...
随机推荐
- Semantic Pull Requests All In One
Semantic Pull Requests All In One https://github.com/zeke/semantic-pull-requests docs: Update direct ...
- React Query & SWR
React Query & SWR HTTP request all in one solution React Query Hooks for fetching, caching and u ...
- CSS font-weight all in one
CSS font-weight all in one font-weight: bolder: 没毛病呀! /* 关键字值 */ font-weight: normal; font-weight: b ...
- HTTPS in depth
HTTPS in depth HTTPS Hypertext Transfer Protocol Secure How does HTTPS work? https://www.cloudflare. ...
- Scratch 游戏开发
Scratch 游戏开发 可视化少儿编程 https://scratch.mit.edu/ Scratch Desktop https://scratch.mit.edu/download https ...
- leetcode bug & 9. Palindrome Number
leetcode bug & 9. Palindrome Number bug shit bug "use strict"; /** * * @author xgqfrms ...
- JavaScript & Error Types
JavaScript & Error Types JavaScript提供了8个错误对象,这些错误对象会根据错误类型在try / catch表达式中引发: Error EvalError Ra ...
- taro external-class
taro external-class https://nervjs.github.io/taro/docs/component-style.html externalClasses child co ...
- 用Python实现一个“百度翻译”
import requests import json s = input("请输入你要翻译的内容:") headers = {"User-Agent":&qu ...
- HTTP 协议中的并发限制及队首阻塞问题
本文转载自HTTP 协议中的并发限制及队首阻塞问题 串行连接 HTTP/0.9 和早期的 HTTP/1.0 协议对 HTTP 请求处理是串行化的.假如一个页面包含 3 个样式文件,同属于一个协议.域名 ...