实际上关于pytorch加载自己的数据之前有写过一篇博客,但是最近接触了mxnet,发现关于这方面的教程很少

如果要加载自己定义的数据的话,看mxnet关于mnist基本上能够推测12

看pytorch与mxnet他们加载数据方式的对比

上图左边是pytorch的,右图是mxnet

实际上,mxnet与pytorch他们的datalayer有着相似之处,为什么这样说呢?直接看上面的代码,基本上都是输入图像的路径,然后输出一个可以供loader调用的可以迭代的对象,所以无论是pytorch或者是mxnet,如果要有自己的数据,只需要在自己的数据那一部分继承与修改ImageFolderDataset这个函数就行,就是直接继承dataset.Dataset类即可

对于pytorch而言,它使用了find_class这样一个函数,而对于mxnet而言,实际上它在类内部定义了一个_list_images的函数,事实上我并没有发现这有没有用,只需要get_item这个函数中返回list,list中是一个tuple,一个是文件的名字,另外一个是文件所对应的label即可。

只需要继承这一个类即可

直接撸代码

这个是我参加kaggle比赛的一段代码,尽管并不收敛,但请不要在意这些细节

  1. # -*-coding:utf-8-*-
  2. from mxnet import autograd
  3. from mxnet import gluon
  4. from mxnet import image
  5. from mxnet import init
  6. from mxnet import nd
  7. from mxnet.gluon.data import vision
  8. import numpy as np
  9. from mxnet.gluon.data import dataset
  10. import os
  11. import warnings
  12. import random
  13. from mxnet import gpu
  14. from mxnet.gluon.data.vision import datasets
  15.  
  16. class MyImageFolderDataset(dataset.Dataset):
  17. def __init__(self, root, label, flag=1, transform=None):
  18. self._root = os.path.expanduser(root)
  19. self._flag = flag
  20. self._label = label
  21. self._transform = transform
  22. self._exts = ['.jpg', '.jpeg', '.png']
  23. self._list_images(self._root, self._label)
  24.  
  25. def _list_images(self, root, label): # label是一个list
  26. self.synsets = []
  27. self.synsets.append(root)
  28. self.items = []
  29. #file = open(label)
  30. #lines = file.readlines()
  31. #random.shuffle(lines)
  32. c = 0
  33. for line in label:
  34. cls = line.split()
  35. fn = cls.pop(0)
  36. fn = fn + '.jpg'
  37. # print(os.path.join(root, fn))
  38. if os.path.isfile(os.path.join(root, fn)):
  39. self.items.append((os.path.join(root, fn), float(cls[0])))
  40. # print((os.path.join(root, fn), float(cls[0])))
  41. else:
  42. print('what')
  43. c = c + 1
  44. print('the total image is ', c)
  45.  
  46. def __getitem__(self, idx):
  47. img = image.imread(self.items[idx][0], self._flag)
  48. label = self.items[idx][1]
  49. if self._transform is not None:
  50. return self._transform(img, label)
  51. return img, label
  52.  
  53. def __len__(self):
  54. return len(self.items)
  55.  
  56. def _get_batch(batch, ctx): # 可以在循环中直接for i, data, label,函数主要把data放在ctx上
  57. """return data and label on ctx"""
  58. if isinstance(batch, mx.io.DataBatch):
  59. data = batch.data[0]
  60. label = batch.label[0]
  61. else:
  62. data, label = batch
  63. return (gluon.utils.split_and_load(data, ctx),
  64. gluon.utils.split_and_load(label, ctx),
  65. data.shape[0])
  66.  
  67. def transform_train(data, label):
  68. im = image.imresize(data.astype('float32') / 255, 256, 256)
  69. auglist = image.CreateAugmenter(data_shape=(3, 256, 256), resize=0,
  70. rand_crop=False, rand_resize=False, rand_mirror=True,
  71. mean=None, std=None,
  72. brightness=0, contrast=0,
  73. saturation=0, hue=0,
  74. pca_noise=0, rand_gray=0, inter_method=2)
  75. for aug in auglist:
  76. im = aug(im)
  77. # 将数据格式从"高*宽*通道"改为"通道*高*宽"。
  78. im = nd.transpose(im, (2, 0, 1))
  79. return (im, nd.array([label]).asscalar().astype('float32'))
  80.  
  81. def transform_test(data, label):
  82. im = image.imresize(data.astype('float32') / 255, 256, 256)
  83. im = nd.transpose(im, (2, 0, 1)) # 之前没有运行此变换
  84. return (im, nd.array([label]).asscalar().astype('float32'))
  85.  
  86. batch_size = 16
  87. root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image'
  88. def random_choose_data(label_path):
  89. f = open(label_path)
  90. lines = f.readlins()
  91. random.shuffle(lines)
  92. total_number = len(lines)
  93. train_number = total_number/10*7
  94. train_list = lines[:train_number]
  95. test_list = lines[train_number:]
  96. return (train_list, test_list)
  97.  
  98. label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt'
  99. train_list, test_list = random_choose_data(label_path)
  100. loader = gluon.data.DataLoader
  101. train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train)
  102. test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test)
  103. train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
  104. test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')
  105. softmax_cross_entropy = gluon.loss.L2Loss() # 定义L2 loss
  106.  
  107. from mxnet.gluon import nn
  108.  
  109. net = nn.Sequential()
  110. with net.name_scope():
  111. net.add(
  112. # 第一阶段
  113. nn.Conv2D(channels=96, kernel_size=11,
  114. strides=4, activation='relu'),
  115. nn.MaxPool2D(pool_size=3, strides=2),
  116. # 第二阶段
  117. nn.Conv2D(channels=256, kernel_size=5,
  118. padding=2, activation='relu'),
  119. nn.MaxPool2D(pool_size=3, strides=2),
  120. # 第三阶段
  121. nn.Conv2D(channels=384, kernel_size=3,
  122. padding=1, activation='relu'),
  123. nn.Conv2D(channels=384, kernel_size=3,
  124. padding=1, activation='relu'),
  125. nn.Conv2D(channels=256, kernel_size=3,
  126. padding=1, activation='relu'),
  127. nn.MaxPool2D(pool_size=3, strides=2),
  128. # 第四阶段
  129. nn.Flatten(),
  130. nn.Dense(4096, activation="relu"),
  131. nn.Dropout(.5),
  132. # 第五阶段
  133. nn.Dense(4096, activation="relu"),
  134. nn.Dropout(.5),
  135. # 第六阶段
  136. nn.Dense(14950) # 输出为1个值
  137. )
  138.  
  139. from mxnet import init
  140. from mxnet import gluon
  141. import mxnet as mx
  142. import utils
  143. import datetime
  144. from time import time
  145.  
  146. ctx = utils.try_gpu()
  147. net.initialize(ctx=ctx, init=init.Xavier())
  148.  
  149. mse_loss = gluon.loss.L2Loss()
  150.  
  151. # utils.train(train_data, test_data, net, loss,
  152. # trainer, ctx, num_epochs=10)
  153. #def train(train_data, test_data, net, loss, trainer, ctx, num_epochs, print_batches=None):
  154. num_epochs = 10
  155. print_batches = 100
  156. """Train a network"""
  157. print("Start training on ", ctx)
  158. if isinstance(ctx, mx.Context):
  159. ctx = [ctx]
  160. def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay):
  161. trainer = gluon.Trainer(net.collect_params(), 'sgd',
  162. {'learning_rate': lr, 'momentum': 0.9, 'wd': wd})
  163. prev_time = datetime.datetime.now()
  164. for epoch in range(num_epochs):
  165. train_loss = 0.0
  166. if epoch > 0 and epoch % lr_period == 0:
  167. trainer.set_learning_rate(trainer.learning_rate*lr_decay)
  168. for data, label in train_data:
  169. label = label.as_in_context(ctx)
  170. with autograd.record():
  171. output = net(data.as_in_context(ctx))
  172. loss = mse_loss(output, label)
  173. loss.backward()
  174. trainer.step(batch_size) # do the update, Trainer needs to know the batch size of the data to normalize
  175. # the gradient by 1/batch_size
  176. train_loss += nd.mean(loss).asscalar()
  177. print(nd.mean(loss).asscalar())
  178. cur_time = datetime.datetime.now()
  179. h, remainder = divmod((cur_time - prev_time).seconds, 3600)
  180. m, s = divmod(remainder, 60)
  181. time_str = "Time %02d:%02d:%02d" % (h, m, s)
  182. epoch_str = ('Epoch %d. Train loss: %f, ' % (epoch, train_loss / len(train_data)))
  183. prev_time = cur_time
  184. print(epoch_str + time_str + ', lr' + str(trainer.learning_rate))
  185. net.collect_params().save('./model/alexnet.params')
  186. ctx = utils.try_gpu()
  187. num_epochs = 100
  188. learning_rate = 0.001
  189. weight_decay = 5e-4
  190. lr_period = 10
  191. lr_decay = 0.1
  192.  
  193. train(net, train_data, test_data, num_epochs, learning_rate,
  194. weight_decay, ctx, lr_period, lr_decay)

