我们在《torch.utils.data.DataLoader与迭代器转换》中介绍了如何使用Pytorch内置的数据集进行论文实现,如torchvision.datasets。下面是加载内置训练数据集的常见操作:

from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize
RAW_DATA_PATH = './rawdata'
transform = Compose(
[ToTensor(),
Normalize((0.1307,), (0.3081,))
]
)
train_data = FashionMNIST(
root=RAW_DATA_PATH,
download=True,
train=True,
transform=transform
)

这里的train_data做为dataset对象,它拥有许多熟悉,我们可以通过以下方法获取样本数据的分类类别集合、样本的特征维度、样本的标签集合等信息。

classes = train_data.classes
num_features = train_data.data[0].shape[0]
train_labels = train_data.targets print(classes)
print(num_features)
print(train_labels)

输出如下:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0, ..., 3, 0, 5])

但是,我们常常会在训练集的基础上拆分出验证集(或者只用部分数据来进行训练)。我们想到的第一个方法是使用torch.utils.data.random_splitdataset进行划分,下面我们假设划分10000个样本做为训练集,其余样本做为验证集:

from torch.utils.data import random_split
k = 10000
train_data, valid_data = random_split(train_data, [k, len(train_data)-k])

注意我们如果打印train_datavalid_data的类型,可以看到显示:

<class 'torch.utils.data.dataset.Subset'>

已经不再是torchvision.datasets.mnist.FashionMNIST对象,而是一个所谓的Subset对象!此时Subset对象虽然仍然还存有data属性,但是内置的targetclasses属性已经不复存在,比如如果我们强行访问valid_datatarget属性:

valid_target = valid_data.target

就会报如下错误:

'Subset' object has no attribute 'target'

但如果我们在后续的代码中常常会将拆分后的数据集也默认为dataset对象,那么该如何做到代码的一致性呢?

这里有一个trick,那就是以继承SubSet类的方式的方式定义一个新的CustomSubSet类,使新类在保持SubSet类的基本属性的基础上,拥有和原本数据集类相似的属性,如targetsclasses等:

from torch.utils.data import Subset
class CustomSubset(Subset):
'''A custom subset class'''
def __init__(self, dataset, indices):
super().__init__(dataset, indices)
self.targets = dataset.targets # 保留targets属性
self.classes = dataset.classes # 保留classes属性 def __getitem__(self, idx): #同时支持索引访问操作
x, y = self.dataset[self.indices[idx]]
return x, y def __len__(self): # 同时支持取长度操作
return len(self.indices)

然后就引出了第二种划分方法,即通过初始化CustomSubset对象的方式直接对数据集进行划分(这里为了简化省略了shuffle的步骤):

import numpy as np
from copy import deepcopy
origin_data = deepcopy(train_data)
train_data = CustomSubset(origin_data, np.arange(k))
valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)

注意,CustomSubset类的初始化方法的第二个参数indices为样本索引,我们可以通过np.arange()的方法来创建。

然后,我们再访问valid_data对应的classestarges属性:

print(valid_data.classes)
print(valid_data.targets)

此时,我们发现可以成功访问这些属性了:

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
tensor([9, 0, 0, ..., 3, 0, 5])

当然,CustomSubset的作用并不只是添加数据集的属性,我们还可以自定义一些数据预处理操作。我们将类的结构修改如下:

class CustomSubset(Subset):
'''A custom subset class with customizable data transformation'''
def __init__(self, dataset, indices, subset_transform=None):
super().__init__(dataset, indices)
self.targets = dataset.targets
self.classes = dataset.classes
self.subset_transform = subset_transform def __getitem__(self, idx):
x, y = self.dataset[self.indices[idx]] if self.subset_transform:
x = self.subset_transform(x) return x, y def __len__(self):
return len(self.indices)

我们可以在使用样本前设置好数据预处理算子:

from torchvision import transforms
valid_data.subset_transform = transforms.Compose(\
[transforms.RandomRotation((180,180))])

这样,我们再像下列这样用索引访问取出数据集样本时,就会自动调用算子完成预处理操作:

print(valid_data[0])

打印结果缩略如下:


(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)

