LeNet 分类 FashionMNIST
import mxnet as mx
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn
from mxnet.gluon import data as gdata
import time
import sys net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
nn.MaxPool2D(pool_size=2, strides=2),
# Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
# (批量大小,通道 * 高 * 宽)形状的输入。
nn.Dense(120, activation='sigmoid'),
nn.Dense(84, activation='sigmoid'),
nn.Dense(10)) X = nd.random.uniform(shape=(1, 1, 28, 28))
net.initialize()
for layer in net:
X = layer(X)
print(layer.name, 'output shape:\t', X.shape) # batch_size = 256
# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器(在cpu上)
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
num_workers=num_workers) def try_gpu4():
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx ctx = try_gpu4() def accuracy(y_hat,y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net, ctx):
acc = nd.array([0], ctx=ctx)
for X, y in data_iter:
# 如果 ctx 是 GPU,将数据复制到 GPU 上。
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
acc += accuracy(net(X), y)
return acc.asscalar() / len(data_iter) def train(net, train_iter, test_iter, batch_size, trainer, ctx,
num_epochs):
print('training on', ctx)
loss = gloss.SoftmaxCrossEntropyLoss()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, start = 0, 0, time.time()
for X, y in train_iter:
X, y = X.as_in_context(ctx), y.as_in_context(ctx)
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
trainer.step(batch_size)
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net, ctx)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter),
test_acc, time.time() - start)) lr, num_epochs = 0.9, 200
net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier()) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

LeNet 分类 FashionMNIST的更多相关文章
- AlexNet 分类 FashionMNIST
from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...
- gluon 实现多层感知机MLP分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...
- gluon实现softmax分类FashionMNIST
from mxnet import gluon,init from mxnet.gluon import loss as gloss,nn from mxnet.gluon import data a ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- CNN卷积神经网络详解
前言 在学计算机视觉的这段时间里整理了不少的笔记,想着就把这些笔记再重新整理出来,然后写成Blog和大家一起分享.目前的计划如下(以下网络全部使用Pytorch搭建): 专题一:计算机视觉基础 介 ...
- 不用写代码就能实现深度学习?手把手教你用英伟达 DIGITS 解决图像分类问题
2006年,机器学习界泰斗Hinton,在Science上发表了一篇使用深度神经网络进行维数约简的论文 ,自此,神经网络再次走进人们的视野,进而引发了一场深度学习革命.深度学习之所以如此受关注,是因为 ...
- PyTorch 介绍 | BUILD THE NEURAL NETWORK
神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...
- Tensorflow学习教程------lenet多标签分类
本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...
- Tensorflow学习教程------实现lenet并且进行二分类
#coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...
随机推荐
- OutSystems学习笔记。
ew job and new software, new challenge as well. OutSystems这软件挺好上手的.虽然没有中文文档,但英文文档超级详细,堪称傻瓜版SOP 照着步骤写 ...
- FTPS Firewall
989 for the FTPS data channel implicit FTPS was expected to listen on the IANA Well Known Port 990/T ...
- Tomcat配置连接池的java实现
1.准备 JNDI(Java Naming and Directory Interface),Java命名和目录接口.JNDI的作用就是:在服务器上配置资源,然后通过统一的方式来获取配置的资源.我们这 ...
- Markdown调查
Markdown调查 一.Editor.md 文档详细,使用者较多 1.1 主要特性 支持“标准”Markdown / CommonMark和Github风格的语法,也可变身为代码编辑器: 支持实 ...
- Sequence contains no elements : LINQ error
1.错误意思: 出现错误的原因是:你要从一个null中取的数据. 2.错误的处理 1,使用FirstOrDefault() 来代替 First() 2.使用SingleOrDefault 来代替 Si ...
- SZU1
CodeForces 518A 字符串进位.. #include <iostream> #include <string> #include <cstring> # ...
- Iphone各个型号机型的详细参数,尺寸和dpr以及像素
1.iPhone尺寸规格 2.单位inch(英吋) 1 inch = 2.54cm = 25.4mm 3.iPhone手机宽高 上表中的宽高(width/height)为手机的物理尺寸,包括显示屏和边 ...
- js 密码 正则表达式
1. 代码 function checkPassword(str){ var reg1 = /[!@#$%^&*()_?<>{}]{1}/; var reg2 = /([a-zA- ...
- bzoj2119 [ZJOI2010]base基站选址
传送门 n年前的考试题,今天才填上…… 听说你们会决策单调性+主席树?然而我多年不写决策单调性,懒得写了……于是就写了一发线段树. 其实线段树应该不难想,毕竟转移是分层转移,并且这个题的转移函数可以快 ...
- CSS 2D转换
转换是使元素改变形状.尺寸和位置的一种效果.通过 CSS3 转换,我们能够对元素进行移动.缩放.转动.拉长或拉伸,可以大致分为2D转换和3D转换.下面介绍的是2D转换的相关知识点. 首先,CSS中2D ...