请看这一段

  1. class MyImageFolderDataset(dataset.Dataset):
  2. def __init__(self, root, label, flag=1, transform=None):
  3. self._root = os.path.expanduser(root)
  4. self._flag = flag
  5. self._label = label
  6. self._transform = transform
  7. self._exts = ['.jpg', '.jpeg', '.png']
  8. self._list_images(self._root, self._label)
  9.  
  10. def _list_images(self, root, label): # label是一个list
  11. self.synsets = []
  12. self.synsets.append(root)
  13. self.items = []
  14. #file = open(label)
  15. #lines = file.readlines()
  16. #random.shuffle(lines)
  17. c = 0
  18. for line in label:
  19. cls = line.split()
  20. fn = cls.pop(0)
  21. fn = fn + '.jpg'
  22. # print(os.path.join(root, fn))
  23. if os.path.isfile(os.path.join(root, fn)):
  24. self.items.append((os.path.join(root, fn), float(cls[0])))
  25. # print((os.path.join(root, fn), float(cls[0])))
  26. else:
  27. print('what')
  28. c = c + 1
  29. print('the total image is ', c)
  30.  
  31. def __getitem__(self, idx):
  32. img = image.imread(self.items[idx][0], self._flag)
  33. label = self.items[idx][1]
  34. if self._transform is not None:
  35. return self._transform(img, label)
  36. return img, label
  37.  
  38. def __len__(self):
  39. return len(self.items)
  40. batch_size = 16
  41. root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image'
  42. def random_choose_data(label_path):
  43. f = open(label_path)
  44. lines = f.readlins()
  45. random.shuffle(lines)
  46. total_number = len(lines)
  47. train_number = total_number/10*7
  48. train_list = lines[:train_number]
  49. test_list = lines[train_number:]
  50. return (train_list, test_list)
  51.  
  52. label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt'
  53. train_list, test_list = random_choose_data(label_path)
  54.  
  55. loader = gluon.data.DataLoader
  56. train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train)
  57. test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test)
  58. train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
  59. test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')

