实现一个定制的 Dataset 类

Dataset 类是 PyTorch 图像数据集中最为重要的一个类,也是 PyTorch 中所有数据集加载类中应该继承的父类。其中,父类的两个私有成员函数必须被重载。

  • getitem(self, index) # 支持数据集索引的函数
  • len(self) # 返回数据集的大小

Datasets 的框架:

class CustomDataset(data.Dataset): # 需要继承 data.Dataset
def __init__(self):
# TODO
# Initialize file path or list of file names.
pass def __getitem__(self, index):
# TODO
# 1. 从文件中读取指定 index 的数据(例:使用 numpy.fromfile, PIL.Image.open)
# 2. 预处理读取的数据(例:torchvision.Transform)
# 3. 返回数据对(例:图像和对应标签)
pass def __len__(self):
# TODO
# You should change 0 to the total size of your dataset.
return 0

举例:

class MyDataset(Dataset):
"""
root: 图像存放地址根路径
augment:是否需要图像增强
""" def __init__(self, root, augment=None):
# 这个 list 存放所有图像的地址
self.image_files = np.array([
x.path for x in os.scandir(root)
if x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG")
])
self.augment = augment def __getitem__(self, index):
if self.augment:
image = open_image(self.image_files[index]) # 这里的 open_image 是读取图像的函数,可以用 PIL 或者 OpenCV 等库进行读取
image = self.augment(image) # 这里对图像进行了数据增强
return to_tensor(image) # PyTorch 中得到的图像必须是 tensor
else:
image = open_image(self.image_files[index])
return to_tensor(image)

下面是官方 MNIST 的例子:

class MNIST(data.Dataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
class_to_idx = {_class: i for i, _class in enumerate(classes)} @property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set if download:
self.download() if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it') if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file)) def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index] # doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None:
img = self.transform(img) if self.target_transform is not None:
target = self.target_transform(target) return img, target def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data) def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip if self._check_exists():
return # download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise for url in self.urls:
print('Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read())
os.unlink(file_path) # process and save as torch files
print('Processing...') training_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f) print('Done!') def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str

PyTorch 之 Datasets的更多相关文章

  1. [深度学习] pytorch利用Datasets和DataLoader读取数据

    本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码.所有逻辑需按自己需求另外实现: 一.分析DataLoader train_loader = DataLoader( ...

  2. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  3. 如何使用Pytorch迅速实现Mnist数据及分类器

    一段时间没有更新博文,想着也该写两篇文章玩玩了.而从一个简单的例子作为开端是一个比较不错的选择.本文章会手把手地教读者构建一个简单的Mnist(Fashion-Mnist同理)的分类器,并且会使用相对 ...

  4. pytorch实现VAE

    一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...

  5. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  6. PyTorch官方中文文档:PyTorch中文文档

    PyTorch中文文档 PyTorch是使用GPU和CPU优化的深度学习张量库. 说明 自动求导机制 CUDA语义 扩展PyTorch 多进程最佳实践 序列化语义 Package参考 torch to ...

  7. 基于pytorch的电影推荐系统

    本文介绍一个基于pytorch的电影推荐系统. 代码移植自https://github.com/chengstone/movie_recommender. 原作者用了tf1.0实现了这个基于movie ...

  8. [深度应用]·实战掌握PyTorch图片分类简明教程

    [深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com ...

  9. pytorch识别CIFAR10:训练ResNet-34(自定义transform,动态调整学习率,准确率提升到94.33%)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 前面通过数据增强,ResNet-34残差网络识别CIFAR10,准确率达到了92.6. 这里对训练过程 ...

随机推荐

  1. 互联网渗透测试之Wireshark的高级应用

    互联网渗透测试之Wireshark的高级应用 1.1说明 在本节将介绍Wireshark的一些高级特性 1.2. "Follow TCP Stream" 如果你处理TCP协议,想要 ...

  2. PAT 乙级真题 1003 我要通过!题解

    1003 我要通过! (20 分) “答案正确”是自动判题系统给出的最令人欢喜的回复.本题属于 PAT 的“答案正确”大派送 —— 只要读入的字符串满足下列条件,系统就输出“答案正确”,否则输出“答案 ...

  3. Unity Built-In Shader造成的运行时内存暴涨

    在某个PC项目中使用了大量的材质球, 并且都使用了自带的Standard Shader, 在编辑器运行的时候, 一切良好, 运行内存只在1G左右, 然而在进行AssetBundle打包之后, EXE运 ...

  4. 2. chromium开发工具--gclient

    1.gclient简介 gclient是谷歌开发的一套跨平台git仓库管理工具,用来将多个git仓库组成一个solution进行管理.总体上,其核心功能是根据一个Solution的DEPS文件所定义的 ...

  5. lf 前后端分离 (2) 课程数据获取,Serializer的返回

    一.关于课程数据的返回 在进行前后端分离时,会通过def 进行前后端传值, 本质上遵循rest 网址规范  增删改查查 get,post,put,del get(\d+) 1.在从数据库获取数据后,进 ...

  6. 解析YAML文件

    YamlMapFactoryBean yamlMapFactoryBean = new YamlMapFactoryBean(); yamlMapFactoryBean.setResources(ne ...

  7. Springboot jackSon -序列化-详解

    在项目中有事需要对值为NULL的对象中Field不做序列化输入配置方式如下: [配置类型]: 源码包中的枚举类: public static enum Include { ALWAYS, NON_NU ...

  8. 201871020225-牟星源《面向对象程序设计(java)》第八周学习总结

    201871020225-牟星源<面向对象程序设计(java)>第八周学习总结 项目 内容 这个作业属于哪个课程 https://www.cnblogs.com/nwnu-daizh/ 这 ...

  9. Python前言之Markdown使用

    一.Markdown基本语法 1.1标题 代码: # 一级标题 ## 二级标题 ### 三级标题 #### 四级标题 ##### 五级标题 ###### 六级标题 效果: 一级标题 二级标题 三级标题 ...

  10. 每天一道Rust-LeetCode(2019-06-10)

    每天一道Rust-LeetCode(2019-06-02) Z 字形变换 坚持每天一道题,刷题学习Rust. 题目描述 https://leetcode-cn.com/problems/simplif ...