MXnet的设计结构是C++做后端运算,python、R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面。今天我们先谈一谈Data iterators。

  MXnet中的data iterator和python中的迭代器是很相似的, 当其内置方法next被call的时候它每次返回一个 data batch。所谓databatch,就是神经网络的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。直接上官网上的一个简单的例子来说说吧。

  

 import numpy as np
class SimpleIter:
def __init__(self, data_names, data_shapes, data_gen,
label_names, label_shapes, label_gen, num_batches=10):
self._provide_data = zip(data_names, data_shapes)
self._provide_label = zip(label_names, label_shapes)
self.num_batches = num_batches
self.data_gen = data_gen
self.label_gen = label_gen
self.cur_batch = 0 def __iter__(self):
return self def reset(self):
self.cur_batch = 0 def __next__(self):
return self.next() @property
def provide_data(self):
return self._provide_data @property
def provide_label(self):
return self._provide_label def next(self):
if self.cur_batch < self.num_batches:
self.cur_batch += 1
data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
assert len(data) > 0, "Empty batch data."
label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
assert len(label) > 0, "Empty batch label."
return SimpleBatch(data, label)
else:
raise StopIteration

  上面的代码是最简单的一个dataiter了,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个dataiter必须要实现上面的几个方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的,next()的方法就很显然了,用来返回你的databatch,如果出现问题...记得raise stopIteration,这里或许用try更好吧...需要注意的是,databatch返回的数据类型是mx.nd.ndarry。

  下面是我最近做segmentation的时候用的一个稍微复杂的dataiter,多了预处理和shuffle等步骤:

  

 # pylint: skip-file
import random import cv2
import mxnet as mx
import numpy as np
import os
from mxnet.io import DataIter, DataBatch class FileIter(DataIter): #一般都是继承DataIter
"""FileIter object in fcn-xs example. Taking a file list file to get dataiter.
in this example, we use the whole image training for fcn-xs, that is to say
we do not need resize/crop the image to the same size, so the batch_size is
set to 1 here
Parameters
----------
root_dir : string
the root dir of image/label lie in
flist_name : string
the list file of iamge and label, every line owns the form:
index \t image_data_path \t image_label_path
cut_off_size : int
if the maximal size of one image is larger than cut_off_size, then it will
crop the image with the minimal size of that image
data_name : string
the data name used in symbol data(default data name)
label_name : string
the label name used in symbol softmax_label(default label name)
""" def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117),
data_name="data", label_name="softmax_label", p=None):
super(FileIter, self).__init__() self.fac = p.fac #这里的P是自己定义的config
self.root_dir = root_dir
self.flist_name = os.path.join(self.root_dir, flist_name)
self.mean = np.array(rgb_mean) # (R, G, B)
self.data_name = data_name
self.label_name = label_name
self.batch_size = p.batch_size
self.random_crop = p.random_crop
self.random_flip = p.random_flip
self.random_color = p.random_color
self.random_scale = p.random_scale
self.output_size = p.output_size
self.color_aug_range = p.color_aug_range
self.use_rnn = p.use_rnn
self.num_hidden = p.num_hidden
if self.use_rnn:
self.init_h_name = 'init_h'
self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden))
self.cursor = -1 self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
self.data_list = []
self.label_list = []
self.order = []
self.dict = {}
lines = file(self.flist_name).read().splitlines()
cnt = 0
for line in lines: #读取lst,为后面读取图片做好准备
_, data_img_name, label_img_name = line.strip('\n').split("\t")
self.data_list.append(data_img_name)
self.label_list.append(label_img_name)
self.order.append(cnt)
cnt += 1
self.num_data = cnt
self._shuffle() def _shuffle(self):
random.shuffle(self.order) def _read_img(self, img_name, label_name):
     # 这个是在服务器上跑的时候,因为数据集很小,而且经常被同事卡IO,所以我就把数据全部放进了内存
