PyTorch 介绍 | TRANSFORMS
数据并不总是满足机器学习算法所需的格式。我们使用transform对数据进行一些操作,使得其能适用于训练。
所有的TorchVision数据集都有两个参数,用以接受包含transform逻辑的可调用项-transform
修改features,targe_transform
修改标签。torchvision.transforms提供了几种现成的常用转换操作。
FashionMNIST features是PIL Image格式,标签是整型。为了训练,我们需要将其转换为标准的tensors,并且标签是one-hot编码的tensor。为了完成这些转换,使用 ToTensor
和 Lambda
。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor(),
# 在创建的具有10个0值数组中,单独取第一个维度的y位置(原始标签),赋为1,即为one-hot编码
target_tansform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0,
torch.tensor(y), value=1))
)
输出:
点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ToTensor()
ToTensor将PIL图像或NumPy ndarray
转换为 FloatTensor
。并且将图片像素值缩放到范围[0., 1.]
Lambda Transforms
Lambda转换可使用任何用户定义的lambda函数。这里,我们定义了一个函数,可以将整型转换成one-hot编码的tensor,首先创建一个大小为10的0值tensor,根据给定标签 y
得到索引位置,调用scatter_将其赋为1。
target_transform = Lambda(lambda y: torch.zeros(
10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
延伸阅读
PyTorch 介绍 | TRANSFORMS的更多相关文章
- PyTorch 介绍 | DATSETS & DATALOADERS
用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...
- PyTorch 介绍 | BUILD THE NEURAL NETWORK
神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...
- PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD
训练神经网络时,最常用的算法就是反向传播.在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整. 为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎. ...
- pytorch随笔
pytorch中transform函数 一般用Compose把多个步骤整合到一起: 比如说 transforms.Compose([ transforms.CenterCrop(10), transf ...
- Keras vs. PyTorch in Transfer Learning
We perform image classification, one of the computer vision tasks deep learning shines at. As traini ...
- Pytorch(一)
一.Pytorch介绍 Pytorch 是Torch在Python上的衍生物 和Tensorflow相比: Pytorch建立的神经网络是动态的,而Tensorflow建立的神经网络是静态的 Tens ...
- PyTorch 实战:计算 Wasserstein 距离
PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...
- Generative Adversarial Network (GAN) - Pytorch版
import os import torch import torchvision import torch.nn as nn from torchvision import transforms f ...
- Tensorflow和pytorch安装(windows安装)
一. Tensorflow安装 1. Tensorflow介绍 Tensorflow是广泛使用的实现机器学习以及其它涉及大量数学运算的算法库之一.Tensorflow由Google开发,是GitHub ...
随机推荐
- 【LeetCode】208. Implement Trie (Prefix Tree) 实现 Trie (前缀树)
作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 公众号:负雪明烛 本文关键词:Leetcode, 力扣,Trie, 前缀树,字典树,20 ...
- python xlwt写Excel表
1 xlwt第三方库 说明:xlwt是一个用于将数据和格式化信息写入并生成Excel文件的库. 注意:xlwt不支持写xlsx表,打开表文件报错. 官方文档:https://xlwt.readthed ...
- RabbitMQ学习笔记六:RabbitMQ之消息确认
使用消息队列,必须要考虑的问题就是生产者消息发送失败和消费者消息处理失败,这两种情况怎么处理. 生产者发送消息,成功,则确认消息发送成功;失败,则返回消息发送失败信息,再做处理. 消费者处理消息,成功 ...
- vue项目报错如下:(Emitted value instead of an instance of Error)
(Emitted value instead of an instance of Error) the "scope" attribute for scoped slots hav ...
- Typec转HDMI+VGA+PD3.0+USB3.0扩展坞四合一芯片方案|CS5268替代AG9321
CS5268是一种高度集成的单芯片,适用于多个细分市场和显示应用,如拓展坞.扩展底座等. 2.CS5268参数说明 总则 USB Type-C规范1.2 HDMI规范v2.0b兼容发射机,数据速率高达 ...
- 基于Spring MVC + Spring + MyBatis的【图书信息管理系统(二)】
资源下载:https://download.csdn.net/download/weixin_44893902/35123371 练习点设计:添加.删除.修改 一.语言和环境 实现语言:JAVA语言. ...
- 使用tomcat搭建HTTP文件下载服务器
使用tomcat搭建HTTP文件下载服务器, 有时我们的应用或者服务需要去外网下载一些资源, 但是如果在内网环境或者网络不好的情况下, 我们可以在内网提供文件下载服务, 将预先下载好的资源放在某个地方 ...
- 微信公众号开发--.net core接入
.net进行微信公众号开发的例子好像比较少,这里做个笔记 首先,我们需要让微信能访问到我们的项目,所以要么需要有一个可以部署项目的连接到公网下的服务器,要么可以通过端口转发将请求转发到我们的项目,总之 ...
- Linux 安装并启用 PHP-FPM
首先,在编译时带上 --enable-fpm 参数: [root@localhost local]# yum -y install libxml2 libxml2-devel gd gd-devel ...
- python+pytest,通过自定义命令行参数,实现浏览器兼容性跑用例
场景拓展: UI自动化可能需要指定浏览器进行测试,为了做成自定义配置浏览器,可以通过动态添加pytest的命令行参数,在执行的时候,获取命令行传入的参数,在对应的浏览器执行用例. 1.自动化用例需要支 ...