最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络。

pytorch中有很方便的dataloader函数来方便我们进行批处理,做了简单的例子,过程很简单,就像把大象装进冰箱里一共需要几步?


第一步:打开冰箱门。

我们要创建torch能够识别的数据集类型(pytorch中也有很多现成的数据集类型,以后再说)。

首先我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果:

随后我们需要把X和Y组成一个完整的数据集,并转化为pytorch能识别的数据集类型:

我们来看一下这些数据的数据类型:

可以看出我们把X和Y通过Data.TensorDataset() 这个函数拼装成了一个数据集,数据集的类型是【TensorDataset】。

好了,第一步结束了,冰箱门打开了。


第二步:把大象装进去。

就是把上一步做成的数据集放入Data.DataLoader中,可以生成一个迭代器,从而我们可以方便的进行批处理。

DataLoader中也有很多其他参数:

dataset:Dataset类型,从其中加载数据
batch_size:int,可选。每个batch加载多少样本
shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌
sampler:Sampler,可选。从数据集中采样样本的方法。
num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。
collate_fn:callable,可选。
pin_memory:bool,可选
drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。

好了,第二步结束了,大象装进去了。


第三步:把冰箱门关上。

好啦,现在我们就可以愉快的用我们上面定义好的迭代器进行训练啦。

在这里我们利用print来模拟我们的训练过程,即我们在这里对搭建好的网络进行喂入。

输出的结果是:

可以看到,我们一共训练了所有的数据训练了5次。数据中一共10组,我们设置的mini-batch是3,即每一次我们训练网络的时候喂入3组数据,到了最后一次我们只有1组数据了,比mini-batch小,我们就仅输出这一个。

此外,还可以利用python中的enumerate(),是对所有可以迭代的数据类型(含有很多东西的list等等)进行取操作的函数,用法如下:

好啦,现在冰箱门就关上啦,(*^__^*)

pytorch中如何使用DataLoader对数据集进行批处理的更多相关文章

  1. pytorch Dataset数据集和Dataloader迭代数据集

    import torch from torch.utils.data import Dataset,DataLoader class SmsDataset(Dataset): def __init__ ...

  2. pytorch中DataLoader, DataSet, Sampler之间的关系

    转自:https://mp.weixin.qq.com/s/RTv0cUWvc0kuXBeNoXVu_A 自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为 ...

  3. PyTorch中的MIT ADE20K数据集的语义分割

    PyTorch中的MIT ADE20K数据集的语义分割 代码地址:https://github.com/CSAILVision/semantic-segmentation-pytorch Semant ...

  4. pytorch中tensorboardX的用法

    在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...

  5. 转pytorch中训练深度神经网络模型的关键知识点

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...

  6. [深度学习] pytorch利用Datasets和DataLoader读取数据

    本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码.所有逻辑需按自己需求另外实现: 一.分析DataLoader train_loader = DataLoader( ...

  7. 第五章——Pytorch中常用的工具

    2018年07月07日 17:30:40 __矮油不错哟 阅读数:221   1. 数据处理 数据加载 ImageFolder DataLoader加载数据 sampler:采样模块 1. 数据处理 ...

  8. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  9. PyTorch中使用深度学习(CNN和LSTM)的自动图像标题

    介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...

随机推荐

  1. js 錯誤

    try{ //需要被檢測是否拋出錯誤 } catch(err) { //錯誤處理代碼 } try.catch成對出現 throw:拋出錯誤 當錯誤發生時,javascript引擎停止運行,并生成一個錯 ...

  2. linux中内存使用原理

    首先介绍一下linux中内存是如何使用的. 当有应用需要读写磁盘数据时,由系统把相关数据从磁盘读取到内存,如果物理内存不够,则把内存中的部分数据导入到磁盘,从而把磁盘的部分空间当作虚拟内存 来使用,也 ...

  3. BZOJ2001 HNOI2010城市建设(线段树分治+LCT)

    一个很显然的思路是把边按时间段拆开线段树分治一下,用lct维护MST.理论上复杂度是O((M+Q)logNlogQ),实际常数爆炸T成狗.正解写不动了. #include<iostream> ...

  4. MT【204】离散型最值

    (联赛一试2006,14).将2006表示成5个正整数$x_1,x_2,x_3,x_4,x_5$之和.记$S=\sum\limits_{1\le i<j\le5}{x_ix_j}$问:(1) 当 ...

  5. [洛谷P4245]【模板】任意模数NTT

    题目大意:给你两个多项式$f(x)$和$g(x)$以及一个模数$p(p\leqslant10^9)$,求$f*g\pmod p$ 题解:任意模数$NTT$,最大的数为$p^2\times\max\{n ...

  6. 自学Zabbix11.6 Zabbix SNMP自定义OID

    点击返回:自学Zabbix之路 点击返回:自学Zabbix4.0之路 点击返回:自学zabbix集锦 自学Zabbix11.6 Zabbix SNMP自定义OID 为什么要自定义OID? 前面已经讲过 ...

  7. hdu3072 Intelligence System (最小树形图?)

    题意:给一个有向图,问要从0号点能到达所有点所需要经过路径的最小权值和是多少,然而,若两点强联通,则这两点互相到达不需要花费.保证0号点能到达所有点 tarjan缩点以后直接取每个点入边中花费最小的即 ...

  8. centos7配置本地yum源 使用安装镜像安装软件

    1. 在cdrom挂载安装镜像.(物理机则插入光盘,虚拟机则在CD/DVD中选择iso镜像.如果虚拟机mount时提示找不到则在选择iso镜像上方勾选“已连接”和“启动时连接”,或者点击 虚拟机下方状 ...

  9. 【模板】spfa

    代码如下 #include <bits/stdc++.h> using namespace std; const int maxv=1e4+10; const int maxe=5e5+1 ...

  10. (转)Maven之自定义archetype生成项目骨架

    背景:最近在开发一个项目的基础构件,在以后项目的开发过程中可以直接使用该构件快速的生成项目骨架进行开发. 摘要:使用过Maven的人都知道maven中有许多功能都是通过插件来提供的,今天我们来说一下其 ...