数据集加载和处理

这里主要涉及两个包:torchvision.datasets 和torch.utils.data.Dataset 和DataLoader

torchvision.datasets是一些包装好的数据集

里边所有可用的dataset都是 torch.utils.data.Dataset 的子类,这些子类都要有 __getitem__ __len__ 方法是实现。

这样, 定义的数据集才能够被 torch.utils.data.DataLoader ,DataLoader能够使用torch.multiprocessing并行加载许多样本

例如:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads) 当我们需要使用我们的数据集的时候,就需要进行包装成DataLoader能够识别的Dataset这样就能把我们从无穷的数据预处理中解脱出来。
创建数据集
首先导入,创建一个子类:
from torch.utils.data import Dataset
import torch
class MyDateset(Dataset):
    def __init__(self,num=10000,transform=None):  #这里就可以写你的参数了,比如文件夹什么的。
        self.len=num
        self.transform=transform
    def __len__(self):
        return self.len
    def __getitem__(self,idx):
        data=torch.rand(3,3,5)  #这里就是你的数据图像的话就是C*M*N的tensor,这里创建了一个3*3*5的张量
        label=torch.LongTensor([1])   #label也是需要一个张量
        if self.transform:    #这里就是数据预处理的部分 、
            data=self.transform(data)  #处理完必须要返回torch.Tensor类型
        return data,label
下面我们测试一下:
md=MyDateset()
print(md[0])
print(len(md))
输出:
(tensor([[[0.2753, 0.8114, 0.2916, 0.9600, 0.5057], [0.8595, 0.1195, 0.8065, 0.6393, 0.6213], 
[0.0997, 0.8590, 0.2469, 0.2158, 0.5296]], [[0.4764, 0.0561, 0.5866, 0.6129, 0.1882],
[0.4666, 0.9362, 0.5397, 0.3065, 0.4307], [0.4700, 0.6202, 0.3649, 0.6357, 0.5181]],
[[0.9794, 0.8127, 0.9842, 0.8821, 0.2447], [0.2320, 0.6406, 0.5683, 0.5637, 0.2734],
[0.2131, 0.5853, 0.5633, 0.9069, 0.9250]]]), tensor([1]))
10000
输出:这样我们就自定义了一个数据集Dataset,这样我们需要使用已有的数据集的时候就可以知道torchvision.dataset下许多数据集的构成了。
 
预处理数据

返回来再看上边定义数据集里有个参数transform,从定义getitem函数里看到,transform其实是一个函数。
torchvision.transforms里就包括了好多的操作。当然它主要处理的是图像,就是C*H*W类型的举证了。
可以直接这样使用:
from torchvision import transforms md=MyDateset(transform=transforms.Normalize((0,0,0),(0.1,0.2,0.3)))
print(md[0])
(tensor([[[2.5435, 9.1073, 4.1653, 9.4720, 0.7595],
[0.4840, 7.2377, 3.1578, 4.5391, 2.7440],
[4.6951, 4.7698, 1.1308, 0.5321, 3.5101]], [[2.6714, 4.5143, 0.0582, 0.2880, 0.2565],
[2.2951, 0.0680, 0.3542, 4.7372, 2.0162],
[1.4065, 2.5195, 0.8911, 4.8432, 3.1045]], [[2.7726, 2.5199, 0.8066, 0.7089, 2.0651],
[1.8641, 1.6599, 0.5546, 2.8716, 2.0964],
[2.5320, 1.5349, 1.8792, 0.0933, 3.2289]]]), tensor([1]))
更多的变换参见:https://pytorch.org/docs/master/torchvision/transforms.html

当然我们也可以自定义一个函数传入:
def add1(x):
    return x+1
