MXnet的设计结构是C++做后端运算,python、R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面。今天我们先谈一谈Data iterators。

  MXnet中的data iterator和python中的迭代器是很相似的, 当其内置方法next被call的时候它每次返回一个 data batch。所谓databatch,就是神经网络的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。直接上官网上的一个简单的例子来说说吧。

  

 import numpy as np
class SimpleIter:
def __init__(self, data_names, data_shapes, data_gen,
label_names, label_shapes, label_gen, num_batches=10):
self._provide_data = zip(data_names, data_shapes)
self._provide_label = zip(label_names, label_shapes)
self.num_batches = num_batches
self.data_gen = data_gen
self.label_gen = label_gen
self.cur_batch = 0 def __iter__(self):
return self def reset(self):
self.cur_batch = 0 def __next__(self):
return self.next() @property
def provide_data(self):
return self._provide_data @property
def provide_label(self):
return self._provide_label def next(self):
if self.cur_batch < self.num_batches:
self.cur_batch += 1
data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
assert len(data) > 0, "Empty batch data."
label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
assert len(label) > 0, "Empty batch label."
return SimpleBatch(data, label)
else:
raise StopIteration

  上面的代码是最简单的一个dataiter了,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个dataiter必须要实现上面的几个方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的,next()的方法就很显然了,用来返回你的databatch,如果出现问题...记得raise stopIteration,这里或许用try更好吧...需要注意的是,databatch返回的数据类型是mx.nd.ndarry。

  下面是我最近做segmentation的时候用的一个稍微复杂的dataiter,多了预处理和shuffle等步骤:

  

 # pylint: skip-file
import random import cv2
import mxnet as mx
import numpy as np
import os
from mxnet.io import DataIter, DataBatch class FileIter(DataIter): #一般都是继承DataIter
"""FileIter object in fcn-xs example. Taking a file list file to get dataiter.
in this example, we use the whole image training for fcn-xs, that is to say
we do not need resize/crop the image to the same size, so the batch_size is
set to 1 here
Parameters
----------
root_dir : string
the root dir of image/label lie in
flist_name : string
the list file of iamge and label, every line owns the form:
index \t image_data_path \t image_label_path
cut_off_size : int
if the maximal size of one image is larger than cut_off_size, then it will
crop the image with the minimal size of that image
data_name : string
the data name used in symbol data(default data name)
label_name : string
the label name used in symbol softmax_label(default label name)
""" def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117),
data_name="data", label_name="softmax_label", p=None):
super(FileIter, self).__init__() self.fac = p.fac #这里的P是自己定义的config
self.root_dir = root_dir
self.flist_name = os.path.join(self.root_dir, flist_name)
self.mean = np.array(rgb_mean) # (R, G, B)
self.data_name = data_name
self.label_name = label_name
self.batch_size = p.batch_size
self.random_crop = p.random_crop
self.random_flip = p.random_flip
self.random_color = p.random_color
self.random_scale = p.random_scale
self.output_size = p.output_size
self.color_aug_range = p.color_aug_range
self.use_rnn = p.use_rnn
self.num_hidden = p.num_hidden
if self.use_rnn:
self.init_h_name = 'init_h'
self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden))
self.cursor = -1 self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
self.data_list = []
self.label_list = []
self.order = []
self.dict = {}
lines = file(self.flist_name).read().splitlines()
cnt = 0
for line in lines: #读取lst,为后面读取图片做好准备
_, data_img_name, label_img_name = line.strip('\n').split("\t")
self.data_list.append(data_img_name)
self.label_list.append(label_img_name)
self.order.append(cnt)
cnt += 1
self.num_data = cnt
self._shuffle() def _shuffle(self):
random.shuffle(self.order) def _read_img(self, img_name, label_name):
     # 这个是在服务器上跑的时候,因为数据集很小,而且经常被同事卡IO,所以我就把数据全部放进了内存
