MXnet的设计结构是C++做后端运算,python、R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面。今天我们先谈一谈Data iterators。

  MXnet中的data iterator和python中的迭代器是很相似的, 当其内置方法next被call的时候它每次返回一个 data batch。所谓databatch,就是神经网络的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。直接上官网上的一个简单的例子来说说吧。

  

  1. import numpy as np
  2. class SimpleIter:
  3. def __init__(self, data_names, data_shapes, data_gen,
  4. label_names, label_shapes, label_gen, num_batches=10):
  5. self._provide_data = zip(data_names, data_shapes)
  6. self._provide_label = zip(label_names, label_shapes)
  7. self.num_batches = num_batches
  8. self.data_gen = data_gen
  9. self.label_gen = label_gen
  10. self.cur_batch = 0
  11.  
  12. def __iter__(self):
  13. return self
  14.  
  15. def reset(self):
  16. self.cur_batch = 0
  17.  
  18. def __next__(self):
  19. return self.next()
  20.  
  21. @property
  22. def provide_data(self):
  23. return self._provide_data
  24.  
  25. @property
  26. def provide_label(self):
  27. return self._provide_label
  28.  
  29. def next(self):
  30. if self.cur_batch < self.num_batches:
  31. self.cur_batch += 1
  32. data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
  33. assert len(data) > 0, "Empty batch data."
  34. label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
  35. assert len(label) > 0, "Empty batch label."
  36. return SimpleBatch(data, label)
  37. else:
  38. raise StopIteration

  上面的代码是最简单的一个dataiter了,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个dataiter必须要实现上面的几个方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的,next()的方法就很显然了,用来返回你的databatch,如果出现问题...记得raise stopIteration,这里或许用try更好吧...需要注意的是,databatch返回的数据类型是mx.nd.ndarry。

  下面是我最近做segmentation的时候用的一个稍微复杂的dataiter,多了预处理和shuffle等步骤:

  

  1. # pylint: skip-file
  2. import random
  3.  
  4. import cv2
  5. import mxnet as mx
  6. import numpy as np
  7. import os
  8. from mxnet.io import DataIter, DataBatch
  9.  
  10. class FileIter(DataIter): #一般都是继承DataIter
  11. """FileIter object in fcn-xs example. Taking a file list file to get dataiter.
  12. in this example, we use the whole image training for fcn-xs, that is to say
  13. we do not need resize/crop the image to the same size, so the batch_size is
  14. set to 1 here
  15. Parameters
  16. ----------
  17. root_dir : string
  18. the root dir of image/label lie in
  19. flist_name : string
  20. the list file of iamge and label, every line owns the form:
  21. index \t image_data_path \t image_label_path
  22. cut_off_size : int
  23. if the maximal size of one image is larger than cut_off_size, then it will
  24. crop the image with the minimal size of that image
  25. data_name : string
  26. the data name used in symbol data(default data name)
  27. label_name : string
  28. the label name used in symbol softmax_label(default label name)
  29. """
  30.  
  31. def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117),
  32. data_name="data", label_name="softmax_label", p=None):
  33. super(FileIter, self).__init__()
  34.  
  35. self.fac = p.fac #这里的P是自己定义的config
  36. self.root_dir = root_dir
  37. self.flist_name = os.path.join(self.root_dir, flist_name)
  38. self.mean = np.array(rgb_mean) # (R, G, B)
  39. self.data_name = data_name
  40. self.label_name = label_name
  41. self.batch_size = p.batch_size
  42. self.random_crop = p.random_crop
  43. self.random_flip = p.random_flip
  44. self.random_color = p.random_color
  45. self.random_scale = p.random_scale
  46. self.output_size = p.output_size
  47. self.color_aug_range = p.color_aug_range
  48. self.use_rnn = p.use_rnn
  49. self.num_hidden = p.num_hidden
  50. if self.use_rnn:
  51. self.init_h_name = 'init_h'
  52. self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden))
  53. self.cursor = -1
  54.  
  55. self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
  56. self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
  57. self.data_list = []
  58. self.label_list = []
  59. self.order = []
  60. self.dict = {}
  61. lines = file(self.flist_name).read().splitlines()
  62. cnt = 0
  63. for line in lines: #读取lst,为后面读取图片做好准备
  64. _, data_img_name, label_img_name = line.strip('\n').split("\t")
  65. self.data_list.append(data_img_name)
  66. self.label_list.append(label_img_name)
  67. self.order.append(cnt)
  68. cnt += 1
  69. self.num_data = cnt
  70. self._shuffle()
  71.  
  72. def _shuffle(self):
  73. random.shuffle(self.order)
  74.  
  75. def _read_img(self, img_name, label_name):
  76.      # 这个是在服务器上跑的时候,因为数据集很小,而且经常被同事卡IO,所以我就把数据全部放进了内存
  77. if os.path.join(self.root_dir, img_name) in self.dict:
  78. img = self.dict[os.path.join(self.root_dir, img_name)]
  79. else:
  80. img = cv2.imread(os.path.join(self.root_dir, img_name))
  81. self.dict[os.path.join(self.root_dir, img_name)] = img
  82.  
  83. if os.path.join(self.root_dir, label_name) in self.dict:
  84. label = self.dict[os.path.join(self.root_dir, label_name)]
  85. else:
  86. label = cv2.imread(os.path.join(self.root_dir, label_name),0)
  87. self.dict[os.path.join(self.root_dir, label_name)] = label
  88.  
  89.      # 下面是读取图片后的一系统预处理工作
  90. if self.random_flip:
  91. flip = random.randint(0, 1)
  92. if flip == 1:
  93. img = cv2.flip(img, 1)
  94. label = cv2.flip(label, 1)
  95. # scale jittering
  96. scale = random.uniform(self.random_scale[0], self.random_scale[1])
  97. new_width = int(img.shape[1] * scale) #
  98. new_height = int(img.shape[0] * scale) # new_width * img.size[1] / img.size[0]
  99. img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
  100. label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
  101. #img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST)
  102. #label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST)
  103. if self.random_crop:
  104. start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1)
  105. start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1)
  106. img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :]
  107. label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]]
  108. if self.random_color:
  109. img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
  110. hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0])
  111. sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1])
  112. val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2])
  113. img = np.array(img, dtype=np.float32)
  114. img[..., 0] += hue
  115. img[..., 1] += sat
  116. img[..., 2] += val
  117. img[..., 0] = np.clip(img[..., 0], 0, 255)
  118. img[..., 1] = np.clip(img[..., 1], 0, 255)
  119. img[..., 2] = np.clip(img[..., 2], 0, 255)
  120. img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR)
  121. is_rgb = True
  122. #cv2.imshow('main', img)
  123. #cv2.waitKey()
  124. #cv2.imshow('maain', label)
  125. #cv2.waitKey()
  126. img = np.array(img, dtype=np.float32) # (h, w, c)
  127. reshaped_mean = self.mean.reshape(1, 1, 3)
  128. img = img - reshaped_mean
  129. img[:, :, :] = img[:, :, [2, 1, 0]]
  130. img = img.transpose(2, 0, 1)
  131. # img = np.expand_dims(img, axis=0) # (1, c, h, w)
  132.  
  133. label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac)
  134. label_zoomed = label_zoomed.astype('uint8')
  135. return (img, label_zoomed)
  136.  
  137. @property
  138. def provide_data(self):
  139. """The name and shape of data provided by this iterator"""
  140. if self.use_rnn:
  141. return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])),
  142. (self.init_h_name, (self.batch_size, self.num_hidden))]
  143. else:
  144. return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))]
  145.  
  146. @property
  147. def provide_label(self):
  148. """The name and shape of label provided by this iterator"""
  149. return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))]
  150.  
  151. def get_batch_size(self):
  152. return self.batch_size
  153.  
  154. def reset(self):
  155. self.cursor = -self.batch_size
  156. self._shuffle()
  157.  
  158. def iter_next(self):
  159. self.cursor += self.batch_size
  160. return self.cursor < self.num_data
  161.  
  162. def _getpad(self):
  163. if self.cursor + self.batch_size > self.num_data:
  164. return self.cursor + self.batch_size - self.num_data
  165. else:
  166. return 0
  167.  
  168. def _getdata(self):
  169. """Load data from underlying arrays, internal use only"""
  170. assert(self.cursor < self.num_data), "DataIter needs reset."
  171. data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
  172. label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
  173. if self.cursor + self.batch_size <= self.num_data:
  174. for i in range(self.batch_size):
  175. idx = self.order[self.cursor + i]
  176. data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
  177. data[i] = data_
  178. label[i] = label_
  179. else:
  180. for i in range(self.num_data - self.cursor):
  181. idx = self.order[self.cursor + i]
  182. data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
  183. data[i] = data_
  184. label[i] = label_
  185. pad = self.batch_size - self.num_data + self.cursor
  186. #for i in pad:
  187. for i in range(pad):
  188. idx = self.order[i]
  189. data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
  190. data[i + self.num_data - self.cursor] = data_
  191. label[i + self.num_data - self.cursor] = label_
  192. return mx.nd.array(data), mx.nd.array(label)
  193.  
  194. def next(self):
  195. """return one dict which contains "data" and "label" """
  196. if self.iter_next():
  197. data, label = self._getdata()
  198. data = [data, self.init_h] if self.use_rnn else [data]
  199. label = [label]
  200. return DataBatch(data=data, label=label,
  201. pad=self._getpad(), index=None,
  202. provide_data=self.provide_data,
  203. provide_label=self.provide_label)
  204. else:
  205. raise StopIteration

    到这里基本上正常的训练我们就可以开始了,但是当你有了很多新的想法的时候,你又会遇到新的问题...比如:multi input/output怎么办?

    其实也很简单,只需要修改几个地方:

      1、provide_label和provide_data,注意到之前我们的return都是一个list,所以之间在里面添加和之前一样的格式就行了。

      2. next() 如果你需要传 data和depth两个输入,只需要传 input = sum([[data],[depth],[]])到databatch的data就行了,label也同理。

    值得一提的时候,MXnet的multi loss实现起来需要在写network的symbol的时候注意一点,假设你有softmax_loss和regression_loss。那么只要在最后return mx.symbol.Group([softmax_loss, regression_loss])。

    总之......That's all~~~~

  

