在深度学习中,数据的处理对于神经网络的训练来说十分重要,良好的数据(包括图像、文本、语音等)处理不仅可以加速模型的训练,同时也直接关系到模型的效果。本文以处理图像数据为例,记录一些使用PyTorch进行图像预处理和数据加载的方法


一、数据的加载

  在PyTorch中,数据加载需要自定义数据集类,并用此类来实例化数据对象,实现自定义的数据集需要继承torch.utils.data包中的Dataset类

  在继承Dataset实现自己的类时,需要实现以下两个Python魔法方法:

  • __getitem__(index): 返回一个样本数据,当使用obj[index]时实际就是在调用obj.__getitem__(index)
  • __len__():返回样本的数量,当使用len(obj)时实际就是在调用obj.__len__()

  例如,以猫狗大战的二分类数据集为例,其加载过程如下:

  1. import os
  2. import torch as t
  3. from torch.utils import data
  4. from PIL import Image
  5. import numpy as np
  6. class dogCat(data.Dataset):
  7. def __init__(self,root): # root为数据存放目录
  8. imgs = os.listdir(root) #列出当前路径下所有的文件
  9. self.imgs = [os.path.join(root,img) for img in imgs] # 所有图片的路径
  10. #print(self.imgs)
  11. """返回一个样本数据"""
  12. def __getitem__(self, item):
  13. img_path = self.imgs[item] # 第item张图片的路径
  14. #dog 1 cat 0
  15. label = 1 if 'dog' in img_path.split('\\')[-1] else 0 # 获取标签信息
  16. #print(label)
  17. pil_img = Image.open(img_path) #读入图片
  18. print(type(pil_img))
  19. array = np.asarray(pil_img) # 转为numpy.array类型
  20. data = t.from_numpy(array) # 转为tensor类型
  21. return data,label #返回图片对应的tensor及其标签
  22. """样本的数量"""
  23. def __len__(self):
  24. return len(self.imgs)
  25. if __name__ == '__main__':
  26. dogcat = dogCat('D:\pycode\dogsVScats\data\catvsdog\\train') #数据集对象
  27. data,label = dogcat[0] # 返回第0张图片的信息
  28. print(data.size())
  29. print(label)
  30. print(len(dogcat))

二、计算机视觉工具包:torchvision

  对于图像数据来说,以上的数据加载时不完善的,因为只是将图片读入,而没有进行相关的处理,如每张图片的大小和形状,样本的数值归一化等等。

  为了解决这一问题,PyTorch开发了一个视觉工具包torchvision,这个包独立于torch,需要通过pip install torchvision来单独安装。

  torchvision有三个部分组成:

  • models提供各种经典的网络结构和预训练好的模型,如AlexNet、VGG、ResNet、Inception等
  1. from torchvision import models
  2. from torch import nn
  3. resnet34 = models.resnet34(pretrained=True,num_classes=1000) # 加载预训练模型
  4. resnet34.fc=nn.Linear(512,10) # 修改全连接层为10分类
  • datasets提供了常用的数据集,如MNIST、CIFAR10/100、ImageNet、COCO等
  1. from torchvision import datasets
  2. dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)

  除了常用数据集外,需要特别注意的是ImageFolder,ImageFolder假设所有的文件按文件夹存放,每个文件夹下面存储同一类的图片,文件夹的名字为这一类别的名字。这是我们经常用到的一种数据组织形式。

  1. # 使用方法:
  2. ImageFolder(root,transform=None,target_transform=None,loader=default_loader
  3. # 参数:文件夹路径,对图像做什么样的转换,对标签做什么样的转换,如何加载图片
  4. from torchvision.datasets import ImageFolder
  5. dataset = ImageFolder('data\\')
  6. print(dataset.class_to_idx) # class_to_idx ,label和id的对应关系,从0开始
  7. print(dataset.imgs) # 数据和标签对应
  • transforms: 提供常用的数据预处理操作,主要是对Tensor和PIL Image对象的处理操作

  对PIL Image的操作:Resize、CenterCrop、RandomCrop、RandomsizedCrop、Pad、ToTensor等。

  对Tensor的操作:Normalize、ToPILImage等。

  如果要进行多个操作,可以通过transforms.Compose([])将操作拼接起来。但是需要注意的是需要首先构建转换操作,然后再执行转换操作。

  1. import os
  2. from torch.utils import data
  3. from PIL import Image
  4. import numpy as np
  5. from torchvision import transforms as T
  6. transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])]) # 构建转换操作
  7. class dogCat(data.Dataset):
  8. def __init__(self,root,transforms):
  9. imgs = os.listdir(root)
  10. #print(imgs)
  11. self.imgs = [os.path.join(root,img) for img in imgs]
  12. #print(self.imgs)
  13. self.transforms = transforms
  14. def __getitem__(self, item):
  15. img_path = self.imgs[item]
  16. #dog 1 cat 0
  17. label = 1 if 'dog' in img_path.split('\\')[-1] else 0
  18. #print(label)
  19. pil_img = Image.open(img_path)
  20. if self.transforms:
  21. pil_img = self.transforms(pil_img) #执行准换操作
  22. return pil_img,label,item
  23. def __len__(self):
  24. return len(self.imgs)

