数据并不总是满足机器学习算法所需的格式。我们使用transform对数据进行一些操作,使得其能适用于训练。

所有的TorchVision数据集都有两个参数,用以接受包含transform逻辑的可调用项-transform 修改features,targe_transform 修改标签。torchvision.transforms提供了几种现成的常用转换操作。

FashionMNIST features是PIL Image格式,标签是整型。为了训练,我们需要将其转换为标准的tensors,并且标签是one-hot编码的tensor。为了完成这些转换,使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda ds = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor(),
# 在创建的具有10个0值数组中,单独取第一个维度的y位置(原始标签),赋为1,即为one-hot编码
target_tansform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0,
torch.tensor(y), value=1))
)

输出:

点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

ToTensor()

ToTensor将PIL图像或NumPy ndarray 转换为 FloatTensor。并且将图片像素值缩放到范围[0., 1.]

Lambda Transforms

Lambda转换可使用任何用户定义的lambda函数。这里,我们定义了一个函数,可以将整型转换成one-hot编码的tensor,首先创建一个大小为10的0值tensor,根据给定标签 y得到索引位置,调用scatter_将其赋为1。

target_transform = Lambda(lambda y: torch.zeros(
10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

延伸阅读

PyTorch 介绍 | TRANSFORMS的更多相关文章

  1. PyTorch 介绍 | DATSETS & DATALOADERS

    用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...

  2. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  3. PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD

    训练神经网络时,最常用的算法就是反向传播.在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整. 为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎. ...

  4. pytorch随笔

    pytorch中transform函数 一般用Compose把多个步骤整合到一起: 比如说 transforms.Compose([ transforms.CenterCrop(10), transf ...

  5. Keras vs. PyTorch in Transfer Learning

    We perform image classification, one of the computer vision tasks deep learning shines at. As traini ...

  6. Pytorch(一)

    一.Pytorch介绍 Pytorch 是Torch在Python上的衍生物 和Tensorflow相比: Pytorch建立的神经网络是动态的,而Tensorflow建立的神经网络是静态的 Tens ...

  7. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  8. Generative Adversarial Network (GAN) - Pytorch版

    import os import torch import torchvision import torch.nn as nn from torchvision import transforms f ...

  9. Tensorflow和pytorch安装(windows安装)

    一. Tensorflow安装 1. Tensorflow介绍 Tensorflow是广泛使用的实现机器学习以及其它涉及大量数学运算的算法库之一.Tensorflow由Google开发,是GitHub ...

随机推荐

  1. 【LeetCode】430. Flatten a Multilevel Doubly Linked List 解题报告(Python)

    [LeetCode]430. Flatten a Multilevel Doubly Linked List 解题报告(Python) 标签(空格分隔): LeetCode 作者: 负雪明烛 id: ...

  2. E. Congruence Equation

    E. Congruence Equation 思路: 中国剩余定理 \(a^n(modp) = a^{nmod(p-1)}(modp)\),那么枚举在\([0,n-2]\)枚举指数 求\(a^i\)关 ...

  3. 【LeetCode】297. Serialize and Deserialize Binary Tree 解题报告(Python)

    [LeetCode]297. Serialize and Deserialize Binary Tree 解题报告(Python) 标签: LeetCode 题目地址:https://leetcode ...

  4. 1052 - String Growth

    1052 - String Growth    PDF (English) Statistics Forum Time Limit: 2 second(s) Memory Limit: 32 MB Z ...

  5. .Net Core&Agile Config配置中心

    当服务逐渐的增多,对各服务的配置管理愈加重要,轻量级的配置中心,入手或是搭建都简单许多,基于.net core开发的轻量级配置中心AgileConfig,功能强大,上手简单. https://gith ...

  6. [数学]高数部分-Part I 极限与连续

    Part I 极限与连续 回到总目录 Part I 极限与连续 一.极限 泰勒公式 基本微分公式 常用等价无穷小 函数极限定义 数列极限数列极限 极限的性质 极限的唯一性 极限的局部有限性 极限的局部 ...

  7. C++ std-11 常用方法

    对多个值取最值 C++标准库提供了获取最大值和最小值的方法: int mi = std::min(x1, x2); int ma = std::max(x1, x2); 如果想获取超过两个数的最值呢? ...

  8. 20道JavaScript经典面试题

    该篇文章整理了一些前端经典面试题,附带详解,涉及到JavaScript多方面知识点,满满都是干货-建议收藏阅读 前言 如果这篇文章有帮助到你,️关注+点赞️鼓励一下作者,文章公众号首发,关注 前端南玖 ...

  9. versions-maven-plugin插件批量修改版本号

    1.简介 versions-maven-plugin插件可以管理项目版本, 特别是当Maven工程项目中有大量子模块时, 可以批量修改pom版本号, 插件会把父模块更新到指定版本号, 然后更新子模块版 ...

  10. HiSql 实现case语法操作 新一代无实体ORM框架

    HiSql 实现case语法操作 在SqlServer,Oralce,Hana,PostGreSql,MySql 这些数据都支持SQL case语法,平常在实现业务开发中也会常用到,那么HiSql对于 ...