if os.path.join(self.root_dir, img_name) in self.dict:
img = self.dict[os.path.join(self.root_dir, img_name)]
else:
img = cv2.imread(os.path.join(self.root_dir, img_name))
self.dict[os.path.join(self.root_dir, img_name)] = img if os.path.join(self.root_dir, label_name) in self.dict:
label = self.dict[os.path.join(self.root_dir, label_name)]
else:
label = cv2.imread(os.path.join(self.root_dir, label_name),0)
self.dict[os.path.join(self.root_dir, label_name)] = label      # 下面是读取图片后的一系统预处理工作
if self.random_flip:
flip = random.randint(0, 1)
if flip == 1:
img = cv2.flip(img, 1)
label = cv2.flip(label, 1)
# scale jittering
scale = random.uniform(self.random_scale[0], self.random_scale[1])
new_width = int(img.shape[1] * scale) #
new_height = int(img.shape[0] * scale) # new_width * img.size[1] / img.size[0]
img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
#img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST)
#label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST)
if self.random_crop:
start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1)
start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1)
img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :]
label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]]
if self.random_color:
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0])
sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1])
val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2])
img = np.array(img, dtype=np.float32)
img[..., 0] += hue
img[..., 1] += sat
img[..., 2] += val
img[..., 0] = np.clip(img[..., 0], 0, 255)
img[..., 1] = np.clip(img[..., 1], 0, 255)
img[..., 2] = np.clip(img[..., 2], 0, 255)
img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR)
is_rgb = True
#cv2.imshow('main', img)
#cv2.waitKey()
#cv2.imshow('maain', label)
#cv2.waitKey()
img = np.array(img, dtype=np.float32) # (h, w, c)
reshaped_mean = self.mean.reshape(1, 1, 3)
img = img - reshaped_mean
img[:, :, :] = img[:, :, [2, 1, 0]]
img = img.transpose(2, 0, 1)
# img = np.expand_dims(img, axis=0) # (1, c, h, w) label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac)
label_zoomed = label_zoomed.astype('uint8')
return (img, label_zoomed) @property
def provide_data(self):
"""The name and shape of data provided by this iterator"""
if self.use_rnn:
return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])),
(self.init_h_name, (self.batch_size, self.num_hidden))]
else:
return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))] @property
def provide_label(self):
"""The name and shape of label provided by this iterator"""
return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))] def get_batch_size(self):
return self.batch_size def reset(self):
self.cursor = -self.batch_size
self._shuffle() def iter_next(self):
self.cursor += self.batch_size
return self.cursor < self.num_data def _getpad(self):
if self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
else:
return 0 def _getdata(self):
"""Load data from underlying arrays, internal use only"""
assert(self.cursor < self.num_data), "DataIter needs reset."
data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
if self.cursor + self.batch_size <= self.num_data:
for i in range(self.batch_size):
idx = self.order[self.cursor + i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i] = data_
label[i] = label_
else:
for i in range(self.num_data - self.cursor):
idx = self.order[self.cursor + i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i] = data_
label[i] = label_
pad = self.batch_size - self.num_data + self.cursor
#for i in pad:
for i in range(pad):
idx = self.order[i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i + self.num_data - self.cursor] = data_
label[i + self.num_data - self.cursor] = label_
return mx.nd.array(data), mx.nd.array(label) def next(self):
"""return one dict which contains "data" and "label" """
if self.iter_next():
data, label = self._getdata()
data = [data, self.init_h] if self.use_rnn else [data]
label = [label]
return DataBatch(data=data, label=label,
pad=self._getpad(), index=None,
provide_data=self.provide_data,
provide_label=self.provide_label)
else:
raise StopIteration

    到这里基本上正常的训练我们就可以开始了,但是当你有了很多新的想法的时候,你又会遇到新的问题...比如:multi input/output怎么办?

    其实也很简单,只需要修改几个地方:

      1、provide_label和provide_data,注意到之前我们的return都是一个list,所以之间在里面添加和之前一样的格式就行了。

      2. next() 如果你需要传 data和depth两个输入,只需要传 input = sum([[data],[depth],[]])到databatch的data就行了,label也同理。

    值得一提的时候,MXnet的multi loss实现起来需要在写network的symbol的时候注意一点,假设你有softmax_loss和regression_loss。那么只要在最后return mx.symbol.Group([softmax_loss, regression_loss])。

    总之......That's all~~~~

  