三、使用DataLoader进行数据再处理

  通过上述描述,我们通过自定义数据集类,使用视觉工具包进行图像的转换等操作,最终得到的是一个dataset的数据集对象,使用此对象可以一次返回一个样本。

  但是,我们应该清楚:训练神经网络时,一般采用的是小批量的梯度下降,因此我们是对一批数据进行处理,也就是一个batch,同时,数据还需要进行打乱(shuffle)和并行加速等。PyTorch提供了DataLoader来实现这些功能。

  DataLoader定义如下:

  1. DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False)

  参数含义如下:

  • dataset:加载的数据集
  • batch_zize: 批大小
  • shuffle: 是否将数据打乱
  • sampler:样本抽样,常用的有随机采样RandomSampler,shuffle=True时自动调用随机采样,默认是顺序采样,还有一个常用的是:WeightedRandomSampler,按照样本的权重进行采样。
  • num_workers: 使用的进程数,0代表不使用多进程。
  • collate_fn: 拼接方式。
  • pin_memory: 是否将数据保存在pin memory区。
  • drop_last: 是否将多出来的不足一个batch的丢弃。

  调用DataLoader得到的结果是一个可迭代的对象,可以和使用迭代器一样使用它。

  1. from torchvision import transforms as T
  2. from torch.utils.data import DataLoader
  3. transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])
  4. if __name__ == '__main__':
  5. dogcat = dogCat('D:\pycode\dogsVScats\data\catvsdog\\train', transform)
  6. data, label, index = dogcat[0]
  7. dataloader = DataLoader(dogcat,batch_size=3,shuffle=False,num_workers=0,drop_last=False)
  8. for batchDatas,batchLabels in dataloader:
  9. train()

总结

  本文记录了使用PyTorch进行数据预处理的相关操作流程,重点是掌握Dataset和DataLoader两个类的使用,另外,视觉工具包torchvision的三个模块灵活运用,会对数据处理过程有很好的帮助。

