from mxnet import gluon,init
from mxnet.gluon import loss as gloss,nn
from mxnet.gluon import data as gdata
from mxnet import autograd,nd
import gluonbook as gb
import sys # 读取数据
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False) batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4 # 小批量数据迭代器
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False,num_workers=num_workers) # 模型参数初始化
net = nn.Sequential()
net.add(nn.Dense(10))
net.initialize(init.Normal(sigma=0.01)) # 损失函数
loss = gloss.SoftmaxCrossEntropyLoss() # 优化算法
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1}) def accuracy(y_hat, y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar() def evaluate_accuracy(data_iter, net):
acc = 0
for X, y in data_iter:
acc += accuracy(net(X), y)
return acc / len(data_iter) num_epochs = 5 def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None):
for epoch in range(num_epochs):
train_l_sum = 0
train_acc_sum = 0
for X,y in train_iter:
with autograd.record():
y_hat = net(X)
l = loss(y_hat,y)
l.backward() if trainer is None:
gb.sgd(params,lr,batch_size)
else:
trainer.step(batch_size) train_l_sum += l.mean().asscalar() test_acc = evaluate_accuracy(test_iter,net)
print('epoch %d,loss %.4f,test acc %.3f'%(epoch+1,train_l_sum / len(train_iter),test_acc)) train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

gluon实现softmax分类FashionMNIST的更多相关文章

  1. 从零和使用mxnet实现softmax分类

    1.softmax从零实现 from mxnet.gluon import data as gdata from sklearn import datasets from mxnet import n ...

  2. 学习笔记TF010:softmax分类

    回答多选项问题,使用softmax函数,对数几率回归在多个可能不同值上的推广.函数返回值是C个分量的概率向量,每个分量对应一个输出类别概率.分量为概率,C个分量和始终为1.每个样本必须属于某个输出类别 ...

  3. gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init from mxnet.gluon import loss as gloss, nn from mxnet.gluon import data ...

  4. 动手学深度学习7-从零开始完成softmax分类

    获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...

  5. AlexNet 分类 FashionMNIST

    from mxnet import gluon,init,nd,autograd from mxnet.gluon import data as gdata,nn from mxnet.gluon i ...

  6. LeNet 分类 FashionMNIST

    import mxnet as mx from mxnet import autograd, gluon, init, nd from mxnet.gluon import loss as gloss ...

  7. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  8. Keras 多层感知机 多类别的 softmax 分类模型代码

    Multilayer Perceptron (MLP) for multi-class softmax classification: from keras.models import Sequent ...

  9. tf.nn.softmax 分类

    tf.nn.softmax(logits,axis=None,name=None,dim=None) 参数: logits:一个非空的Tensor.必须是下列类型之一:half, float32,fl ...

随机推荐

  1. PIE SDK符号选择器

    1. 功能简介 符号选择器可以根据不同的需求进行改变图层的符号形状以及颜色,下面基于PIE SDK介绍如何使用符号选择器. 2. 功能实现说明 2.1.  实现思路及原理说明 第一步 加载图层 第二步 ...

  2. Spark遇到的报错和坑

    1. Java版本不一致,导致启动报错. # 解决方法: 在启动脚本最前边添加系统参数,指定Java版本 export JAVA_HOME=/usr/java/jdk1..0_181-amd64/jr ...

  3. HexChat访问Docker频道

    1.使用HexChat登录Freenode.net 2.在Freenode下输入并回车: /msg NickServ REGISTER yourpassword youremail@example.c ...

  4. 牛客网Java刷题知识点之ArrayList 、LinkedList 、Vector 的底层实现和区别

    不多说,直接上干货! 这篇我是从整体出发去写的. 牛客网Java刷题知识点之Java 集合框架的构成.集合框架中的迭代器Iterator.集合框架中的集合接口Collection(List和Set). ...

  5. 用js语句控制css样式

    <!DOCTYPE html><html lang="en"><head> <meta charset="UTF-8" ...

  6. 4、加载:Loading

    /* ---html----*/ <ion-content> <button (click)="manual()">手动关闭</button> ...

  7. (四) HTML之表单元素

    HTML中的表单元素,是构成动态网页的重要组成部分,因此,熟知表单元素是十分重要的.下面将根据表单中的一些常用标签进行介绍 1.单选按钮 <input type="radio" ...

  8. Java线程同步打印ABC

    需求: 三个线程,依次打印ABCABCABC.... 方案一: 使用阻塞队列,线程1从队列1获取内容打印,线程2从队列2获取内容打印,线程3从队列3中获取内容打印.线程1把B放到队列3中,线程2把C放 ...

  9. Celery-------周期任务

    在项目目录例子的基础上进行修改一下celery文件 from celery import Celery from celery.schedules import crontab celery_task ...

  10. JS基础学习——对象

    JS基础学习--对象 什么是对象 对象object是JS的一种基本数据类型,除此之外还包括的基本数据类型有string.number.boolean.null.undefined.与其他数据类型不同的 ...