解压文件命令:

with zipfile.ZipFile('../data/kaggle_cifar10/' + fin, 'r') as zin:
zin.extractall('../data/kaggle_cifar10/')

拷贝文件命令:

shutil.copy(原文件, 目标文件)

一、整理数据

我们有两个文件夹'../data/kaggle_cifar10/train'和'../data/kaggle_cifar10/test',一个记录了文件名和类别的索引文件

我们的目的是在新的文件夹下形成拷贝,包含三个文件夹train_valid、train、valid,每个文件夹下存放不同的类别文件夹,里面存放对应类别的图片,

import os
import shutil def reorg_cifar10_data(data_dir, label_file, train_dir, test_dir, input_dir, valid_ratio):
"""
处理之后,新建三个文件夹存放数据,train_valid、train、valid
data_dir:'../data/kaggle_cifar10'
label_file:'trainLabels.csv'
train_dir = 'train'
test_dir = 'test'
input_dir = 'train_valid_test'
valid_ratio = 0.1
"""
# 读取训练数据标签。
# 打开csv索引:'../data/kaggle_cifar10/trainLabels.csv'
with open(os.path.join(data_dir, label_file), 'r') as f:
# 跳过文件头行(栏名称)。
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
# {索引:标签}
idx_label = dict(((int(idx), label) for idx, label in tokens))
# 标签集合
labels = set(idx_label.values())
# 训练数据数目:'../data/kaggle_cifar10/train'
num_train = len(os.listdir(os.path.join(data_dir, train_dir)))
# train数目(对应valid)
num_train_tuning = int(num_train * (1 - valid_ratio))
# <---异常检测
assert 0 < num_train_tuning < num_train
# 每个label的train数据条目
num_train_tuning_per_label = num_train_tuning // len(labels)
label_count = dict() def mkdir_if_not_exist(path):
if not os.path.exists(os.path.join(*path)):
os.makedirs(os.path.join(*path)) # 整理训练和验证集。
# 循环训练数据图片 '../data/kaggle_cifar10/train'
for train_file in os.listdir(os.path.join(data_dir, train_dir)):
# 去掉扩展名作为索引
idx = int(train_file.split('.')[0])
# 索引到标签
label = idx_label[idx] # '../data/kaggle_cifar10/train_valid_test/train_valid' + 标签名称
mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label])
# 拷贝图片
shutil.copy(os.path.join(data_dir, train_dir, train_file),
os.path.join(data_dir, input_dir, 'train_valid', label)) # 保证train文件夹下的每类标签训练数目足够后,分给valid文件夹
if label not in label_count or label_count[label] < num_train_tuning_per_label:
# '../data/kaggle_cifar10/train_valid_test/train' + 标签名称
mkdir_if_not_exist([data_dir, input_dir, 'train', label])
shutil.copy(os.path.join(data_dir, train_dir, train_file),
os.path.join(data_dir, input_dir, 'train', label))
label_count[label] = label_count.get(label, 0) + 1
else:
mkdir_if_not_exist([data_dir, input_dir, 'valid', label])
shutil.copy(os.path.join(data_dir, train_dir, train_file),
os.path.join(data_dir, input_dir, 'valid', label)) # 整理测试集
# '../data/kaggle_cifar10/train_valid_test/test/unknown' 里面存放test图片
mkdir_if_not_exist([data_dir, input_dir, 'test', 'unknown'])
for test_file in os.listdir(os.path.join(data_dir, test_dir)):
shutil.copy(os.path.join(data_dir, test_dir, test_file),
os.path.join(data_dir, input_dir, 'test', 'unknown')) train_dir = 'train'
test_dir = 'test'
batch_size = 128 data_dir = '../data/kaggle_cifar10'
label_file = 'trainLabels.csv'
input_dir = 'train_valid_test'
valid_ratio = 0.1
reorg_cifar10_data(data_dir, label_file, train_dir, test_dir, input_dir, valid_ratio)

二、数据预处理

