from mxnet import gluon,init,nd,autograd
from mxnet.gluon import data as gdata,nn
from mxnet.gluon import loss as gloss
import mxnet as mx
import time
import os
import sys # 建立网络
net = nn.Sequential()
# 使用较大的 11 x 11 窗口来捕获物体。同时使用步幅 4 来较大减小输出高和宽。
# 这里使用的输入通道数比 LeNet 中的也要大很多。
net.add(nn.Conv2D(96, kernel_size=11, strides=4, activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2),
# 减小卷积窗口,使用填充为 2 来使得输入输出高宽一致,且增大输出通道数。
nn.Conv2D(256, kernel_size=5, padding=2, activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2),
# 连续三个卷积层,且使用更小的卷积窗口。除了最后的卷积层外,进一步增大了输出通道数。
# 前两个卷积层后不使用池化层来减小输入的高和宽。
nn.Conv2D(384, kernel_size=3, padding=1, activation='relu'),
nn.Conv2D(384, kernel_size=3, padding=1, activation='relu'),
nn.Conv2D(256, kernel_size=3, padding=1, activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2),
# 这里全连接层的输出个数比 LeNet 中的大数倍。使用丢弃层来缓解过拟合。
nn.Dense(4096, activation="relu"), nn.Dropout(0.5),
nn.Dense(4096, activation="relu"), nn.Dropout(0.5),
# 输出层。由于这里使用 Fashion-MNIST,所以用类别数为 10,而非论文中的 1000。
nn.Dense(10)) X = nd.random.uniform(shape=(1,1,224,224))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name,'output shape:\t',X.shape) # 读取数据
# fashionMNIST 28*28 转为224*224
def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join(
'~', '.mxnet', 'datasets', 'fashion-mnist')):
root = os.path.expanduser(root) # 展开用户路径 '~'。
transformer = []
if resize:
transformer += [gdata.vision.transforms.Resize(resize)]
transformer += [gdata.vision.transforms.ToTensor()]
transformer = gdata.vision.transforms.Compose(transformer)
mnist_train = gdata.vision.FashionMNIST(root=root, train=True)
mnist_test = gdata.vision.FashionMNIST(root=root, train=False)
num_workers = 0 if sys.platform.startswith('win32') else 4
train_iter = gdata.DataLoader(
mnist_train.transform_first(transformer), batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(
mnist_test.transform_first(transformer), batch_size, shuffle=False,
num_workers=num_workers)
return train_iter, test_iter batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224) def accuracy(y_hat,y):
return (y_hat.argmax(axis=1)==y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter,net,ctx):
acc = nd.array([0],ctx=ctx)
for X,y in data_iter:
X = X.as_in_context(ctx)
y = y.as_in_context(ctx)
acc+=accuracy(net(X),y)
return acc.asscalar() / len(data_iter) # 训练模型
def train(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs):
print('training on',ctx)
loss = gloss.SoftmaxCrossEntropyLoss() for epoch in range(num_epochs):
train_l_sum = 0
train_acc_sum = 0
start = time.time()
for X,y in train_iter:
X = X.as_in_context(ctx)
y = y.as_in_context(ctx) with autograd.record():
y_hat = net(X)
l = loss(y_hat,y) l.backward()
trainer.step(batch_size) train_l_sum += l.mean().asscalar()
train_acc_sum += evaluate_accuracy(test_iter,net,ctx)
test_acc = evaluate_accuracy(test_iter,net,ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch+1,train_l_sum/len(train_iter),test_acc,time.time()-start)) def try_gpu():
try:
ctx = mx.gpu()
_ = nd.zeros((1,),ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx lr = 0.01
num_epochs = 5
ctx = try_gpu() net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':lr})
train(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs)

AlexNet 分类 FashionMNIST的更多相关文章

  1. LeNet 分类 FashionMNIST

    import mxnet as mx from mxnet import autograd, gluon, init, nd from mxnet.gluon import loss as gloss ...

  2. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  3. gluon实现softmax分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...

  4. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  5. Pytorch分类和准确性评估--基于FashionMNIST数据集

    最近在学习Pytorch v1.3最新版和Tensorflow2.0. 我学习Pytorch的主要途径:莫烦Python和Pytorch 1.3官方文档 ,Pytorch v1.3跟之前的Pytorc ...

  6. 【分类】AlexNet论文总结

    目录 0. 论文链接 1. 概述 2. 对数据集的处理 3. 网络模型 3.1 ReLU Nonlinearity 3.2 Training on multiple GPUs 3.3 Local Re ...

  7. AlexNet实现cifar10数据集分类

    import tensorflow as tf import os from matplotlib import pyplot as plt import tensorflow.keras.datas ...

  8. 从头学pytorch(十五):AlexNet

    AlexNet AlexNet是2012年提出的一个模型,并且赢得了ImageNet图像识别挑战赛的冠军.首次证明了由计算机自动学习到的特征可以超越手工设计的特征,对计算机视觉的研究有着极其重要的意义 ...

  9. 《动手学深度学习》系列笔记—— 1.2 Softmax回归与分类模型

    目录 softmax的基本概念 交叉熵损失函数 模型训练和预测 获取Fashion-MNIST训练集和读取数据 get dataset softmax从零开始的实现 获取训练集数据和测试集数据 模型参 ...

随机推荐

  1. 架构实战项目心得(五):mysql安装

    1.  yum安装mysql yum -y install mysql-server 2.  启动mysql服务 启动mysql:service mysqld start 查看mysql的状态:ser ...

  2. PostgreSQL Entity Framework 自动迁移

    1.依次添加NuGet包 EntityFramework.Npgsql.EntityFramework6.Npgsql,会自动生成一些配置文件,不过缺少数据库驱动的配置节点: <system.d ...

  3. Java - XPath解析爬取内容

    code { margin: 0; padding: 0; white-space: pre; border: none; background: transparent; } pre { backg ...

  4. Java接口和抽象类理解(New)

    一. 抽象类和接口的特点  包含抽象方法的类称为抽象类,但并不意味着抽象类中只能有抽象方法,它和普通类一样,同样可以拥有成员变量和普通的成员方法.注意,抽象类和普通类的主要有三点区别: 1)抽象方法必 ...

  5. Csharp: speech to text, text to speech in win

    using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; usin ...

  6. IE8 td元素 width无效的bug;

    不经意间做项目发现IE的td在某种情况下好奇怪,自己设置的width不起作用: 后经google大法,发现解决方案:已验证过完美解决bug; <table style="width:  ...

  7. AngularJS之控制器

    控制器在Angularjs中的作用是增强视图,它实际就是一个函数,用来向视图中的作用域添加额外的功能,我们用它来给作用域对象设置初始状态,并添加自定义行为. 当我们在页面上创建一个控制器时,Angul ...

  8. Linux-学习笔记(PHP向)<一>

    Linux常用命令 使用PHP服务器端脚本编程语言进行网站开发,需要在lamp环境下进行,Linux作为”四剑客”之一是有必要了解熟悉的,而Linux系统并不像windows操作系统那样,以图形化的界 ...

  9. Java从入门到精通——数据库篇Mongo DB 安装启动及配置详解

    一.概述     Mongo DB 下载下来以后我们应该如何去安装启动和配置才能使用Mongo DB,本篇博客就给大家讲述一下Mongo DB的安装启动及配置详解. 二.安装 1.下载Mongo DB ...

  10. 聚合maven+spring-boot打包可执行jar

    整整搞了一天,终于解决这个问题了.这里是四个module,module之间存在依赖,打包两个可执行jar,看下最终效果吧 聚合maven+spring-boot的搭建很简单,和普通的聚合maven没有 ...