MXnet的设计结构是C++做后端运算,python、R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面。今天我们先谈一谈Data iterators。

  MXnet中的data iterator和python中的迭代器是很相似的, 当其内置方法next被call的时候它每次返回一个 data batch。所谓databatch,就是神经网络的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。直接上官网上的一个简单的例子来说说吧。

  

 import numpy as np
class SimpleIter:
def __init__(self, data_names, data_shapes, data_gen,
label_names, label_shapes, label_gen, num_batches=10):
self._provide_data = zip(data_names, data_shapes)
self._provide_label = zip(label_names, label_shapes)
self.num_batches = num_batches
self.data_gen = data_gen
self.label_gen = label_gen
self.cur_batch = 0 def __iter__(self):
return self def reset(self):
self.cur_batch = 0 def __next__(self):
return self.next() @property
def provide_data(self):
return self._provide_data @property
def provide_label(self):
return self._provide_label def next(self):
if self.cur_batch < self.num_batches:
self.cur_batch += 1
data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
assert len(data) > 0, "Empty batch data."
label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
assert len(label) > 0, "Empty batch label."
return SimpleBatch(data, label)
else:
raise StopIteration

  上面的代码是最简单的一个dataiter了,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个dataiter必须要实现上面的几个方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的,next()的方法就很显然了,用来返回你的databatch,如果出现问题...记得raise stopIteration,这里或许用try更好吧...需要注意的是,databatch返回的数据类型是mx.nd.ndarry。

  下面是我最近做segmentation的时候用的一个稍微复杂的dataiter,多了预处理和shuffle等步骤:

  

 # pylint: skip-file
import random import cv2
import mxnet as mx
import numpy as np
import os
from mxnet.io import DataIter, DataBatch class FileIter(DataIter): #一般都是继承DataIter
"""FileIter object in fcn-xs example. Taking a file list file to get dataiter.
in this example, we use the whole image training for fcn-xs, that is to say
we do not need resize/crop the image to the same size, so the batch_size is
set to 1 here
Parameters
----------
root_dir : string
the root dir of image/label lie in
flist_name : string
the list file of iamge and label, every line owns the form:
index \t image_data_path \t image_label_path
cut_off_size : int
if the maximal size of one image is larger than cut_off_size, then it will
crop the image with the minimal size of that image
data_name : string
the data name used in symbol data(default data name)
label_name : string
the label name used in symbol softmax_label(default label name)
""" def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117),
data_name="data", label_name="softmax_label", p=None):
super(FileIter, self).__init__() self.fac = p.fac #这里的P是自己定义的config
self.root_dir = root_dir
self.flist_name = os.path.join(self.root_dir, flist_name)
self.mean = np.array(rgb_mean) # (R, G, B)
self.data_name = data_name
self.label_name = label_name
self.batch_size = p.batch_size
self.random_crop = p.random_crop
self.random_flip = p.random_flip
self.random_color = p.random_color
self.random_scale = p.random_scale
self.output_size = p.output_size
self.color_aug_range = p.color_aug_range
self.use_rnn = p.use_rnn
self.num_hidden = p.num_hidden
if self.use_rnn:
self.init_h_name = 'init_h'
self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden))
self.cursor = -1 self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
self.data_list = []
self.label_list = []
self.order = []
self.dict = {}
lines = file(self.flist_name).read().splitlines()
cnt = 0
for line in lines: #读取lst,为后面读取图片做好准备
_, data_img_name, label_img_name = line.strip('\n').split("\t")
self.data_list.append(data_img_name)
self.label_list.append(label_img_name)
self.order.append(cnt)
cnt += 1
self.num_data = cnt
self._shuffle() def _shuffle(self):
random.shuffle(self.order) def _read_img(self, img_name, label_name):
     # 这个是在服务器上跑的时候,因为数据集很小,而且经常被同事卡IO,所以我就把数据全部放进了内存