MyImageFolderDataset是dataset.Dataset的子类,主要是是重载索引运算__getitem__,并且返回image以及其对应的label即可,前面的的_list_image函数只要是能够返回item这个list就行,关于运算符重载给自己挖个坑

可以说和pytorch非常像了,就连沐神在讲课的时候还在说,其实在写mxnet的时候,借鉴了很多pytorch的内容

mxnet自定义dataloader加载自己的数据的更多相关文章

  1. hive 压缩全解读(hive表存储格式以及外部表直接加载压缩格式数据);HADOOP存储数据压缩方案对比(LZO,gz,ORC)

    数据做压缩和解压缩会增加CPU的开销,但可以最大程度的减少文件所需的磁盘空间和网络I/O的开销,所以最好对那些I/O密集型的作业使用数据压缩,cpu密集型,使用压缩反而会降低性能. 而hive中间结果 ...

  2. [原创.数据可视化系列之三]使用Ol3加载大量点数据

    不管是百度地图还是高德地图,都很难得见到在地图上加载大量点要素,比如同屏1000的,因为这样客户端性能会很低,尤其是IE系列的浏览器,简直是卡的要死.但有的时候,还真的需要,比如,我要加载全球的AQI ...

  3. jsTree 的简单用法--异步加载和刷新数据

    首先这两个文件是必须要引用的,还有就是引用 jQuery 文件就不说了: <link href="/css/plugins/jsTree/style.min.css" rel ...

  4. 使用getJSON()方法异步加载JSON格式数据

    使用getJSON()方法异步加载JSON格式数据 使用getJSON()方法可以通过Ajax异步请求的方式,获取服务器中的数组,并对获取的数据进行解析,显示在页面中,它的调用格式为: jQuery. ...

  5. 异步加载回来的数据不受JS控制了

    写成下面这种方式时,异步加载回来的数据不受JS控制 $(."orderdiv").click(function(){ $(this).find(".orderinfo&q ...

  6. echarts 图表重新加载,原来的数据依然存在图表上

    问题 在做一个全国地图上一些饼图,并且向省一级的地图钻取的时候,原来的饼图依然显示 原因 echars所有添加的图表都在一个series属性集合中,并且同一个echars对象默认是合并之前的数据的,所 ...

  7. 实用ExtJS教程100例-010:ExtJS Form异步加载和提交数据

    ExtJS Form 为我们提供了两个方法:load 和 submit,分别用来加载和提交数据,这两个方法都是异步的. 系列ExtJS教程持续更新中,点击查看>>最新ExtJS教程目录 F ...

  8. HTTP 筛选器 DLL C:\Windows\Microsoft.Net\Framework\v4.0.30319\aspnet_filter.dll 加载失败。数据是错误。

    今天在一台win2003的云主机上,安装.net 4.0时,所有的网站都打不开了.打开事件查看器,发现以下错误: HTTP 筛选器 DLL C:\Windows\Microsoft.Net\Frame ...

  9. Flex 4 自定义预加载器

    本示例的目的是在Flash Professional里创建自定义预加载器SWC,并扩展SparkDownloadProgressBar类在Flex 4应用程序中使用.    预加载器显示加载进度百分比 ...

随机推荐

  1. Vue 项目骨架屏注入与实践

    作为与用户联系最为密切的前端开发者,用户体验是最值得关注的问题.关于页面loading状态的展示,主流的主要有loading图和进度条两种.除此之外,越来越多的APP采用了“骨架屏”的方式去展示未加载 ...

  2. 常用java命令

    javap 反编译 javap xxx.class 查看大概 javap -v -p xxx.class 查看详细 jps 查看有哪些java进程 jinfo 查看或设置java进程的 vm 参数,只 ...

  3. CSS三列布局之左右宽度固定,中间元素自适应问题

    最近学到了几种关于左右固定宽度,中间自适应的三列布局的方法,整理了一下,在这里跟大家一起分享分享,其中有什么不足的还望各位给指导指导哈. 首先我想到的是float——浮动布局 使用浮动,先渲染左右两个 ...

  4. Eclipse集成weblogic教程

    1.在线安装插件 1.1安装Oracle Weblogic Servers Tools oeop是添加的软件仓库的名字,随便写主要是方便记. 仓库链接:http://www.oracle.com/te ...

  5. maven打包上传到本地中央库

    pom文件中添加插件如下 <build> <plugins> <plugin> <groupId>org.apache.maven.plugins< ...

  6. 转 Deep Learning for NLP 文章列举

    原文链接:http://www.xperseverance.net/blogs/2013/07/2124/   大部分文章来自: http://www.socher.org/ http://deepl ...

  7. Linux 第一周作业

    [](http://images2017.cnblogs.com/blog/1249774/201710/1249774-20171001234038872-10d31233192.pngd

  8. asp.net 发送电子邮件本地测试正常,但服务器上异常的解决办法

    如题,这个问题曾经非常苦恼,代码肯定是没有问题的.在网上也查找了不少资料,按照他们的步骤做了,还是无效. 最后问题解决了,原来:我租用腾讯云服务器,腾讯为了防止垃圾邮件,禁止了邮件发送的25号端口,原 ...

  9. 整数中1出现的次数(1~n)

    题目描述 求出1~13的整数中1出现的次数,并算出100~1300的整数中1出现的次数?为此他特别数了一下1~13中包含1的数字有1.10.11.12.13因此共出现6次,但是对于后面问题他就没辙了. ...

  10. 十九. Python基础(19)--异常

    十九. Python基础(19)--异常 1 ● 捕获异常 if VS异常处理: if是预防异常出现, 异常处理是处理异常出现 异常处理一般格式: try:     <............. ...