Pytorch技法:继承Subset类完成自定义数据拆分的更多相关文章

  1. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

  2. [深度学习] pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强)

    一.继承nn.Module类并自定义层 我们要利用pytorch提供的很多便利的方法,则需要将很多自定义操作封装成nn.Module类. 首先,简单实现一个Mylinear类: from torch ...

  3. .Net 配置文件--继承ConfigurationSection实现自定义处理类处理自定义配置节点

    除了使用继承IConfigurationSectionHandler的方法定义处理自定义节点的类,还可以通过继承ConfigurationSection类实现同样效果. 首先说下.Net配置文件中一个 ...

  4. .Net 配置文件——继承ConfigurationSection实现自定义处理类处理自定义配置节点

    除了使用继承IConfigurationSectionHandler的方法定义处理自定义节点的类,还可以通过继承ConfigurationSection类实现同样效果. 首先说下.Net配置文件中一个 ...

  5. WPF 之 创建继承自Window 基类的自定义窗口基类

    开发项目时,按照美工的设计其外边框(包括最大化,最小化,关闭等按钮)自然不同于 Window 自身的,但窗口的外边框及窗口移动.最小化等标题栏操作基本都是一样的.所以通过查看资料,可按如下方法创建继承 ...

  6. QVariant类及QVariant与自定义数据类型转换的方法

    这个类型相当于是Java里面的Object,它把绝大多数Qt提供的数据类型都封装起来,起到一个数据类型“擦除”的作用.比如我们的 table单元格可以是string,也可以是int,也可以是一个颜色值 ...

  7. 【spring boot】7.静态资源和拦截器处理 以及继承WebMvcConfigurerAdapter类进行更多自定义配置

    开头是鸡蛋,后面全靠编!!! ========================================================  1.默认静态资源映射路径以及优先顺序 Spring B ...

  8. JS面向对象(1) -- 简介,入门,系统常用类,自定义类,constructor,typeof,instanceof,对象在内存中的表现形式

    相关链接: JS面向对象(1) -- 简介,入门,系统常用类,自定义类,constructor,typeof,instanceof,对象在内存中的表现形式 JS面向对象(2) -- this的使用,对 ...

  9. [转]MVC自定义数据验证(两个时间的比较)

    本文转自:http://www.cnblogs.com/zhangliangzlee/archive/2012/07/26/2610071.html Model: public class Model ...

随机推荐

  1. 使用 Jenkins + Ansible 实现 Spring Boot 自动化部署101

    本文要点:设计一条 Spring Boot 最基本的流水线:包括构建.制品上传.部署.使用 Docker 容器运行构建逻辑.自动化整个实验环境:包括 Jenkins 的配置,Jenkins agent ...

  2. test_5 排序‘+’、‘-’

    题目是:有一组"+"和"-"符号,要求将"+"排到左边,"-"排到右边,写出具体的实现方法. 方法一: l=['-', ...

  3. WAFW00F waf识别工具 源码学习

    我实习工作的第一个任务根据已有的java waf识别工具 实现了一个python的waf识别工具 代码结构非常乱 仅仅达到了能用的水平. 顶头svp推荐这个项目当时我已经写好了开始用了自己的 稍微看了 ...

  4. Java 中如何实现线程间通信

    世界以痛吻我,要我报之以歌 -- 泰戈尔<飞鸟集> 虽然通常每个子线程只需要完成自己的任务,但是有时我们希望多个线程一起工作来完成一个任务,这就涉及到线程间通信. 关于线程间通信本文涉及到 ...

  5. 记一次异步处理导致Jetty Request对象泄漏

    最近排查一个bug,发现了一系列有意思的东西,对「自定义线程池」.「Jetty线程模型」都有了一些新的认识. 本文预计阅读时间10分钟,包括: 问题表现 常见原因筛查 根因与源码分析 最佳实践 一些小 ...

  6. Java 各个版本中的新特性

    新特性你知道多少? Java 8 Lambda 表达式 接口增加默认方法等 方法引用 流 Stream Java 9 模块系统 交互式工具jshell .of() 创建不可变集合 接口支持私有方法 更 ...

  7. 《剑指offer》面试题18. 删除链表的节点

    问题描述 给定单向链表的头指针和一个要删除的节点的值,定义一个函数删除该节点. 返回删除后的链表的头节点. 注意:此题对比原题有改动 示例 1: 输入: head = [4,5,1,9], val = ...

  8. 《剑指offer》面试题57 - II. 和为s的连续正数序列

    问题描述 输入一个正整数 target ,输出所有和为 target 的连续正整数序列(至少含有两个数). 序列内的数字由小到大排列,不同序列按照首个数字从小到大排列. 示例 1: 输入:target ...

  9. docker安装、下载镜像、容器的基本操作

    文章目录 一.docker安装与基本使用 1.docker的安装.从远程仓库下载镜像 2.配置docker国内源 二.创建容器 1.create i.创建容器 ii.进入容器 iii.启动容器 2.r ...

  10. JS调用堆栈

    调用栈 JavaScript 是一门单线程的语言,这意味着它只有一个调用栈,因此,它同一时间只能做一件事.如果我们运行到一个函数,它就会将其放置到栈顶.当从这个函数返回的时候,就会将这个函数从栈顶弹出 ...