从零开始学习MXnet(二)之dataiter的更多相关文章

  1. 从零开始学习jQuery (二) 万能的选择器

    本系列文章导航 从零开始学习jQuery (二) 万能的选择器 一.摘要 本章讲解jQuery最重要的选择器部分的知识. 有了jQuery的选择器我们几乎可以获取页面上任意的一个或一组对象, 可以明显 ...

  2. 从零开始学习MXnet(四)计算图和粗细粒度以及自动求导

    这篇其实跟使用MXnet的关系不大,但对于我们理解深度学习的框架设计还是很有帮助的. 首先还是对promgramming models的一个简单介绍,这个东西实际上是在编译里面经常出现的东西,我们在编 ...

  3. 从零开始学习MXnet(三)之Model和Module

    在我们在MXnet中定义好symbol.写好dataiter并且准备好data之后,就可以开开心的去训练了.一般训练一个网络有两种常用的策略,基于model的和基于module的.今天,我想谈一谈他们 ...

  4. 从零开始学习Android(二)从架构开始说起

    我们刚开始学新东西的时候,往往希望能从一个实例进行入手学习.接下来的系列连载文章也主要是围绕这个实例进行.这个实例原形是从电子书<Android应用开发详解>得到的,我们在这里对其进行详细 ...

  5. 从零开始学习MXnet(五)MXnet的黑科技之显存节省大法

    写完发现名字有点拗口..- -# 大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉 ...

  6. 从零开始学习MXnet(一)

    最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家. 我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步. 一 ...

  7. oracle从零开始学习笔记 二

    多表查询 等值连接(Equijoin) select ename,empno,sal,emp.deptno from emp,dept where dept.deptno=emp.deptno; 非等 ...

  8. 从零开始学习Vue(二)

    思维方式的变化 WebForm时代, Aspx.cs 取得数据,绑定到前台的Repeater之类的控件.重新渲染整个HTML页面.就是整个页面不断的刷新;后来微软打了个补丁,推出了AJAX控件,比如U ...

  9. 从零开始学习jQuery(转)

    本系列文章导航 从零开始学习jQuery (一) 开天辟地入门篇 从零开始学习jQuery (二) 万能的选择器 从零开始学习jQuery (三) 管理jQuery包装集 从零开始学习jQuery ( ...

随机推荐

  1. flask过滤器

    过滤器的本质就是函数.有时候我们不仅仅只是需要输出变量的值,我们还需要修改变量的显示,甚至格式化.运算等等,而在模板中是不能直接调用 Python 中的某些方法,那么这就用到了过滤器. 过滤器的使用方 ...

  2. 常用 css html 样式

    CSS基础必学列表 CSS width宽度 CSS height高度 CSS border边框 CSS background背景 CSS sprites背景拼合 CSS float浮动 CSS mar ...

  3. css在线sprite

    大家知道网站图片多,浏览器下载多个图片要有多个请求.可是请求比较耗时,那怎么办呢? 对,方法就是css sprite. 今天我们来看看css在线sprite 百度搜索css-sprite 打开www. ...

  4. 【转】谈谈 iOS 中图片的解压缩

    转自:http://blog.leichunfeng.com/blog/2017/02/20/talking-about-the-decompression-of-the-image-in-ios/ ...

  5. Linux - 信息收集

    1. #!,代表加载器(解释器)的路径,如: #!/bin/bash echo "Hello Boy!" 上面的意思是说,把下面的字符(#!/bin/bash以下的所有字符)统统传 ...

  6. java.sql.Date java.sql.Time java.sql.Timestamp 之比较

    java.sql.Date,java.sql.Time和java.sql.Timestamp 三个都是java.util.Date的子类(包装类). java.sql.Date是java.util.D ...

  7. react实现页面切换动画效果

    一.前情概要 注:(我使用的路由是react-router4)     如下图所示,我们需要在页面切换时有一个过渡效果,这样就不会使页面切换显得生硬,用户体验大大提升:     but the 问题是 ...

  8. 11-Mysql数据库----单表查询

    本节重点: 单表查询 语法: 一.单表查询的语法 SELECT 字段1,字段2... FROM 表名 WHERE 条件 GROUP BY field HAVING 筛选 ORDER BY field ...

  9. NLP系列-中文分词(基于统计)

    上文已经介绍了基于词典的中文分词,现在让我们来看一下基于统计的中文分词. 统计分词: 统计分词的主要思想是把每个词看做是由字组成的,如果相连的字在不同文本中出现的次数越多,就证明这段相连的字很有可能就 ...

  10. GraphSAGE 代码解析(四) - models.py

    原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(二) - layers.py ...