if os.path.join(self.root_dir, img_name) in self.dict:
img = self.dict[os.path.join(self.root_dir, img_name)]
else:
img = cv2.imread(os.path.join(self.root_dir, img_name))
self.dict[os.path.join(self.root_dir, img_name)] = img if os.path.join(self.root_dir, label_name) in self.dict:
label = self.dict[os.path.join(self.root_dir, label_name)]
else:
label = cv2.imread(os.path.join(self.root_dir, label_name),0)
self.dict[os.path.join(self.root_dir, label_name)] = label      # 下面是读取图片后的一系统预处理工作
if self.random_flip:
flip = random.randint(0, 1)
if flip == 1:
img = cv2.flip(img, 1)
label = cv2.flip(label, 1)
# scale jittering
scale = random.uniform(self.random_scale[0], self.random_scale[1])
new_width = int(img.shape[1] * scale) #
new_height = int(img.shape[0] * scale) # new_width * img.size[1] / img.size[0]
img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
#img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST)
#label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST)
if self.random_crop:
start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1)
start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1)
img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :]
label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]]
if self.random_color:
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0])
sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1])
val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2])
img = np.array(img, dtype=np.float32)
img[..., 0] += hue
img[..., 1] += sat
img[..., 2] += val
img[..., 0] = np.clip(img[..., 0], 0, 255)
img[..., 1] = np.clip(img[..., 1], 0, 255)
img[..., 2] = np.clip(img[..., 2], 0, 255)
img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR)
is_rgb = True
#cv2.imshow('main', img)
#cv2.waitKey()
#cv2.imshow('maain', label)
#cv2.waitKey()
img = np.array(img, dtype=np.float32) # (h, w, c)
reshaped_mean = self.mean.reshape(1, 1, 3)
img = img - reshaped_mean
img[:, :, :] = img[:, :, [2, 1, 0]]
img = img.transpose(2, 0, 1)
# img = np.expand_dims(img, axis=0) # (1, c, h, w) label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac)
label_zoomed = label_zoomed.astype('uint8')
return (img, label_zoomed) @property
def provide_data(self):
"""The name and shape of data provided by this iterator"""
if self.use_rnn:
return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])),
(self.init_h_name, (self.batch_size, self.num_hidden))]
else:
return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))] @property
def provide_label(self):
"""The name and shape of label provided by this iterator"""
return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))] def get_batch_size(self):
return self.batch_size def reset(self):
self.cursor = -self.batch_size
self._shuffle() def iter_next(self):
self.cursor += self.batch_size
return self.cursor < self.num_data def _getpad(self):
if self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
else:
return 0 def _getdata(self):
"""Load data from underlying arrays, internal use only"""
assert(self.cursor < self.num_data), "DataIter needs reset."
data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
if self.cursor + self.batch_size <= self.num_data:
for i in range(self.batch_size):
idx = self.order[self.cursor + i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i] = data_
label[i] = label_
else:
for i in range(self.num_data - self.cursor):
idx = self.order[self.cursor + i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i] = data_
label[i] = label_
pad = self.batch_size - self.num_data + self.cursor
#for i in pad:
for i in range(pad):
idx = self.order[i]
data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
data[i + self.num_data - self.cursor] = data_
label[i + self.num_data - self.cursor] = label_
return mx.nd.array(data), mx.nd.array(label) def next(self):
"""return one dict which contains "data" and "label" """
if self.iter_next():
data, label = self._getdata()
data = [data, self.init_h] if self.use_rnn else [data]
label = [label]
return DataBatch(data=data, label=label,
pad=self._getpad(), index=None,
provide_data=self.provide_data,
provide_label=self.provide_label)
else:
raise StopIteration

    到这里基本上正常的训练我们就可以开始了,但是当你有了很多新的想法的时候,你又会遇到新的问题...比如:multi input/output怎么办?

    其实也很简单,只需要修改几个地方:

      1、provide_label和provide_data,注意到之前我们的return都是一个list,所以之间在里面添加和之前一样的格式就行了。

      2. next() 如果你需要传 data和depth两个输入,只需要传 input = sum([[data],[depth],[]])到databatch的data就行了,label也同理。

    值得一提的时候,MXnet的multi loss实现起来需要在写network的symbol的时候注意一点,假设你有softmax_loss和regression_loss。那么只要在最后return mx.symbol.Group([softmax_loss, regression_loss])。

    总之......That's all~~~~

  