【深度学习框架】使用PyTorch进行数据处理的更多相关文章

  1. 金玉良缘易配而木石前盟难得|M1 Mac os(Apple Silicon)天生一对Python3开发环境搭建(集成深度学习框架Tensorflow/Pytorch)

    原文转载自「刘悦的技术博客」https://v3u.cn/a_id_189 笔者投入M1的怀抱已经有一段时间了,俗话说得好,但闻新人笑,不见旧人哭,Intel mac早已被束之高阁,而M1 mac已经 ...

  2. 常用深度学习框架(keras,pytorch.cntk,theano)conda 安装--未整理

    版本查询 cpu tensorflow conda env list source activate tensorflow python import tensorflow as tf 和 tf.__ ...

  3. 神工鬼斧惟肖惟妙,M1 mac系统深度学习框架Pytorch的二次元动漫动画风格迁移滤镜AnimeGANv2+Ffmpeg(图片+视频)快速实践

    原文转载自「刘悦的技术博客」https://v3u.cn/a_id_201 前段时间,业界鼎鼎有名的动漫风格转化滤镜库AnimeGAN发布了最新的v2版本,一时间街谈巷议,风头无两.提起二次元,目前国 ...

  4. ArXiv最受欢迎开源深度学习框架榜单:TensorFlow第一,PyTorch第四

    [导读]Kears作者François Chollet刚刚在Twitter贴出最近三个月在arXiv提到的深度学习框架,TensorFlow不出意外排名第一,Keras排名第二.随后是Caffe.Py ...

  5. 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题

    在学习陈云的教程<深度学习框架PyTorch:入门与实践>的损失函数构建时代码如下: 可我运行如下代码: output = net(input) target = Variable(t.a ...

  6. 深度学习框架Keras与Pytorch对比

    对于许多科学家.工程师和开发人员来说,TensorFlow是他们的第一个深度学习框架.TensorFlow 1.0于2017年2月发布,可以说,它对用户不太友好. 在过去的几年里,两个主要的深度学习库 ...

  7. 从TensorFlow到PyTorch:九大深度学习框架哪款最适合你?

    开源的深度学习神经网络正步入成熟,而现在有许多框架具备为个性化方案提供先进的机器学习和人工智能的能力.那么如何决定哪个开源框架最适合你呢?本文试图通过对比深度学习各大框架的优缺点,从而为各位读者提供一 ...

  8. Spark如何与深度学习框架协作,处理非结构化数据

    随着大数据和AI业务的不断融合,大数据分析和处理过程中,通过深度学习技术对非结构化数据(如图片.音频.文本)进行大数据处理的业务场景越来越多.本文会介绍Spark如何与深度学习框架进行协同工作,在大数 ...

  9. Cs231n课堂内容记录-Lecture 8 深度学习框架

    Lecture 8  Deep Learning Software 课堂笔记参见:https://blog.csdn.net/u012554092/article/details/78159316 今 ...

  10. [Tensorflow实战Google深度学习框架]笔记4

    本系列为Tensorflow实战Google深度学习框架知识笔记,仅为博主看书过程中觉得较为重要的知识点,简单摘要下来,内容较为零散,请见谅. 2017-11-06 [第五章] MNIST数字识别问题 ...

随机推荐

  1. Spark 颠覆 MapReduce 保持的排序记录

    在过去几年,Apache Spark的採用以惊人的速度添加着,通常被作为MapReduce后继,能够支撑数千节点规模的集群部署. 在内存中数 据处理上,Apache Spark比MapReduce更加 ...

  2. CF #323 DIV2 D题

    可以知道,当T较大时,对于LIS,肯定会有很长的一部分是重复的,而这重复的部分,只能是一个block中出现次数最多的数字组成一序列.所以,对于T>1000时,可以直接求出LIS,剩下T-=100 ...

  3. CF #319 div 2 E

    在一个边长为10^6正方形中,可以把它x轴分段,分成1000段.奇数的时候由底往上扫描,偶数的时候由上往下扫描.估计一下这个最长的长度,首先,我们知道有10^6个点,则y邮方向最多移动10^3*10^ ...

  4. MFC 小知识总结四

    1 PlaySound  播放WAV格式的音乐 This function plays a sound specified by a file name, resource, or system ev ...

  5. OllyDbg 使用笔记 (七)

    OllyDbg 使用笔记 (七) 參考 书:<加密与解密> 视频:小甲鱼 解密系列 视频 演示样例程序下载:http://pan.baidu.com/s/1gvwlS 暴力破解 观察这个程 ...

  6. hdu 3810 Magina 队列模拟0-1背包

    题意: 出一些独立的陆地,每片陆地上有非常多怪物.杀掉每一个怪物都须要一定的时间,并能获得一定的金钱.给出指定的金钱m, 求最少要多少时间能够得到m金钱,仅能选择一个陆地进行杀怪. 题解: 这题,假设 ...

  7. 利用Node.js对某智能家居server重构

    原文摘自我的前端博客,欢迎大家来訪问 http://www.hacke2.cn 之前负责过一个智能家居项目的开发,外包重庆一家公司的.我们主要开发server监控和集群版管理. 移动端和机顶盒的远程通 ...

  8. NOIP2013--火柴排队(树状数组)

    转载: 树状数组,具体的说是 离散化+树状数组.这也是学习树状数组的第一题. 算法的大体流程就是: 1.先对输入的数组离散化,使得各个元素比较接近,而不是离散的, 2.接着,运用树状数组的标准操作来累 ...

  9. Error-MySQL:2005 - Unknown MySQL server host 'localhost'(0)

    ylbtech-Error-MySQL:2005 - Unknown MySQL server host 'localhost'(0) 1.返回顶部 1. 今天在外面开navicat for mysq ...

  10. Rails5 关联表格搜索

    创建: 2017/08/13   other_type_car = Car.joins(:car_type).active.find_by(car_type: car_type)   @recomme ...