从零开始学习MXnet(二)之dataiter
MXnet的设计结构是C++做后端运算,python、R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面。今天我们先谈一谈Data iterators。
MXnet中的data iterator和python中的迭代器是很相似的, 当其内置方法next被call的时候它每次返回一个 data batch。所谓databatch,就是神经网络的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。直接上官网上的一个简单的例子来说说吧。
- import numpy as np
- class SimpleIter:
- def __init__(self, data_names, data_shapes, data_gen,
- label_names, label_shapes, label_gen, num_batches=10):
- self._provide_data = zip(data_names, data_shapes)
- self._provide_label = zip(label_names, label_shapes)
- self.num_batches = num_batches
- self.data_gen = data_gen
- self.label_gen = label_gen
- self.cur_batch = 0
- def __iter__(self):
- return self
- def reset(self):
- self.cur_batch = 0
- def __next__(self):
- return self.next()
- @property
- def provide_data(self):
- return self._provide_data
- @property
- def provide_label(self):
- return self._provide_label
- def next(self):
- if self.cur_batch < self.num_batches:
- self.cur_batch += 1
- data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
- assert len(data) > 0, "Empty batch data."
- label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
- assert len(label) > 0, "Empty batch label."
- return SimpleBatch(data, label)
- else:
- raise StopIteration
上面的代码是最简单的一个dataiter了,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个dataiter必须要实现上面的几个方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的,next()的方法就很显然了,用来返回你的databatch,如果出现问题...记得raise stopIteration,这里或许用try更好吧...需要注意的是,databatch返回的数据类型是mx.nd.ndarry。
下面是我最近做segmentation的时候用的一个稍微复杂的dataiter,多了预处理和shuffle等步骤:
- # pylint: skip-file
- import random
- import cv2
- import mxnet as mx
- import numpy as np
- import os
- from mxnet.io import DataIter, DataBatch
- class FileIter(DataIter): #一般都是继承DataIter
- """FileIter object in fcn-xs example. Taking a file list file to get dataiter.
- in this example, we use the whole image training for fcn-xs, that is to say
- we do not need resize/crop the image to the same size, so the batch_size is
- set to 1 here
- Parameters
- ----------
- root_dir : string
- the root dir of image/label lie in
- flist_name : string
- the list file of iamge and label, every line owns the form:
- index \t image_data_path \t image_label_path
- cut_off_size : int
- if the maximal size of one image is larger than cut_off_size, then it will
- crop the image with the minimal size of that image
- data_name : string
- the data name used in symbol data(default data name)
- label_name : string
- the label name used in symbol softmax_label(default label name)
- """
- def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117),
- data_name="data", label_name="softmax_label", p=None):
- super(FileIter, self).__init__()
- self.fac = p.fac #这里的P是自己定义的config
- self.root_dir = root_dir
- self.flist_name = os.path.join(self.root_dir, flist_name)
- self.mean = np.array(rgb_mean) # (R, G, B)
- self.data_name = data_name
- self.label_name = label_name
- self.batch_size = p.batch_size
- self.random_crop = p.random_crop
- self.random_flip = p.random_flip
- self.random_color = p.random_color
- self.random_scale = p.random_scale
- self.output_size = p.output_size
- self.color_aug_range = p.color_aug_range
- self.use_rnn = p.use_rnn
- self.num_hidden = p.num_hidden
- if self.use_rnn:
- self.init_h_name = 'init_h'
- self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden))
- self.cursor = -1
- self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
- self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
- self.data_list = []
- self.label_list = []
- self.order = []
- self.dict = {}
- lines = file(self.flist_name).read().splitlines()
- cnt = 0
- for line in lines: #读取lst,为后面读取图片做好准备
- _, data_img_name, label_img_name = line.strip('\n').split("\t")
- self.data_list.append(data_img_name)
- self.label_list.append(label_img_name)
- self.order.append(cnt)
- cnt += 1
- self.num_data = cnt
- self._shuffle()
- def _shuffle(self):
- random.shuffle(self.order)
- def _read_img(self, img_name, label_name):
- # 这个是在服务器上跑的时候,因为数据集很小,而且经常被同事卡IO,所以我就把数据全部放进了内存
- if os.path.join(self.root_dir, img_name) in self.dict:
- img = self.dict[os.path.join(self.root_dir, img_name)]
- else:
- img = cv2.imread(os.path.join(self.root_dir, img_name))
- self.dict[os.path.join(self.root_dir, img_name)] = img
- if os.path.join(self.root_dir, label_name) in self.dict:
- label = self.dict[os.path.join(self.root_dir, label_name)]
- else:
- label = cv2.imread(os.path.join(self.root_dir, label_name),0)
- self.dict[os.path.join(self.root_dir, label_name)] = label
- # 下面是读取图片后的一系统预处理工作
- if self.random_flip:
- flip = random.randint(0, 1)
- if flip == 1:
- img = cv2.flip(img, 1)
- label = cv2.flip(label, 1)
- # scale jittering
- scale = random.uniform(self.random_scale[0], self.random_scale[1])
- new_width = int(img.shape[1] * scale) #
- new_height = int(img.shape[0] * scale) # new_width * img.size[1] / img.size[0]
- img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
- label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
- #img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST)
- #label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST)
- if self.random_crop:
- start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1)
- start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1)
- img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :]
- label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]]
- if self.random_color:
- img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
- hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0])
- sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1])
- val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2])
- img = np.array(img, dtype=np.float32)
- img[..., 0] += hue
- img[..., 1] += sat
- img[..., 2] += val
- img[..., 0] = np.clip(img[..., 0], 0, 255)
- img[..., 1] = np.clip(img[..., 1], 0, 255)
- img[..., 2] = np.clip(img[..., 2], 0, 255)
- img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR)
- is_rgb = True
- #cv2.imshow('main', img)
- #cv2.waitKey()
- #cv2.imshow('maain', label)
- #cv2.waitKey()
- img = np.array(img, dtype=np.float32) # (h, w, c)
- reshaped_mean = self.mean.reshape(1, 1, 3)
- img = img - reshaped_mean
- img[:, :, :] = img[:, :, [2, 1, 0]]
- img = img.transpose(2, 0, 1)
- # img = np.expand_dims(img, axis=0) # (1, c, h, w)
- label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac)
- label_zoomed = label_zoomed.astype('uint8')
- return (img, label_zoomed)
- @property
- def provide_data(self):
- """The name and shape of data provided by this iterator"""
- if self.use_rnn:
- return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])),
- (self.init_h_name, (self.batch_size, self.num_hidden))]
- else:
- return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))]
- @property
- def provide_label(self):
- """The name and shape of label provided by this iterator"""
- return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))]
- def get_batch_size(self):
- return self.batch_size
- def reset(self):
- self.cursor = -self.batch_size
- self._shuffle()
- def iter_next(self):
- self.cursor += self.batch_size
- return self.cursor < self.num_data
- def _getpad(self):
- if self.cursor + self.batch_size > self.num_data:
- return self.cursor + self.batch_size - self.num_data
- else:
- return 0
- def _getdata(self):
- """Load data from underlying arrays, internal use only"""
- assert(self.cursor < self.num_data), "DataIter needs reset."
- data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
- label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
- if self.cursor + self.batch_size <= self.num_data:
- for i in range(self.batch_size):
- idx = self.order[self.cursor + i]
- data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
- data[i] = data_
- label[i] = label_
- else:
- for i in range(self.num_data - self.cursor):
- idx = self.order[self.cursor + i]
- data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
- data[i] = data_
- label[i] = label_
- pad = self.batch_size - self.num_data + self.cursor
- #for i in pad:
- for i in range(pad):
- idx = self.order[i]
- data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
- data[i + self.num_data - self.cursor] = data_
- label[i + self.num_data - self.cursor] = label_
- return mx.nd.array(data), mx.nd.array(label)
- def next(self):
- """return one dict which contains "data" and "label" """
- if self.iter_next():
- data, label = self._getdata()
- data = [data, self.init_h] if self.use_rnn else [data]
- label = [label]
- return DataBatch(data=data, label=label,
- pad=self._getpad(), index=None,
- provide_data=self.provide_data,
- provide_label=self.provide_label)
- else:
- raise StopIteration
到这里基本上正常的训练我们就可以开始了,但是当你有了很多新的想法的时候,你又会遇到新的问题...比如:multi input/output怎么办?
其实也很简单,只需要修改几个地方:
1、provide_label和provide_data,注意到之前我们的return都是一个list,所以之间在里面添加和之前一样的格式就行了。
2. next() 如果你需要传 data和depth两个输入,只需要传 input = sum([[data],[depth],[]])到databatch的data就行了,label也同理。
值得一提的时候,MXnet的multi loss实现起来需要在写network的symbol的时候注意一点,假设你有softmax_loss和regression_loss。那么只要在最后return mx.symbol.Group([softmax_loss, regression_loss])。
总之......That's all~~~~
从零开始学习MXnet(二)之dataiter的更多相关文章
- 从零开始学习jQuery (二) 万能的选择器
本系列文章导航 从零开始学习jQuery (二) 万能的选择器 一.摘要 本章讲解jQuery最重要的选择器部分的知识. 有了jQuery的选择器我们几乎可以获取页面上任意的一个或一组对象, 可以明显 ...
- 从零开始学习MXnet(四)计算图和粗细粒度以及自动求导
这篇其实跟使用MXnet的关系不大,但对于我们理解深度学习的框架设计还是很有帮助的. 首先还是对promgramming models的一个简单介绍,这个东西实际上是在编译里面经常出现的东西,我们在编 ...
- 从零开始学习MXnet(三)之Model和Module
在我们在MXnet中定义好symbol.写好dataiter并且准备好data之后,就可以开开心的去训练了.一般训练一个网络有两种常用的策略,基于model的和基于module的.今天,我想谈一谈他们 ...
- 从零开始学习Android(二)从架构开始说起
我们刚开始学新东西的时候,往往希望能从一个实例进行入手学习.接下来的系列连载文章也主要是围绕这个实例进行.这个实例原形是从电子书<Android应用开发详解>得到的,我们在这里对其进行详细 ...
- 从零开始学习MXnet(五)MXnet的黑科技之显存节省大法
写完发现名字有点拗口..- -# 大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉 ...
- 从零开始学习MXnet(一)
最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家. 我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步. 一 ...
- oracle从零开始学习笔记 二
多表查询 等值连接(Equijoin) select ename,empno,sal,emp.deptno from emp,dept where dept.deptno=emp.deptno; 非等 ...
- 从零开始学习Vue(二)
思维方式的变化 WebForm时代, Aspx.cs 取得数据,绑定到前台的Repeater之类的控件.重新渲染整个HTML页面.就是整个页面不断的刷新;后来微软打了个补丁,推出了AJAX控件,比如U ...
- 从零开始学习jQuery(转)
本系列文章导航 从零开始学习jQuery (一) 开天辟地入门篇 从零开始学习jQuery (二) 万能的选择器 从零开始学习jQuery (三) 管理jQuery包装集 从零开始学习jQuery ( ...
随机推荐
- flask过滤器
过滤器的本质就是函数.有时候我们不仅仅只是需要输出变量的值,我们还需要修改变量的显示,甚至格式化.运算等等,而在模板中是不能直接调用 Python 中的某些方法,那么这就用到了过滤器. 过滤器的使用方 ...
- 常用 css html 样式
CSS基础必学列表 CSS width宽度 CSS height高度 CSS border边框 CSS background背景 CSS sprites背景拼合 CSS float浮动 CSS mar ...
- css在线sprite
大家知道网站图片多,浏览器下载多个图片要有多个请求.可是请求比较耗时,那怎么办呢? 对,方法就是css sprite. 今天我们来看看css在线sprite 百度搜索css-sprite 打开www. ...
- 【转】谈谈 iOS 中图片的解压缩
转自:http://blog.leichunfeng.com/blog/2017/02/20/talking-about-the-decompression-of-the-image-in-ios/ ...
- Linux - 信息收集
1. #!,代表加载器(解释器)的路径,如: #!/bin/bash echo "Hello Boy!" 上面的意思是说,把下面的字符(#!/bin/bash以下的所有字符)统统传 ...
- java.sql.Date java.sql.Time java.sql.Timestamp 之比较
java.sql.Date,java.sql.Time和java.sql.Timestamp 三个都是java.util.Date的子类(包装类). java.sql.Date是java.util.D ...
- react实现页面切换动画效果
一.前情概要 注:(我使用的路由是react-router4) 如下图所示,我们需要在页面切换时有一个过渡效果,这样就不会使页面切换显得生硬,用户体验大大提升: but the 问题是 ...
- 11-Mysql数据库----单表查询
本节重点: 单表查询 语法: 一.单表查询的语法 SELECT 字段1,字段2... FROM 表名 WHERE 条件 GROUP BY field HAVING 筛选 ORDER BY field ...
- NLP系列-中文分词(基于统计)
上文已经介绍了基于词典的中文分词,现在让我们来看一下基于统计的中文分词. 统计分词: 统计分词的主要思想是把每个词看做是由字组成的,如果相连的字在不同文本中出现的次数越多,就证明这段相连的字很有可能就 ...
- GraphSAGE 代码解析(四) - models.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(二) - layers.py ...