md=MyDateset(transform=add1)
print(md[0])
输出:
(tensor([[[1.9552, 1.1294, 1.9435, 1.6476, 1.2726],
[1.1544, 1.7726, 1.1975, 1.9914, 1.2694],
当然也可以组合起来个transform形成一个一个处理级联:
tc=transforms.Compose([transforms.Normalize((0,0,0),(0.1,0.2,0.3)),add1])
md=MyDateset(transform=tc)
print(md[0]) 输出:
(tensor([[[ 1.9232,  6.4972,  7.9916,  4.3426, 10.9737],
[ 5.4062, 2.6264, 6.8474, 4.7810, 3.3232],
[ 8.6633, 4.1399, 2.3371, 5.5058, 3.9724]],
等等。

用Dataloader加载数据集

在训练网络,测试网络时我们就需要使用刚才定义好的数据集了。

from torch.utils.data import Dataset, DataLoader
md=MyDateset()
print(md[1])
dl=DataLoader(md, batch_size=4,  shuffle=False,  num_workers=4)
print(len(dl.dataset)) 这样dl就可以在程序里循环生成批样本,提供训练,测试了。

什么是pytorch(4.数据集加载和处理)(翻译)的更多相关文章

  1. OFRecord 数据集加载

    OFRecord 数据集加载 在数据输入一文中知道了使用 DataLoader 及相关算子加载数据,往往效率更高,并且学习了如何使用 DataLoader 及相关算子. 在 OFrecord 数据格式 ...

  2. Pytorch读取,加载图像数据(一)

    在学习Pytorch的时候,先学会如何正确创建或者加载数据,至关重要. 有了数据,很多函数,操作的效果就变得很直观. 本文主要用其他库读取图像文件(学会这个,你就可以在之后的学习中,将一些效果直观化) ...

  3. PIE SDK 多数据源的复合数据集加载

    1. 功能简介 GIS遥感图像数据复合是将多种遥感图像数据融合成一种新的图像数据的技术,是目前遥感应用分析的前沿,PIESDK通过复合数据技术可以将多幅幅影像数据集(多光谱和全色数据)组合成一幅多波段 ...

  4. tensorflow数据集加载

    本篇涉及的内容主要有小型常用的经典数据集的加载步骤,tensorflow提供了如下接口:keras.datasets.tf.data.Dataset.from_tensor_slices(shuffl ...

  5. [深度学习]-Dataset数据集加载

    加载数据集dataloader from torch.utils.data import DataLoader form 自己写的dataset import Dataset train_set = ...

  6. las数据集加载las数据

    引用的类库:ESRI.ArcGIS.GeoDatabaseExtensions 逻辑步骤: 1.创建las数据集(ILasDataset). 2.实例化las数据集的编辑器(ILasDatasetEd ...

  7. Pytorch 0.3加载0.4模型及其之间版本的变化

    1. 0.4中使用设备:.to(device) 2. 0.4中删除了Variable,直接tensor就可以 3. with torch.no_grad():的使用代替volatile:弃用volat ...

  8. Pytorch划分数据集的方法

    之前用过sklearn提供的划分数据集的函数,觉得超级方便.但是在使用TensorFlow和Pytorch的时候一直找不到类似的功能,之前搜索的关键字都是"pytorch split dat ...

  9. pytorch 加载数据集

    pytorch初学者,想加载自己的数据,了解了一下数据类型.维度等信息,方便以后加载其他数据. 1 torchvision.transforms实现数据预处理 transforms.Totensor( ...

随机推荐

  1. 剑指offer 04:重构二叉树

    题目描述 输入某二叉树的前序遍历和中序遍历的结果,请重建出该二叉树.假设输入的前序遍历和中序遍历的结果中都不含重复的数字.例如输入前序遍历序列{1,2,4,7,3,5,6,8}和中序遍历序列{4,7, ...

  2. oracle 11 g release 2 卸载

    Win 10 系统,Oracle 11 g R 2 ,安装目录C盘根目录 1.停止Oracle的所有服务 打开“服务”窗口,关闭Oracle的所有服务 2.运行Oracle Universal Ins ...

  3. NYOJ 542 试制品(第五届河南省省赛)

    解法不唯一,但是还是set好理解而且用着爽,代码注释应该够详细了 #include<stdio.h> #include<string.h> #include<math.h ...

  4. loadrunner中面向目标场景的设计

    在一个面向目标的方案中,可以定义五种类型的目标:虚拟用户数.每秒点击次数(仅 Web Vuser).每秒事务数.每分钟页面数(仅 Web Vuser)或方案的事务响应时间.使用“编辑方案目标”对话框可 ...

  5. English trip EM2-LP-4B At school Teacher:Will

    课上内容(Lesson) 词汇(Key Word ) art  美术:艺术 business  商科 engineering  工程学 graphic design  平面造型学 history  历 ...

  6. Hadoop启动之后jps没有NameNode节点

    这是因为多次格式化namenode节点出现的问题 1.先运行stop-all.sh 2.删除原目录,即core-site.xml下配置的<name>hadoop.tmp.dir</n ...

  7. MYSQL 总结——2

    1.mysql限制显示条目数:Limit,  Offset 图片网址:https://sqlbolt.com/lesson/filtering_sorting_query_results 实例: SE ...

  8. 记录一个下最近用tensorflow的几个坑

    1, softmax_cross_entropy_with_logits 的中的logits=x*w+b,其中w应该是[nfeats,nclass],b是[nclass]是对输出的每个类上logits ...

  9. map传参上下文赋值的问题

    今天开发遇到一个问题就是声明一个map<String,String> param ,给param赋值,明明有结果但是就是返回为空:下面附上代码: 因为在一个大的循环中,param是公用赋值 ...

  10. spoj Minimax Triangulation

    题解: dp+计算几何 F[i][j]表示第i-j条边的答案 然后转移一下 代码: #include<bits/stdc++.h> using namespace std; ]; ][]; ...