从零开始学习MXnet(二)之dataiter的更多相关文章

  1. 从零开始学习jQuery (二) 万能的选择器

    本系列文章导航 从零开始学习jQuery (二) 万能的选择器 一.摘要 本章讲解jQuery最重要的选择器部分的知识. 有了jQuery的选择器我们几乎可以获取页面上任意的一个或一组对象, 可以明显 ...

  2. 从零开始学习MXnet(四)计算图和粗细粒度以及自动求导

    这篇其实跟使用MXnet的关系不大,但对于我们理解深度学习的框架设计还是很有帮助的. 首先还是对promgramming models的一个简单介绍,这个东西实际上是在编译里面经常出现的东西,我们在编 ...

  3. 从零开始学习MXnet(三)之Model和Module

    在我们在MXnet中定义好symbol.写好dataiter并且准备好data之后,就可以开开心的去训练了.一般训练一个网络有两种常用的策略,基于model的和基于module的.今天,我想谈一谈他们 ...

  4. 从零开始学习Android(二)从架构开始说起

    我们刚开始学新东西的时候,往往希望能从一个实例进行入手学习.接下来的系列连载文章也主要是围绕这个实例进行.这个实例原形是从电子书<Android应用开发详解>得到的,我们在这里对其进行详细 ...

  5. 从零开始学习MXnet(五)MXnet的黑科技之显存节省大法

    写完发现名字有点拗口..- -# 大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉 ...

  6. 从零开始学习MXnet(一)

    最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家. 我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步. 一 ...

  7. oracle从零开始学习笔记 二

    多表查询 等值连接(Equijoin) select ename,empno,sal,emp.deptno from emp,dept where dept.deptno=emp.deptno; 非等 ...

  8. 从零开始学习Vue(二)

    思维方式的变化 WebForm时代, Aspx.cs 取得数据,绑定到前台的Repeater之类的控件.重新渲染整个HTML页面.就是整个页面不断的刷新;后来微软打了个补丁,推出了AJAX控件,比如U ...

  9. 从零开始学习jQuery(转)

    本系列文章导航 从零开始学习jQuery (一) 开天辟地入门篇 从零开始学习jQuery (二) 万能的选择器 从零开始学习jQuery (三) 管理jQuery包装集 从零开始学习jQuery ( ...

随机推荐

  1. 数据库中pymysql模块的使用

    pymysql 模块 使用步骤: 核心类Connect链接用和Cursor读写用 1. 与数据库服务器建立链接 2. 获取游标对象(用于发送和接收数据) 3. 用游标执行sql语句 4. 使用fetc ...

  2. 阿里云异常网络连接-可疑WebShell通信行为的分析解决办法

    2018年10月27日接到新客户网站服务器被上传了webshell脚本木马后门问题的求助,对此我们sine安全公司针对此阿里云提示的安全问题进行了详细分析,ECS服务器被阿里云提示异常网络连接-可疑W ...

  3. DedeCMS V5.7sp2最新版本parse_str函数SQL注入漏洞

    织梦dedecms,在整个互联网中许多企业网站,个人网站,优化网站都在使用dede作为整个网站的开发架构,dedecms采用php+mysql数据库的架构来承载整个网站的运行与用户的访问,首页以及栏目 ...

  4. C语言:类型、运算符、表达式

    看了一天书,有点累了.就写写随笔记录一下今天的复习成果吧. C语言的基本数据类型 数值型:整型数,浮点数,布尔数,复数和虚数. 非数值型:字符. 整数最基本的是int,由此引出许多变式诸如有符号整数s ...

  5. windows下subversion服务器搭建

    一.下载subversion服务器端和客户端软件 1.subversion下载地址:http://subversion.tigris.org/ 2.svn比较流行的客户端Tortoisesvn下载地址 ...

  6. 初步学习pg_control文件之六

    接前文:初步学习pg_control文件之五 ,DB_IN_ARCHIVE_RECOVERY何时出现? 看代码:如果recovery.conf文件存在,则返回 InArchiveRecovery = ...

  7. P2340 奶牛会展(状压dp)

    P2340 奶牛会展 题目背景 奶牛想证明它们是聪明而风趣的.为此,贝西筹备了一个奶牛博览会,她已经对N 头奶牛进行 了面试,确定了每头奶牛的智商和情商. 题目描述 贝西有权选择让哪些奶牛参加展览.由 ...

  8. SLAM中的常识与经验

    双目矫正 双目通常事先是通过畸变矫正标定的,而RGB-D和单目则并不一定完成了矫正. 因此,对于RGB-D和单目获取的图像,在提取特征点之后,需要矫正,而双目则可以省略这一过程. 词袋模型反向索引 D ...

  9. [网站日志]当Memcached缓存服务挂掉时性能监视器中的表现

    我们用的Memcached缓存服务是阿里云OCS,今天晚上遇到了一次OCS挂掉的情况(计划中的升级),看一下性能监视器中的表现,也许对分析黑色1秒问题有帮助. 应用日志中错误: 2014-06-05 ...

  10. 「暑期训练」「Brute Force」 Far Relative’s Problem (CFR343D2B)

    题意 之后补 分析 我哭了,强行增加自己的思考复杂度...明明一道尬写的题- -(往区间贪心方向想了 其实完全没必要,注意到只有366天,直接穷举判断即可. 代码 #include <bits/ ...