从零开始学习MXnet(二)之dataiter的更多相关文章

  1. 从零开始学习jQuery (二) 万能的选择器

    本系列文章导航 从零开始学习jQuery (二) 万能的选择器 一.摘要 本章讲解jQuery最重要的选择器部分的知识. 有了jQuery的选择器我们几乎可以获取页面上任意的一个或一组对象, 可以明显 ...

  2. 从零开始学习MXnet(四)计算图和粗细粒度以及自动求导

    这篇其实跟使用MXnet的关系不大,但对于我们理解深度学习的框架设计还是很有帮助的. 首先还是对promgramming models的一个简单介绍,这个东西实际上是在编译里面经常出现的东西,我们在编 ...

  3. 从零开始学习MXnet(三)之Model和Module

    在我们在MXnet中定义好symbol.写好dataiter并且准备好data之后,就可以开开心的去训练了.一般训练一个网络有两种常用的策略,基于model的和基于module的.今天,我想谈一谈他们 ...

  4. 从零开始学习Android(二)从架构开始说起

    我们刚开始学新东西的时候,往往希望能从一个实例进行入手学习.接下来的系列连载文章也主要是围绕这个实例进行.这个实例原形是从电子书<Android应用开发详解>得到的,我们在这里对其进行详细 ...

  5. 从零开始学习MXnet(五)MXnet的黑科技之显存节省大法

    写完发现名字有点拗口..- -# 大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉 ...

  6. 从零开始学习MXnet(一)

    最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家. 我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步. 一 ...

  7. oracle从零开始学习笔记 二

    多表查询 等值连接(Equijoin) select ename,empno,sal,emp.deptno from emp,dept where dept.deptno=emp.deptno; 非等 ...

  8. 从零开始学习Vue(二)

    思维方式的变化 WebForm时代, Aspx.cs 取得数据,绑定到前台的Repeater之类的控件.重新渲染整个HTML页面.就是整个页面不断的刷新;后来微软打了个补丁,推出了AJAX控件,比如U ...

  9. 从零开始学习jQuery(转)

    本系列文章导航 从零开始学习jQuery (一) 开天辟地入门篇 从零开始学习jQuery (二) 万能的选择器 从零开始学习jQuery (三) 管理jQuery包装集 从零开始学习jQuery ( ...

随机推荐

  1. Vue 去脚手架插件,自动加载vue文件

    接上回 一些本质 本质上,去脚手架也好,读取vue文件也好,无非是维护options,每个Vue对象的初始化配置对象不触及Vue内部而言,在外部想怎么改都是可以的,只要保证options的正确,一切都 ...

  2. CSS3实现加载数据动画2

    <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  3. windows系统下npm升级的正确姿势以及原理

    本文来自网易云社区 作者:陈观喜 网上关于npm升级很多方法多种多样,但是在windows系统下不是每种方法都会正确升级.其中在windows系统下主要的升级方法有以下三种: 首先最暴力的方法删掉no ...

  4. runtime总结 iOS

    Runtime的特性主要是消息(方法)传递,如果消息(方法)在对象中找不到,就进行转发,具体怎么实现的呢.我们从下面几个方面探寻Runtime的实现机制. Runtime介绍 Runtime消息传递 ...

  5. 修改有数据oracle字段类型 从number转为varchar

    --修改有数据oracle字段类型 从number转为varchar--例:修改ta_sp_org_invoice表中RESCUE_PHONE字段类型,从number转为varchar --step1 ...

  6. WOW.js 的使用方法

    WOW.js 是一个非常轻量级的动画效果插件,使用它可以组合多种炫酷的效果. 使用WOW.js可以实现我们在网站上常看到的,页面滚动到指定区域时就显示动画的效果. 1.要使用WOW.js必须引入:WO ...

  7. centos7使用Gogs搭建Git服务器

    一.初次接触Gogs,记录一下搭建过程 二.平台环境 Linux: CentOS7.5.1804 MySQL: 5.6.35 安装步骤: linux服务器新建git用户: 下载.解压gogs安装包: ...

  8. LeetCode 206——反转链表

    对单链表进行反转有迭代法和递归法两种. 1. 迭代法 迭代法从前往后遍历链表,定义三个指针分别指向相邻的三个结点,反转前两个结点,即让第二个结点指向第一个结点.然后依次往后移动指针,直到第二个结点为空 ...

  9. 可以随着SeekBar滑块滑动显示的Demo

    //关于Seek的自定义样式,之前也有总结过,但是,一直做不出随着滑块移动的效果,查询了很多资料终于解决了这个问题,现在把代码写出来有bug的地方 希望大家批评指正. Step 1 :自定义一个Vie ...

  10. hibernate延时加载机制

    延迟加载: 延迟加载机制是为了避免一些无谓的性能开销而提出来的,所谓延迟加载就是当在真正需要数据的时候,才真正执行数据加载操作.在Hibernate中提供了对实体对象的延迟加载以及对集合的延迟加载,另 ...