# 预处理
from mxnet import autograd
from mxnet import gluon
from mxnet import init
from mxnet import nd
from mxnet.gluon.data import vision
from mxnet.gluon.data.vision import transforms
import numpy as np transform_train = transforms.Compose([
# transforms.CenterCrop(32)
# transforms.RandomFlipTopBottom(),
# transforms.RandomColorJitter(brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0),
# transforms.RandomLighting(0.0),
# transforms.Cast('float32'),
# transforms.Resize(32), # 随机按照scale和ratio裁剪,并放缩为32x32的正方形
transforms.RandomResizedCrop(32, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0)),
# 随机左右翻转图片
transforms.RandomFlipLeftRight(),
# 将图片像素值缩小到(0,1)内,并将数据格式从"高*宽*通道"改为"通道*高*宽"
transforms.ToTensor(),
# 对图片的每个通道做标准化
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
]) # 测试时,无需对图像做标准化以外的增强数据处理。
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
]) # '../data/kaggle_cifar10、train_valid_test/'
input_str = data_dir + '/' + input_dir + '/' # 读取原始图像文件。flag=1说明输入图像有三个通道(彩色)。
train_ds = vision.ImageFolderDataset(input_str + 'train', flag=1)
valid_ds = vision.ImageFolderDataset(input_str + 'valid', flag=1)
train_valid_ds = vision.ImageFolderDataset(input_str + 'train_valid', flag=1)
test_ds = vision.ImageFolderDataset(input_str + 'test', flag=1) loader = gluon.data.DataLoader
train_data = loader(train_ds.transform_first(transform_train),
batch_size, shuffle=True, last_batch='keep')
valid_data = loader(valid_ds.transform_first(transform_test),
batch_size, shuffle=True, last_batch='keep')
train_valid_data = loader(train_valid_ds.transform_first(transform_train),
batch_size, shuffle=True, last_batch='keep')
test_data = loader(test_ds.transform_first(transform_test),
batch_size, shuffle=False, last_batch='keep') # 交叉熵损失函数。
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

mxnet.gluon.vision.ImageFolderDataset
mxnet.gluon.data.DataLoader

数据的预处理放在DataLoader中,这样后面可以调用ImageFolderDataset,获取原始图片集

至此,数据准备完成。

三、模型定义

1、新的ResNet