if os.path.join(self.root_dir, img_name) in self.dict:
img = self.dict[os.path.join(self.root_dir, img_name)]
else:
img = cv2.imread(os.path.join(self.root_dir, img_name))
self.dict[os.path.join(self.root_dir, img_name)] = img if os.path.join(self.root_dir, label_name) in self.dict:
label = self.dict[os.path.join(self.root_dir, label_name)]
else:
label = cv2.imread(os.path.join(self.root_dir, label_name),0)
self.dict[os.path.join(self.root_dir, label_name)] = label      # 下面是读取图片后的一系统预处理工作
if self.random_flip:
flip = random.randint(0, 1)
if flip == 1:
img = cv2.flip(img, 1)
label = cv2.flip(label, 1)
# scale jittering
scale = random.uniform(self.random_scale[0], self.random_scale[1])
new_width = int(img.shape[1] * scale) #
new_height = int(img.shape[0] * scale) # new_width * img.size[1] / img.size[0]
img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
#img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST)
#label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST)
if self.random_crop:
start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1)
start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1)
img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :]
label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]]
if self.random_color:
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0])
sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1])
val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2])
img = np.array(img, dtype=np.float32)
img[..., 0] += hue
img[..., 1] += sat
img[..., 2] += val
img[..., 0] = np.clip(img[..., 0], 0, 255)
img[..., 1] = np.clip(img[..., 1], 0, 255)
img[..., 2] = np.clip(img[..., 2], 0, 255)
img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR)
is_rgb = True
#cv2.imshow('main', img)
#cv2.waitKey()
#cv2.imshow('maain', label)
#cv2.waitKey()
img = np.array(img, dtype=np.float32) # (h, w, c)
reshaped_mean = self.mean.reshape(1, 1, 3)
img = img - reshaped_mean
img[:, :, :] = img[:, :, [2, 1, 0]]
img = img.transpose(2, 0, 1)
# img = np.expand_dims(img, axis=0) # (1, c, h, w) label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac)
label_zoomed = label_zoomed.astype('uint8')
return (img, label_zoomed) @property
def provide_data(self):
"""The name and shape of data provided by this iterator"""
if self.use_rnn:
return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])),
(self.init_h_name, (self.batch_size, self.num_hidden))]
else:
return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))] @property
def provide_label(self):
"""The name and shape of label provided by this iterator"""
return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))] def get_batch_size(self):
return self.batch_size def reset(self):
self.cursor = -self.batch_size
self._shuffle() def iter_next(self):
self.cursor += self.batch_size
return self.cursor < self.num_data def _getpad(self):
if self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
else:
return 0 def _getdata(self):
"""Load data from underlying arrays, internal use only"""
assert(self.cursor < self.num_data), "DataIter needs reset."
data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
if self.cursor + self.batch_size <= self.num_data:
for i in range(self.batch_size):
idx = self.order[self.cursor + i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i] = data_
label[i] = label_
else:
for i in range(self.num_data - self.cursor):
idx = self.order[self.cursor + i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i] = data_
label[i] = label_
pad = self.batch_size - self.num_data + self.cursor
#for i in pad:
for i in range(pad):
idx = self.order[i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i + self.num_data - self.cursor] = data_
label[i + self.num_data - self.cursor] = label_
return mx.nd.array(data), mx.nd.array(label) def next(self):
"""return one dict which contains "data" and "label" """
if self.iter_next():
data, label = self._getdata()
data = [data, self.init_h] if self.use_rnn else [data]
label = [label]
return DataBatch(data=data, label=label,
pad=self._getpad(), index=None,
provide_data=self.provide_data,
provide_label=self.provide_label)
else:
raise StopIteration

    到这里基本上正常的训练我们就可以开始了,但是当你有了很多新的想法的时候,你又会遇到新的问题...比如:multi input/output怎么办?

    其实也很简单,只需要修改几个地方:

      1、provide_label和provide_data,注意到之前我们的return都是一个list,所以之间在里面添加和之前一样的格式就行了。

      2. next() 如果你需要传 data和depth两个输入,只需要传 input = sum([[data],[depth],[]])到databatch的data就行了,label也同理。

    值得一提的时候,MXnet的multi loss实现起来需要在写network的symbol的时候注意一点,假设你有softmax_loss和regression_loss。那么只要在最后return mx.symbol.Group([softmax_loss, regression_loss])。

    总之......That's all~~~~

  

