数据并不总是满足机器学习算法所需的格式。我们使用transform对数据进行一些操作,使得其能适用于训练。

所有的TorchVision数据集都有两个参数,用以接受包含transform逻辑的可调用项-transform 修改features,targe_transform 修改标签。torchvision.transforms提供了几种现成的常用转换操作。

FashionMNIST features是PIL Image格式,标签是整型。为了训练,我们需要将其转换为标准的tensors,并且标签是one-hot编码的tensor。为了完成这些转换,使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda ds = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor(),
# 在创建的具有10个0值数组中,单独取第一个维度的y位置(原始标签),赋为1,即为one-hot编码
target_tansform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0,
torch.tensor(y), value=1))
)

输出:

点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

ToTensor()

ToTensor将PIL图像或NumPy ndarray 转换为 FloatTensor。并且将图片像素值缩放到范围[0., 1.]

Lambda Transforms

Lambda转换可使用任何用户定义的lambda函数。这里,我们定义了一个函数,可以将整型转换成one-hot编码的tensor,首先创建一个大小为10的0值tensor,根据给定标签 y得到索引位置,调用scatter_将其赋为1。

target_transform = Lambda(lambda y: torch.zeros(
10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

延伸阅读

PyTorch 介绍 | TRANSFORMS的更多相关文章

  1. PyTorch 介绍 | DATSETS & DATALOADERS

    用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...

  2. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  3. PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD

    训练神经网络时,最常用的算法就是反向传播.在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整. 为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎. ...

  4. pytorch随笔

    pytorch中transform函数 一般用Compose把多个步骤整合到一起: 比如说 transforms.Compose([ transforms.CenterCrop(10), transf ...

  5. Keras vs. PyTorch in Transfer Learning

    We perform image classification, one of the computer vision tasks deep learning shines at. As traini ...

  6. Pytorch(一)

    一.Pytorch介绍 Pytorch 是Torch在Python上的衍生物 和Tensorflow相比: Pytorch建立的神经网络是动态的,而Tensorflow建立的神经网络是静态的 Tens ...

  7. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  8. Generative Adversarial Network (GAN) - Pytorch版

    import os import torch import torchvision import torch.nn as nn from torchvision import transforms f ...

  9. Tensorflow和pytorch安装(windows安装)

    一. Tensorflow安装 1. Tensorflow介绍 Tensorflow是广泛使用的实现机器学习以及其它涉及大量数学运算的算法库之一.Tensorflow由Google开发,是GitHub ...

随机推荐

  1. 【九度OJ】题目1124:Digital Roots 解题报告

    [九度OJ]题目1124:Digital Roots 解题报告 标签(空格分隔): 九度OJ 原题地址:http://ac.jobdu.com/problem.php?pid=1124 题目描述: T ...

  2. RXD and math

    RXD and math 题目链接 思路 \(u\)函数是莫比乌斯函数,这个不影响做题,这个式子算的是\([1,n^k]\)中能够写成\(a*b^2\)的数的个数,\(u(a)!=0\).然后我们可以 ...

  3. (gcd(a,b) == axorb)==>b = a-gcd(a,b)

    \((gcd(a,b) == axorb)==>b = a-gcd(a,b)\)的证明 \[(gcd(a,b) == axorb)==>b = a-gcd(a,b) \] 证明$$a-b& ...

  4. Estimation of Non-Normalized Statistical Models by Score Matching

    目录 概 主要内容 方法 损失函数的转换 一个例子 Hyv"{a}rinen A. Estimation of Non-Normalized Statistical Models by Sc ...

  5. Certified Adversarial Robustness via Randomized Smoothing

    目录 概 主要内容 定理1 代码 Cohen J., Rosenfeld E., Kolter J. Certified Adversarial Robustness via Randomized S ...

  6. EDP转LVDS屏转接板方案|基于INTELX86主板和商显应用EDP转LVDS设计CS5211

    众所周知LVDS接口是美国NS美国国家半导体公司为克服以TTL电平方式传输宽带高码率数据时功耗大,电磁干扰大等缺点而研制的一种数字视频信号传输方式.由于其采用低压和低电流驱动方式,实现了低噪声和低功耗 ...

  7. Java Web程序设计笔记 • 【第6章 Servlet技术进阶】

    全部章节   >>>> 本章目录 6.1 应用 Servlet API(一) 6.1.1 Servlet 类的层次结构 6.1.2 使用 Servlet API 的原则 6.1 ...

  8. 使用 jQuery 操作页面元素的方法,实现浏览大图片的效果,在页面上插入一幅小图片,当鼠标悬停到小图片上时,在小图片的右侧出现与之相对应的大图片

    查看本章节 查看作业目录 需求说明: 使用 jQuery 操作页面元素的方法,实现浏览大图片的效果,在页面上插入一幅小图片,当鼠标悬停到小图片上时,在小图片的右侧出现与之相对应的大图片 实现思路: 在 ...

  9. Java中的构造方法「注意事项」

    构造方法是专门用来创建对象的方法,当我们通过关键字new来创建对象时,其实就是调用构造方法. 语法: public 类名称(参数类型 参数名称){ 方法体 } 注意事项: 构造方法的名称必须和所在的类 ...

  10. .net Core WebApi使用AutoFac

    1.在要添加的项目中选中 依赖项->右键->管理NuGet程序包(N) 2.在NuGet包管理器中输入Autofac,安装选中文件 3.在项目中找到Program.cs文件,添加如下代码 ...