from mxnet.gluon import nn
from mxnet import nd class Residual(nn.HybridBlock):
def __init__(self, channels, same_shape=True, **kwargs):
super(Residual, self).__init__(**kwargs)
self.same_shape = same_shape
with self.name_scope():
strides = 1 if same_shape else 2
self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1,
strides=strides)
self.bn1 = nn.BatchNorm()
self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm()
if not same_shape:
self.conv3 = nn.Conv2D(channels, kernel_size=1,
strides=strides) def hybrid_forward(self, F, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if not self.same_shape:
x = self.conv3(x)
return F.relu(out + x) class ResNet(nn.HybridBlock):
def __init__(self, num_classes, verbose=False, **kwargs):
super(ResNet, self).__init__(**kwargs)
self.verbose = verbose
with self.name_scope():
net = self.net = nn.HybridSequential()
# 模块1
net.add(nn.Conv2D(channels=32, kernel_size=3, strides=1, padding=1))
net.add(nn.BatchNorm())
net.add(nn.Activation(activation='relu'))
# 模块2
for _ in range(3):
net.add(Residual(channels=32))
# 模块3
net.add(Residual(channels=64, same_shape=False))
for _ in range(2):
net.add(Residual(channels=64))
# 模块4
net.add(Residual(channels=128, same_shape=False))
for _ in range(2):
net.add(Residual(channels=128))
# 模块5
net.add(nn.AvgPool2D(pool_size=8))
net.add(nn.Flatten())
net.add(nn.Dense(num_classes)) def hybrid_forward(self, F, x):
out = x
for i, b in enumerate(self.net):
out = b(out)
if self.verbose:
print('Block %d output: %s'%(i+1, out.shape))
return out def get_net(ctx):
num_outputs = 10
net = ResNet(num_outputs)
net.initialize(ctx=ctx, init=init.Xavier())
return net

2、迁移学习ResNet

mxnet.gluon.model_zoo中有预训练好的model
通常预训练好的模型由两块构成,一是features,二是output。后者主要包括最后一层全连接层,前者包含从输入开始的大部分层。这样的划分的一个主要目的是为了更方便做微调。

from mxnet.gluon.model_zoo import vision as models

pretrained_net = models.resnet18_v2(pretrained=True)
finetune_net = models.resnet18_v2(classes=10)
finetune_net.features = pretrained_net.features
finetune_net.output.initialize(init.Xavier())

迁移学习网络定义补充说明

尝试使用class定义迁移学习网络,

# 迁移学习

from mxnet.gluon import nn
from mxnet.gluon.model_zoo import vision as models

class ResNet(nn.HybridBlock):
    def __init__(self, num_classes, verbose=False, **kwargs):
        super(ResNet, self).__init__(**kwargs)
        # 获取pretrained=True的模型
        pretrained_net = models.resnet18_v2(pretrained=True)
        # 获取空模型,分类数目为10
        finetune_net = models.resnet18_v2(classes=num_classes)
        self.net = nn.HybridSequential()
        self.net.add(pretrained_net.features)
        self.net.add(finetune_net.output)
        
    def hybrid_forward(self, F, x):
        out = self.net(x)
        return out

def get_net(ctx):
    num_outputs = 10
    net = ResNet(num_outputs)
    # print(net)
    net.net[-1].initialize(init.Xavier())
    net.collect_params().reset_ctx(ctx)
    return net

有几个小总结,

1,实际上class ResNet类包含.net属性,这个属性本身是一个HybridSequential,也就是说和python其他类没什么不同

2,初始化时如果不使用net[-1](表示finetune_net.output层),会warning参数已经初始化,问我们是否强制初始化,这也就是因为pretrained_net已经训练过的原因

3,鉴于上面的class和属性关系,对于其他有子结构的网络class,其实都可以这样区分,例如model_zoo里网络的features结构和output结构的索引获取

4,补充2,实际上参数初始化是以每一个Parameter对象为单位,记录层、模型初始化与否也是参数本身在记录,警报实际上也是一个参数一条而非一层一条

四、训练

gb.accuracy(output, label)

trainer.set_learning_rate(trainer.learning_rate * lr_decay)

gb.evaluate_accuracy(valid_data, net, ctx)

import datetime
import sys
sys.path.append('..')
import gluonbook as gb def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay):
trainer = gluon.Trainer(
net.collect_params(), 'sgd', {'learning_rate': lr, 'momentum': 0.9, 'wd': wd}) prev_time = datetime.datetime.now()
for epoch in range(num_epochs):
train_loss = 0.0
train_acc = 0.0
if epoch > 0 and epoch % lr_period == 0:
trainer.set_learning_rate(trainer.learning_rate * lr_decay)
for data, label in train_data:
label = label.astype('float32').as_in_context(ctx)
with autograd.record():
output = net(data.as_in_context(ctx))
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size)
train_loss += nd.mean(loss).asscalar()
train_acc += gb.accuracy(output, label)
cur_time = datetime.datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
if valid_data is not None:
valid_acc = gb.evaluate_accuracy(valid_data, net, ctx)
epoch_str = ("Epoch %d. Loss: %f, Train acc %f, Valid acc %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_acc))
else:
epoch_str = ("Epoch %d. Loss: %f, Train acc %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data)))
prev_time = cur_time
print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))  

实际训练起来,

ctx = gb.try_gpu()
num_epochs = 1
learning_rate = 0.1
weight_decay = 5e-4
lr_period = 80
lr_decay = 0.1 finetune = False
if finetune:
finetune_net.collect_params().reset_ctx(ctx)
finetune_net.hybridize()
net = finetune_net
else:
net = get_net(ctx)
net.hybridize() train(net, train_data, valid_data, num_epochs, learning_rate,
weight_decay, ctx, lr_period, lr_decay)

五、预测

