数据并不总是满足机器学习算法所需的格式。我们使用transform对数据进行一些操作,使得其能适用于训练。

所有的TorchVision数据集都有两个参数,用以接受包含transform逻辑的可调用项-transform 修改features,targe_transform 修改标签。torchvision.transforms提供了几种现成的常用转换操作。

FashionMNIST features是PIL Image格式,标签是整型。为了训练,我们需要将其转换为标准的tensors,并且标签是one-hot编码的tensor。为了完成这些转换,使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda ds = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor(),
# 在创建的具有10个0值数组中,单独取第一个维度的y位置(原始标签),赋为1,即为one-hot编码
target_tansform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0,
torch.tensor(y), value=1))
)

输出:

点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

ToTensor()

ToTensor将PIL图像或NumPy ndarray 转换为 FloatTensor。并且将图片像素值缩放到范围[0., 1.]

Lambda Transforms

Lambda转换可使用任何用户定义的lambda函数。这里,我们定义了一个函数,可以将整型转换成one-hot编码的tensor,首先创建一个大小为10的0值tensor,根据给定标签 y得到索引位置,调用scatter_将其赋为1。

target_transform = Lambda(lambda y: torch.zeros(
10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

延伸阅读

PyTorch 介绍 | TRANSFORMS的更多相关文章

  1. PyTorch 介绍 | DATSETS & DATALOADERS

    用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...

  2. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  3. PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD

    训练神经网络时,最常用的算法就是反向传播.在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整. 为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎. ...

  4. pytorch随笔

    pytorch中transform函数 一般用Compose把多个步骤整合到一起: 比如说 transforms.Compose([ transforms.CenterCrop(10), transf ...

  5. Keras vs. PyTorch in Transfer Learning

    We perform image classification, one of the computer vision tasks deep learning shines at. As traini ...

  6. Pytorch(一)

    一.Pytorch介绍 Pytorch 是Torch在Python上的衍生物 和Tensorflow相比: Pytorch建立的神经网络是动态的,而Tensorflow建立的神经网络是静态的 Tens ...

  7. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  8. Generative Adversarial Network (GAN) - Pytorch版

    import os import torch import torchvision import torch.nn as nn from torchvision import transforms f ...

  9. Tensorflow和pytorch安装(windows安装)

    一. Tensorflow安装 1. Tensorflow介绍 Tensorflow是广泛使用的实现机器学习以及其它涉及大量数学运算的算法库之一.Tensorflow由Google开发,是GitHub ...

随机推荐

  1. 【LeetCode】801. Minimum Swaps To Make Sequences Increasing 解题报告(Python)

    作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 动态规划 参考资料 日期 题目地址:https:// ...

  2. ELBO surgery: yet another way to carve up the variational evidence lower bound

    目录 概 主要内容 Evidence minus posterior KL Average negative energy plus entropy Average term-by-term reco ...

  3. 计算机图形学——梁友栋-Barsky算法

    梁算法是计算机图形学上最经典的几个算法,也是目前唯一一个以中国人命名的出现在国内外计算机图形学课本的算法,我之前在介绍裁剪算法的时候介绍过这个算法 https://www.cnblogs.com/wk ...

  4. ☕【难点攻克技术系列】「海量数据计算系列」如何使用BitMap在海量数据中对相应的进行去重、查找和排序

    BitMap(位图)的介绍 BitMap从字面的意思,很多人认为是位图,其实准确的来说,翻译成基于位的映射,其中数据库中有一种索引就叫做位图索引. 在具有性能优化的数据结构中,大家使用最多的就是has ...

  5. HTML网页设计基础笔记 • 【第4章 CSS3基础】

    全部章节   >>>> 本章目录 4.1 CSS 概述 4.1.1 CSS 简介 4.1.2 CSS3 基本语法 4.1.3 样式表的分类 4.2 CSS 基本选择器 4.2. ...

  6. MySQL数据操作与查询笔记 • 【目录】

    持续更新中- 我的大学笔记>>> 章节 内容 第1章 MySQL数据操作与查询笔记 • [第1章 MySQL数据库基础] 第2章 MySQL数据操作与查询笔记 • [第2章 表结构管 ...

  7. playwright--自动化(二):过滑块验证码 验证码缺口识别

    前两天需要自动化登录一个商城的后台 用的是playwright 没有用selenium 中间出了一个滑块验证 现阶段playwright教程不是太多,自己做移动的时候各种找,费劲巴拉的.现在自己整出来 ...

  8. 自学java,如何快速地找到工作

    本人最近一直在帮零基础的java开发者提升能力和找工作,在这个过程中,发现零基础的java程序员,在自学和找工作时,普遍会出现一些问题,同时在实践过程中,也总结出了一些能帮零基础java开发尽快提升能 ...

  9. Java 设置系统参数和运行参数

    系统参数 系统级全局变量,该参数在程序中任何位置都可以访问到.优先级最高,覆盖程序中同名配置. 系统参数的标准格式为:-Dargname=argvalue,多个参数之间用空格隔开,如果参数值中间有空格 ...

  10. cpu负载

    查看cpu负载,我们经常会使用top,或者是uptime命令 但是这只能看到cpu的总体的负载情况.如果我们想看cpu每个核心的负载情况是看不到的. 所以我们可以用mpstat命令 服务器一共32核心 ...