在Pytorch0.4版本的DARTS代码里,有一行代码是

trn_data = datasets.CIFAR10(root=data_path, train=True, download=False, transform=train_transform)
shape = trn_data.train_data.shape

在1.2及以上版本里,查看源码可知,CIFAR10这个类已经没有train_data这个属性了,取而代之的是data,因此要把第二行改成

shape = trn_data.data.shape

datasets.CIFAR10源码如下:

from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive [docs]class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. """
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
] test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
} def __init__(self, root, train=True, transform=None, target_transform=None,
download=False): super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform) self.train = train # training set or test set if download:
self.download() if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it') if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list self.data = []
self.targets = [] # now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels']) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC self._load_meta()

关于torchvision.datasets.CIFAR10的更多相关文章

  1. torchvision.datasets.ImageFolder数据加载

    ImageFolder 一个通用的数据加载器,数据集中的数据以以下方式组织 root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/12 ...

  2. torchvision.datasets

    转载  https://ptorch.com/docs/8/torchvision-datasets

  3. 试着用教程跑cifar10数据

    1.terminal里已经可import torchvision了,为什么Spyder里还是不能import torchvision 重启. 2. trainset = torchvision.dat ...

  4. PyTorch入门-CIFAR10图像分类

    CIFAR10数据集下载 CIFAR10数据集包含10个类别,图像尺寸为 3×32×32 官方下载地址很慢,这里给一个百度云: https://pan.baidu.com/s/1oTvW8wNa-VO ...

  5. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  6. 学习笔记-ResNet网络

    ResNet网络 ResNet原理和实现 总结 一.ResNet原理和实现 神经网络第一次出现在1998年,当时用5层的全连接网络LetNet实现了手写数字识别,现在这个模型已经是神经网络界的“hel ...

  7. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  8. PyTorch进行深度学习入门

    一.PyTorch是什么? 这是一个基于Python的科学计算软件包,针对两组受众: ①.NumPy的替代品,可以使用GPU的强大功能 ②.深入学习研究平台,提供最大的灵活性和速度 二.入门 ①.张量 ...

  9. 深度学习(pytorch)-1.基于简单神经网络的图片自动分类

    这是pytorch官方的一个例子 官方教程地址:http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-b ...

随机推荐

  1. Netty入门 零基础

    因为接下来的项目要用到netty,所以就了解一下这个程序,奈何网上的教程都是稍微有点基础的,所以,就写一篇对于netty零基础的,顺便也记录一下. 先扔几个参考学习的网页: netty 官方API:  ...

  2. jquery selected选择器 语法

    jquery selected选择器 语法 作用::selected 选择器选取被选择的 <option> 元素.直线电机生产厂家 语法:$(":selected") ...

  3. web+大文件上传

    总结一下大文件分片上传和断点续传的问题.因为文件过大(比如1G以上),必须要考虑上传过程网络中断的情况.http的网络请求中本身就已经具备了分片上传功能,当传输的文件比较大时,http协议自动会将文件 ...

  4. bzoj 4899 记忆的轮廓 题解(概率dp+决策单调性优化)

    题目背景 四次死亡轮回后,昴终于到达了贤者之塔,当代贤者夏乌拉一见到昴就上前抱住了昴“师傅!你终于回来了!你有着和师傅一样的魔女的余香,肯定是师傅”.众所周知,大贤者是嫉妒魔女沙提拉的老公,400年前 ...

  5. 关于int main(int argc,char* argv[])详解

    平时在VS的环境下,主函数总会看到这两个参数,今天突然很想知道这两个参数的原理以及作用,因此查了下资料.真心受教了. 下面的博文是在百度空间看一位大神的,原文链接:http://hi.baidu.co ...

  6. C语言的预编译,程序员必须懂的知识!【预编译指令】【预编译过程】

    由“源代码”到“可执行文件”的过程包括四个步骤:预编译.编译.汇编.链接.所以,首先就应该清楚的首要问题就是:预编译只是对程序的文本起作用,换句话说就是,预编译阶段仅仅对源代码的单词进行变换,而不是对 ...

  7. [题解] [AtCoder2134] Zigzag MST

    题面 题解 考虑kruscal的过程 对于三个点\(x, y, x + 1\), 我们可以将\((x, y, z), (y, x + 1, z + 1)\)看做\((x, y, z), (x, x + ...

  8. Linux TCP自连接问题

    [参考文章]:net.ipv4.ip_local_port_range 的值究竟影响了啥 [参考文章]:Linux内核参数优化 最近卸载MySQL服务偶尔会遇到MySQL端口自连接问题.导致MySQL ...

  9. React 番外篇

    小技巧:如果我们想了解一门技术,不知道如何学习,那就在 BOSS 直聘上,来看看对这门技术的要求 这篇给大家讲的是 React 1.0 的初始版本,仅仅是让大家有个了解,毕竟回顾历史,我们才能找到他最 ...

  10. 往Angular应用程序中添加DevExtreme

    To start this tutorial, you need an Angular 5+ application created using the Angular CLI. Refer to t ...