import numpy as np
import pandas as pd # 训练
net = get_net(ctx)
net.hybridize()
train(net, train_valid_data, None, num_epochs, learning_rate,
weight_decay, ctx, lr_period, lr_decay) # 预测
preds = []
for data, label in test_data:
output = net(data.as_in_context(ctx))
preds.extend(output.argmax(axis=1).astype(int).asnumpy()) sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key = lambda x:str(x)) df = pd.DataFrame({'id': sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.synsets[x])
df.to_csv('submission.csv', index=False)

『MXNet』第九弹_分类器以及迁移学习DEMO的更多相关文章

  1. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  2. 『TensorFlow』第九弹_图像预处理_不爱红妆爱武装

    部分代码单独测试: 这里实践了图像大小调整的代码,值得注意的是格式问题: 输入输出图像时一定要使用uint8编码, 但是数据处理过程中TF会自动把编码方式调整为float32,所以输入时没问题,输出时 ...

  3. 『MXNet』第一弹_基础架构及API

    MXNet是基础,Gluon是封装,两者犹如TensorFlow和Keras,不过得益于动态图机制,两者交互比TensorFlow和Keras要方便得多,其基础操作和pytorch极为相似,但是方便不 ...

  4. 『MXNet』第二弹_Gluon构建模型

    上节用了Sequential类来构造模型.这里我们另外一种基于Block类的模型构造方法,它让构造模型更加灵活,也将让你能更好的理解Sequential的运行机制. 回顾: 序列模型生成 层填充 初始 ...

  5. 『PyTorch』第二弹_张量

    参考:http://www.jianshu.com/p/5ae644748f21# 几个数学概念: 标量(Scalar)是只有大小,没有方向的量,如1,2,3等 向量(Vector)是有大小和方向的量 ...

  6. 『TensorFlow』第二弹_线性拟合&神经网络拟合_恰是故人归

    Step1: 目标: 使用线性模拟器模拟指定的直线:y = 0.1*x + 0.3 代码: import tensorflow as tf import numpy as np import matp ...

  7. 『PyTorch』第一弹_静动态图构建if逻辑对比

    对比TensorFlow和Pytorch的动静态图构建上的差异 静态图框架设计好了不能够修改,且定义静态图时需要使用新的特殊语法,这也意味着图设定时无法使用if.while.for-loop等结构,而 ...

  8. 『MXNet』专题汇总

    MXNet文档 MXNet官方教程 持久化模型 框架介绍 『MXNet』第一弹_基础架构及API 『MXNet』第二弹_Gluon构建模型 『MXNet』第三弹_Gluon模型参数 『MXNet』第四 ...

  9. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

随机推荐

  1. vim命令详解

    VIM编辑常用技巧 vim编辑器 简介: vi: Visual Interface,文本编辑器 文本:ASCII, Unicode 文本编辑种类: 行编辑器: sed 全屏编辑器:nano, vi V ...

  2. linux golang

    wget -c http://www.golangtc.com/static/go/go1.3.linux-386.tar.gz #下载32位Linux的够源码包 tar -zxvf go1.1.li ...

  3. R工具包一网打尽

    这里有很多非常不错的R包和工具. 该想法来自于awesome-machine-learning. 这里是包的导航清单,看起来更方便 >>>导航清单 通过这些翻译了解这些工具包,以后干 ...

  4. 【Java】【异常】

    java中2种方法处理异常:1.在发⽣异常的地方直接处理:2.将异常抛给调用者,让调⽤者处理.异常分类1.检查性异常: java.lang.Exception2.运⾏期异常: java.lang.Ru ...

  5. hdu 6069 Counting Divisors 筛法

    Counting Divisors Time Limit: 10000/5000 MS (Java/Others)    Memory Limit: 524288/524288 K (Java/Oth ...

  6. mac的终端怎么退出git:(master)

    今天在终端误操作,在主目录下执行git init命令,结果杯具了, 总是出现这个提示. 各种搜索解决方案,终于退出了. 方法如下: 删掉.git目录: rm -rf ~/.git

  7. Spring boot2.0 与 2.0以前版本 跨域配置的区别

    一·简介 spring boot升级到2.0后发现继承WebMvcConfigurerAdapter实现跨域过时了,那我们就紧随潮流. 二·全局配置 2.0以前 支持跨域请求代码: import or ...

  8. Java+selenium 爬Boss直聘中职位信息,薪资水平和职位描述

      需要下载合适的selenium webdirver jar包和对应浏览器的驱动jar包 import org.openqa.selenium.By; import org.openqa.selen ...

  9. SqlParameter的两种用法【二】

    private void Loadprovince() { string sql = "select * from Tables where ArealdPid=@pid"; /第 ...

  10. 使用Hexo搭建一个简单的博客(一)

    搭建好简洁的博客框架后,回看时发现,简洁之中透露着一丝丝简陋,好的,网上关于丰富hexo的文章也很多 记录一下自己的一些瞎操作. 在你的hexo目录下,你可以看到themes文件夹里有个默认的land ...