旷视MegEngine数据加载与处理
旷视MegEngine数据加载与处理
在网络训练与测试中,数据的加载和预处理往往会耗费大量的精力。 MegEngine 提供了一系列接口来规范化这些处理工作。
利用 Dataset 封装一个数据集
数据集是一组数据的集合,例如 MNIST、Cifar10等图像数据集。 Dataset 是 MegEngine 中表示数据集的抽象类。自定义的数据集类应该继承 Dataset 并重写下列方法:
- __init__() :一般在其中实现读取数据源文件的功能。也可以添加任何其它的必要功能;
- __getitem__() :通过索引操作来获取数据集中某一个样本,使得可以通过 for 循环来遍历整个数据集;
- __len__() :返回数据集大小;
下面是一个简单示例。 根据下图所示的二分类数据,创建一个 Dataset 。每个数据是一个二维平面上的点,横坐标和纵坐标在 [-1, 1] 之间。共有两个类别标签(图1中的蓝色 * 和红色 +),标签为0的点处于一、三象限;标签为1的点处于二、四象限。
图1
该数据集的创建过程如下:
- 在 __init__() 中利用 NumPy 随机生成 ndarray 作为数据;
- 在 __getitem__() 中返回 ndarray 中的一个样本;
- 在 __len__() 中返回整个数据集中样本的个数;
import numpy as np
from typing import Tuple
# 导入需要被继承的 Dataset 类
from megengine.data.dataset import Dataset
class XORDataset(Dataset):
def __init__(self, num_points):
"""
生成如图1所示的二分类数据集,数据集长度为 num_points
"""
super().__init__()
# 初始化一个维度为 (50000, 2) 的 NumPy 数组。
# 数组的每一行是一个横坐标和纵坐标都落在 [-1, 1] 区间的一个数据点 (x, y)
self.data = np.random.rand(num_points, 2).astype(np.float32) * 2 - 1
# 为上述 NumPy 数组构建标签。每一行的 (x, y) 如果符合 x*y < 0,则对应标签为1,反之,标签为0
self.label = np.zeros(num_points, dtype=np.int32)
for i in range(num_points):
self.label[i] = 1 if np.prod(self.data[i]) < 0 else 0
# 定义获取数据集中每个样本的方法
def __getitem__(self, index: int) -> Tuple:
return self.data[index], self.label[index]
# 定义返回数据集长度的方法
def __len__(self) -> int:
return len(self.data)
np.random.seed(2020)
# 构建一个包含 30000 个点的训练数据集
xor_train_dataset = XORDataset(30000)
print("The length of train dataset is: {}".format(len(xor_train_dataset)))
# 通过 for 遍历数据集中的每一个样本
for cor, tag in xor_train_dataset:
print("The first data point is: {}, {}".format(cor, tag))
break
print("The second data point is: {}".format(xor_train_dataset[1]))
输出:
The length of train dataset is: 30000
The first data point is: [0.97255366 0.74678389], 0
The second data point is: (array([ 0.01949105, -0.45632857]), 1)
MegEngine 中也提供了一些已经继承自 Dataset 的数据集类,方便使用,比如 ArrayDataset 。 ArrayDataset 允许通过传入单个或多个 NumPy 数组,对它进行初始化。其内部实现如下:
- __init__() :检查传入的多个 NumPy 数组的长度是否一致;不一致则无法成功创建;
- __getitem__() :将多个 NumPy 数组相同索引位置的元素构成一个 tuple 并返回;
- __len__() :返回数据集的大小;
以图1所示的数据集为例,可以通过坐标数据和标签数据的数组直接构造 ArrayDataset ,无需用户自己定义数据集类。
from megengine.data.dataset import ArrayDataset
# 准备 NumPy 形式的 data 和 label 数据
np.random.seed(2020)
num_points = 30000
data = np.random.rand(num_points, 2).astype(np.float32) * 2 - 1
label = np.zeros(num_points, dtype=np.int32)
for i in range(num_points):
label[i] = 1 if np.prod(data[i]) < 0 else 0
# 利用 ArrayDataset 创建一个数据集类
xor_dataset = ArrayDataset(data, label)
通过 Sampler 从 Dataset 中采样
Dataset 仅能通过一个固定的顺序(其 __getitem__ 实现)访问所有样本, 而 Sampler 使得可以以所期望的方式从 Dataset 中采样,生成训练和测试的批(minibatch)数据。 Sampler 本质上是一个数据集中数据索引的迭代器,接收 Dataset 的实例和批大小(batch_size)来进行初始化。
MegEngine 中提供各种常见的采样器,如 RandomSampler (通常用于训练)、 SequentialSampler (通常用于测试) 等。
下面示例,来熟悉 Sampler 的基本用法:
# 导入 MegEngine 中采样器
from megengine.data import RandomSampler
# 创建一个随机采样器
random_sampler = RandomSampler(dataset=xor_dataset, batch_size=4)
# 获取迭代sampler时每次返回的数据集索引
for indices in random_sampler:
print(indices)
break
输出:
[19827, 2614, 8788, 8641]
可以看到,在 batch_size 为4时,每次迭代 sampler 返回的是长度为4的列表,列表中的每个元素是随机采样出的数据索引。
如果创建的是一个序列化采样器 SequentialSampler ,那么每次返回的就是顺序索引。
from megengine.data import SequentialSampler
sequential_sampler = SequentialSampler(dataset=xor_dataset, batch_size=4)
# 获取迭代sampler时返回的数据集索引信息
for indices in sequential_sampler:
print(indices)
break
输出:
[0, 1, 2, 3]
用户也可以继承 Sampler 自定义采样器,这里不做详述。
用 DataLoader 生成批数据
MegEngine 中,DataLoader 本质上是一个迭代器,它通过 Dataset 和 Sampler 生成 minibatch 数据。
下列代码通过 for 循环获取每个 minibatch 的数据。
from megengine.data import DataLoader
# 创建一个 DataLoader,并指定数据集和顺序采样器
xor_dataloader = DataLoader(
dataset=xor_dataset,
sampler=sequential_sampler,
)
print("The length of the xor_dataloader is: {}".format(len(xor_dataloader)))
# 从 DataLoader 中迭代地获取每批数据
for idx, (cor, tag) in enumerate(xor_dataloader):
print("iter %d : " % (idx), cor, tag)
break
输出:
The length of the xor_dataloader is: 7500
iter 0 : [[ 0.97255366 0.74678389]
[ 0.01949105 -0.45632857]
[-0.32616254 -0.56609147]
[-0.44704571 -0.31336881]] [0 1 0 0]
DataLoader 中的数据变换(Transform)
在深度学习模型的训练中,经常需要对数据进行各种转换,比如,归一化、各种形式的数据增广等。 Transform 是数据变换的基类,其各种派生类提供了常见的数据转换功能。 DataLoader 构造函数可以接收一个 Transform 参数, 在构建 minibatch 时,对该批数据进行相应的转换操作。
接下来通过 MNIST 数据集(MegEngine 提供了 MNIST Dataset)来熟悉 Transform 的使用。 首先构建一个不做 Transform 的 MNIST DataLoader,并可视化第一个 minibatch 数据。
# 从 MegEngine 中导入 MNIST 数据集
from megengine.data.dataset import MNIST
# 若是第一次下载 MNIST 数据集,download 需设置成 True
# 若已经下载 MNIST 数据集,通过 root 指定 MNIST数据集 raw 路径
# 通过设置 train=True/False 获取训练集或测试集
mnist_train_dataset = MNIST(root="./dataset/MNIST", train=True, download=True)
# mnist_test_dataset = MNIST(root="./dataset/MNIST", train=False, download=True)
sequential_sampler = SequentialSampler(dataset=mnist_train_dataset, batch_size=4)
mnist_train_dataloader = DataLoader(
dataset=mnist_train_dataset,
sampler=sequential_sampler,
)
for i, batch_sample in enumerate(mnist_train_dataloader):
batch_image, batch_label = batch_sample[0], batch_sample[1]
# 下面可以将 batch_image, batch_label 传递给网络做训练,这里省略
# trainging code ...
# 中断
break
print("The shape of minibatch is: {}".format(batch_image.shape))
# 导入可视化 Python 库,若没有,安装
import matplotlib.pyplot as plt
def show(batch_image, batch_label):
for i in range(4):
plt.subplot(1, 4, i+1)
plt.imshow(batch_image[i][:,:,-1], cmap='gray')
plt.xticks([])
plt.yticks([])
plt.title("label: {}".format(batch_label[i]))
plt.show()
# 可视化数据
show(batch_image, batch_label)
输出:
The shape of minibatch is: (4, 28, 28, 1)
可视化第一批 MNIST 数据:
图2
然后,构建一个做 RandomResizedCrop transform 的 MNIST DataLoader,并查看此时第一个 minibatch 的图片。
# 导入 MegEngine 已支持的一些数据增强操作
from megengine.data.transform import RandomResizedCrop
dataloader = DataLoader(
mnist_train_dataset,
sampler=sequential_sampler,
# 指定随机裁剪后的图片的输出size
transform=RandomResizedCrop(output_size=28),
)
for i, batch_sample in enumerate(dataloader):
batch_image, batch_label = batch_sample[0], batch_sample[1]
break
show(batch_image, batch_label)
可视化第一个批数据:
图3
可以看到,此时图片经过了随机裁剪并 resize 回原尺寸。
组合变换(Compose Transform)
经常需要做一系列数据变换。比如:
- 数据归一化:可以通过 Transform 中提供的 Normalize 类来实现;
- Pad:对图片的每条边补零以增大图片尺寸,通过 Pad 类来实现;
- 维度转换:将 (Batch-size, Hight, Width, Channel) 维度的 minibatch 转换为 (Batch-size, Channel, Hight, Width)(因为这是 MegEngine 支持的数据格式),通过 ToMode 类来实现;
- 其它的转换操作
为了方便使用,MegEngine 中的 Compose 类允许组合多个 Transform 并传递给 DataLoader 的 transform 参数。
接下来通过 Compose 类将之前的 RandomResizedCrop 操作与 Normalize 、 Pad 和 ToMode 操作组合起来, 实现多种数据转换操作的混合使用。运行如下代码查看转换 minibatch 的维度信息。
from megengine.data.transform import RandomResizedCrop, Normalize, ToMode, Pad, Compose
# 利用 Compose 组合多个 Transform 操作
dataloader = DataLoader(
mnist_train_dataset,
sampler=sequential_sampler,
transform=Compose([
RandomResizedCrop(output_size=28),
# mean 和 std 分别是 MNIST 数据的均值和标准差,图片数值范围是 0~255
Normalize(mean=0.1307*255, std=0.3081*255),
Pad(2),
# 'CHW'表示把图片由 (height, width, channel) 格式转换成 (channel, height, width) 格式
ToMode('CHW'),
])
)
for i, batch_sample in enumerate(dataloader):
batch_image, batch_label = batch_sample[0], batch_sample[1]
break
print("The shape of the batch is now: {}".format(batch_image.shape))
输出:
The shape of the batch is now: (4, 1, 32, 32)
可以看到,此时 minibatch 数据的 channel 维换了位置,且图片尺寸变为32。
DataLoader 中其他参数的用法请参考 DataLoader 文档。
旷视MegEngine数据加载与处理的更多相关文章
- 旷视MegEngine核心技术升级
旷视MegEngine核心技术升级 7 月 11 日,旷视研究院在 2020 WAIC · 开发者日「深度学习框架与技术生态论坛」上围绕 6 月底发布的天元深度学习框架(MegEngine)Beta ...
- 旷视MegEngine网络搭建
旷视MegEngine网络搭建 在 基本概念 中,介绍了计算图.张量和算子,神经网络可以看成一个计算图.在 MegEngine 中,按照计算图的拓扑结构,将张量和算子连接起来,即可完成对网络的搭建.M ...
- 旷视MegEngine基本概念
旷视MegEngine基本概念 MegEngine 是基于计算图的深度神经网络学习框架. 本文简要介绍计算图及其相关基本概念,以及它们在 MegEngine 中的实现. 计算图(Computation ...
- ScrollView嵌套ListView,GridView数据加载不全问题的解决
我们大家都知道ListView,GridView加载数据项,如果数据项过多时,就会显示滚动条.ScrollView组件里面只能包含一个组件,当ScrollView里面嵌套listView,GridVi ...
- python多种格式数据加载、处理与存储
多种格式数据加载.处理与存储 实际的场景中,我们会在不同的地方遇到各种不同的数据格式(比如大家熟悉的csv与txt,比如网页HTML格式,比如XML格式),我们来一起看看python如何和这些格式的数 ...
- flask+sqlite3+echarts3+ajax 异步数据加载
结构: /www | |-- /static |....|-- jquery-3.1.1.js |....|-- echarts.js(echarts3是单文件!!) | |-- /templates ...
- Entity Framework关联查询以及数据加载(延迟加载,预加载)
数据加载分为延迟加载和预加载 EF的关联实体加载有三种方式:Lazy Loading,Eager Loading,Explicit Loading,其中Lazy Loading和Explicit Lo ...
- JQuery插件:遮罩+数据加载中。。。(特点:遮你想遮,罩你想罩)
在很多项目中都会涉及到数据加载.数据加载有时可能会是2-3秒,为了给一个友好的提示,一般都会给一个[数据加载中...]的提示.今天就做了一个这样的提示框. 先去jQuery官网看看怎么写jQuery插 ...
- 如何评估ETL的数据加载时间
简述如何评估大型ETL数据加载时间. 答:评估一个大型的ETL的数据加载时间是一件很复杂的事情.数据加载分为两类,一类是初次加载,另一类是增量加载. 在数据仓库正式投入使用时,需要进行一次初次加载,而 ...
随机推荐
- Laravel 定时任务 任务调度 可手动执行
1.创建一个命令 php artisan make:command TestCommand 执行成功后会提示: Console command created successfully. 生成了一个新 ...
- Linux-鸟菜-0-计算机概论
Linux-鸟菜-0-计算机概论 这一章在说计算机概论,额....,总的来说看完之后还是有点收获,回忆了下计算机基本知识.没有什么可上手操作的东西,全是概念,直接把最后的总结给截图过来吧,因为概念的话 ...
- 【JavaScript】Leetcode每日一题-移除元素
[JavaScript]Leetcode每日一题-移除元素 [题目描述] 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度. 不要使用 ...
- 安装和简单使用apidoc
安装nodejs 参考链接 安装apidoc 参考链接 使用 https://www.bilibili.com/video/BV1MW411Q7g4 https://www.bilibili.com/ ...
- vue 2.9.6升级到最新版本
在看文档https://cli.vuejs.org/zh/guide/installation.html中,按步骤升级vue: 于是就先通过 npm uninstall vue-cli -g卸载vue ...
- StreamReader & StreamWriter
这节讲StreamReader & StreamWriter,这两个类用于操作字符或者字符串,它将流的操作封装在了底层,相对来说用法比较简单,但是它不支持Seek()方法. 先看一下代码: F ...
- 浅入浅出 MySQL 索引
简单了解索引 首先,索引(Index)是什么?如果我直接告诉你索引是数据库管理系统中的一个有序的数据结构,你可能会有点懵逼. 为了避免这种情况,我打算举几个例子来帮助你更容易的认识索引. 我们查询字典 ...
- 【Web前端HTML5&CSS3】06-盒模型
笔记来源:尚硅谷Web前端HTML5&CSS3初学者零基础入门全套完整版 目录 盒模型 1. 文档流(normalflow) 2. 块元素 3. 行内元素 4. 盒子模型 盒模型.盒子模型.框 ...
- (五)Jira Api对接:修改任务状态
项目迭代结束后我们需要把sprint下面的story.task任务状态修改到结束状态,如果手动修改会花费不少时间,本文就介绍如何通过jira api自动修改任务状态,提高工作效率. 一.查看任务工作流 ...
- laravel 伪静态实现
Route::get('show{id}.html',['as'=>'products.detail','uses'=>'companyController@show']) ->wh ...