从零开始学习MXnet(二)之dataiter的更多相关文章

  1. 从零开始学习jQuery (二) 万能的选择器

    本系列文章导航 从零开始学习jQuery (二) 万能的选择器 一.摘要 本章讲解jQuery最重要的选择器部分的知识. 有了jQuery的选择器我们几乎可以获取页面上任意的一个或一组对象, 可以明显 ...

  2. 从零开始学习MXnet(四)计算图和粗细粒度以及自动求导

    这篇其实跟使用MXnet的关系不大,但对于我们理解深度学习的框架设计还是很有帮助的. 首先还是对promgramming models的一个简单介绍,这个东西实际上是在编译里面经常出现的东西,我们在编 ...

  3. 从零开始学习MXnet(三)之Model和Module

    在我们在MXnet中定义好symbol.写好dataiter并且准备好data之后,就可以开开心的去训练了.一般训练一个网络有两种常用的策略,基于model的和基于module的.今天,我想谈一谈他们 ...

  4. 从零开始学习Android(二)从架构开始说起

    我们刚开始学新东西的时候,往往希望能从一个实例进行入手学习.接下来的系列连载文章也主要是围绕这个实例进行.这个实例原形是从电子书<Android应用开发详解>得到的,我们在这里对其进行详细 ...

  5. 从零开始学习MXnet(五)MXnet的黑科技之显存节省大法

    写完发现名字有点拗口..- -# 大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉 ...

  6. 从零开始学习MXnet(一)

    最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家. 我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步. 一 ...

  7. oracle从零开始学习笔记 二

    多表查询 等值连接(Equijoin) select ename,empno,sal,emp.deptno from emp,dept where dept.deptno=emp.deptno; 非等 ...

  8. 从零开始学习Vue(二)

    思维方式的变化 WebForm时代, Aspx.cs 取得数据,绑定到前台的Repeater之类的控件.重新渲染整个HTML页面.就是整个页面不断的刷新;后来微软打了个补丁,推出了AJAX控件,比如U ...

  9. 从零开始学习jQuery(转)

    本系列文章导航 从零开始学习jQuery (一) 开天辟地入门篇 从零开始学习jQuery (二) 万能的选择器 从零开始学习jQuery (三) 管理jQuery包装集 从零开始学习jQuery ( ...

随机推荐

  1. Linux 系统无法登录?你的程序有问题吧!

    今天遇到一个问题,有个用户连接不上服务器(无法ssh远程连接) su: failed to execute /bin/bash: Resource temporarily unavailable 谷歌 ...

  2. BAT批处理

    常用命令 查看目录内容命令dir 指定可执行文件搜索目录path 创建目录命令md 打开指定目录命令cd 删除当前指定的子目录命令rd 改变当前盘符命令d: 文件复制命令copy 显示文本文件内容命令 ...

  3. C# 面试题 (一)

    一.C# 理论 1.1.简述 private. protected. public. internal.protected internal 访问修饰符和访问权限 private : 私有成员, 在类 ...

  4. linux进程 生产者消费者

    #include<stdio.h> #include<unistd.h> #include<stdlib.h> #include<string.h> # ...

  5. HBase-site.xml 常见重要配置参数

    HBase 常见重要配置参数 (1) Hbase.rpc.timeout rpc 的超时时间,默认 60s,不建议修改,避免影响正常的业务,在线上环境刚开始配置的是 3 秒,运行半天后发现了大量的 t ...

  6. Grok Debugger本地安装(转载)

    原文链接:http://fengwan.blog.51cto.com/508652/1758845 最近在使用ELK对日志进行集中管理,因为涉及到日志的规则经常要用到http://grokdebug. ...

  7. Hackerrank - [Algo] Matrix Rotation

    https://www.hackerrank.com/challenges/matrix-rotation-algo 又是一道耗了两小时以上的题,做完了才想起来,这不就是几年前在POJ上做过的一个同类 ...

  8. C++类数组批量赋值

    类和结构体不同,结构体在初始化时可以使用{...}的方法全部赋值,但是结构体怎么办呢?一种是把数据数组写到一个相同的结构体内,然后for循环使用一个非构造函数写入到类数组中.另一种方法是直接写入到对应 ...

  9. wpf显示视频,image控件闪屏,使用winform控件实现

    使用C#调用mingw的动态库实现视频识别软件,程序通过C++调用opencv打开视频,将图像的原始数据以rgb24的方式传递给C#端,C#通过构造图像对象给控件赋值的方式显示图片. 一开始使用wpf ...

  10. Python杂篇

    一:文件保存 def save_to_file(file_name, contents): fh = open(file_name, 'w') fh.write(contents) fh.close( ...