【小白学PyTorch】3 浅谈Dataset和Dataloader
文章目录:
1 Dataset基类
PyTorch 读取其他的数据,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。在看很多PyTorch的代码的时候,也会经常看到dataset这个东西的存在。Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。
先看一下源码:
这里有一个__getitem__
函数,__getitem__
函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。之后会举例子来讲解这个逻辑。
其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,这是触发去读取图片这些操作的是DataLoader里的__iter__(self)
(后面再讲)。
2 构建Dataset子类
下面我们构建一下Dataset的子类,叫他MyDataset类:
import torch
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
def __init__(self):
self.data = torch.tensor([[1,2,3],[2,3,4],[3,4,5],[4,5,6]])
self.label = torch.LongTensor([1,1,0,0])
def __getitem__(self,index):
return self.data[index],self.label[index]
def __len__(self):
return len(self.data)
2.1 Init
- 初始化中,一般是把数据直接保存在这个类的属性中。像是
self.data,self.label
2.2 getitem
- index是一个索引,这个索引的取值范围是要根据
__len__
这个返回值确定的,在上面的例子中,__len__
的返回值是4,所以这个index会在0,1,2,3这个范围内。
3 dataloader
从上文中,我们知道了MyDataset这个类中的__getitem__
的返回值,应该是某一个样本的数据和标签(如果是测试集的dataset,那么就只返回数据),在梯度下降的过程中,一般是需要将多个数据组成batch,这个需要我们自己来组合吗?不需要的,所以PyTorch中存在DataLoader这个迭代器(这个名词用的准不准确有待考究)。
继续上面的代码,我们接着写代码:
mydataloader = DataLoader(dataset=mydataset,
batch_size=1)
我们现在创建了一个DataLoader的实例,并且把之前实例化的mydataset作为参数输入进去,并且还输入了batch_size这个参数,现在我们使用的batch_size是1.下面来用for循环来遍历这个dataloader:
for i,(data,label) in enumerate(mydataloader):
print(data,label)
输出结果是:
意料之中的结果,总共输出了4个batch,每个batch都是只有1个样本(数据+标签),值得注意的是,这个输出过程是顺序的。
我们稍微修改一下上面的DataLoader的参数:
mydataloader = DataLoader(dataset=mydataset,
batch_size=2,
shuffle=True)
for i,(data,label) in enumerate(mydataloader):
print(data,label)
结果是:
可以看到每一个batch内出现了2个样本。假如我们再运行一遍上面的代码,得到:
两次结果不同,这是因为shuffle=True
,dataset中的index不再是按照顺序从0到3了,而是乱序,可能是[0,1,2,3],也可能是[2,3,1,0]。
【个人感想】
Dataloader和Dataset两个类是非常方便的,因为这个可以快速的做出来batch数据,修改batch_size和乱序都非常地方便。有下面两个希望注意的地方:
- 一般标签值应该是Long整数的,所以标签的tensor可以用
torch.LongTensor(数据)
或者用.long()
来转化成Long整数的形式。 - 如果要使用PyTorch的GPU训练的话,一般是先判断cuda是否可用,然后把数据标签都用
to()
放到GPU显存上进行GPU加速。
device = 'cuda' if torch.cuda.is_available() else 'cpu'
for i,(data,label) in enumerate(mydataloader):
data = data.to(device)
label = label.to(device)
print(data,label)
看一下输出:
【小白学PyTorch】3 浅谈Dataset和Dataloader的更多相关文章
- 【小白学C#】浅谈.NET中的IL代码
一.前言 前几天群里有位水友提问:”C#中,当一个方法所传入的参数是一个静态字段的时候,程序是直接到静态字段拿数据还是从复制的函数栈中拿数据“.其实很明显,这和方法参数的传递方式有关,如果是引用传递的 ...
- 【小白学PyTorch】20 TF2的eager模式与求导
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 【小白学PyTorch】15 TF2实现一个简单的服装分类任务
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 【小白学PyTorch】8 实战之MNIST小试牛刀
文章来自微信公众号[机器学习炼丹术].有什么问题都可以咨询作者WX:cyx645016617.想交个朋友占一个好友位也是可以的~好友位快满了不过. 参考目录: 目录 1 探索性数据分析 1.1 数据集 ...
- 【小白学PyTorch】4 构建模型三要素与权重初始化
文章目录: 目录 1 模型三要素 2 参数初始化 3 完整运行代码 4 尺寸计算与参数计算 1 模型三要素 三要素其实很简单 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Mo ...
- 【小白学PyTorch】5 torchvision预训练模型与数据集全览
文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...
- 【小白学PyTorch】11 MobileNet详解及PyTorch实现
文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...
- 【小白学PyTorch】16 TF2读取图片的方法
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx645016617. 参考 ...
- 【小白学PyTorch】17 TFrec文件的创建与读取
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
随机推荐
- pagehelper的使用和一些坑!
[toc] ##1.1 pagehelper介绍和使用 PageHelper是一款好用的开源免费的Mybatis第三方物理分页插件. 原本以为分页插件,应该是很简单的,然而PageHelper比我想象 ...
- log4j2 自动删除过期日志文件配置及实现原理解析
日志文件自动删除功能必不可少,当然你可以让运维去做这事,只是这不地道.而日志组件是一个必备组件,让其多做一件删除的工作,无可厚非.本文就来探讨下 log4j 的日志文件自动删除实现吧. 0. 自动删除 ...
- Android JNI之编译
JNI代码都写好了,在编译之前我们有非常重要的一部,就是写mk文件,mk文件就相当于gcc编译时的Makefile文件,它是用来告诉编译器如何去编译的. 这里只对自己理解和常用的知识点做记录,想要看关 ...
- centos,linux环境下安装JDK1.8完整
进入oracle官网下载安装包,cetos一般选择xx-xx-linux-x64.tar.gz.获取到地址后可以点击下载,也可以使用wget命令下载. 在得到下载好的文件后下面就可以开始安装了.比如我 ...
- Python最全pdf学习书籍资料分享
本人学习Python两年时间,期间统计了一些比较好的学习资料 1.基础资料 下载地址: 链接:https://pan.baidu.com/s/1sjtyYayBbQLsrUdaXWmzkg提取码:1 ...
- SPP、ASPP、RFB、CBAM
SPP:ASPP:将pooling 改为了 空洞卷积RFB:不同大小的卷积核和空洞卷积进行组合,认为大的卷积应该有更大的感受野. CBAM:空间和通道的注意力机制 SPP: Spatial Pyram ...
- C#LeetCode刷题之#724-寻找数组的中心索引( Find Pivot Index)
问题 该文章的最新版本已迁移至个人博客[比特飞],单击链接 https://www.byteflying.com/archives/3742 访问. 给定一个整数类型的数组 nums,请编写一个能够返 ...
- NumPy速查笔记(持续更新中)
目录 1 总览 2 ndarray 3 常用API 3.1 创建ndarray (1)将Python类似数组的对象转化成Numpy数组 (2)numpy内置的数组创建 (3)从磁盘中读取标准格式或者自 ...
- ElasticSearch 7.8.1集群搭建
通往集群的大门 集群由什么用? 高可用 高可用(High Availability)是分布式系统架构设计中必须考虑的因素之一,它通常是指,通过设计减少系统不能提供服务的时间.如果系统每运行100个时间 ...
- Oracle从回收站找回误删的数据
Step1 先根据删除时间查看删除了那些表 select * from recyclebin where type = 'TABLE' and createtime like '${删除时间}%' o ...