多层感知机中:

hi 以 p 的概率被丢弃,以 1-p 的概率被拉伸,除以  1 - p

import mxnet as mx
import sys
import os
import time
import gluonbook as gb
from mxnet import autograd,init
from mxnet import nd,gluon
from mxnet.gluon import data as gdata,nn
from mxnet.gluon import loss as gloss '''
# 模型参数
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784,10,256,256 W1 = nd.random.normal(scale=0.01,shape=(num_inputs,num_hiddens1))
b1 = nd.zeros(num_hiddens1) W2 = nd.random.normal(scale=0.01,shape=(num_hiddens1,num_hiddens2))
b2 = nd.zeros(num_hiddens2) W3 = nd.random.normal(scale=0.01,shape=(num_hiddens2,num_outputs))
b3 = nd.zeros(num_outputs) params = [W1,b1,W2,b2,W3,b3] for param in params:
param.attach_grad() # 定义网络 '''
# 读取数据
# fashionMNIST 28*28 转为224*224
def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join(
'~', '.mxnet', 'datasets', 'fashion-mnist')):
root = os.path.expanduser(root) # 展开用户路径 '~'。
transformer = []
if resize:
transformer += [gdata.vision.transforms.Resize(resize)]
transformer += [gdata.vision.transforms.ToTensor()]
transformer = gdata.vision.transforms.Compose(transformer)
mnist_train = gdata.vision.FashionMNIST(root=root, train=True)
mnist_test = gdata.vision.FashionMNIST(root=root, train=False)
num_workers = 0 if sys.platform.startswith('win32') else 4
train_iter = gdata.DataLoader(
mnist_train.transform_first(transformer), batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(
mnist_test.transform_first(transformer), batch_size, shuffle=False,
num_workers=num_workers)
return train_iter, test_iter # 定义网络
drop_prob1,drop_prob2 = 0.2,0.5
# Gluon版
net = nn.Sequential()
net.add(nn.Dense(256,activation="relu"),
nn.Dropout(drop_prob1),
nn.Dense(256,activation="relu"),
nn.Dropout(drop_prob2),
nn.Dense(10)
)
net.initialize(init.Normal(sigma=0.01)) # 训练模型 def accuracy(y_hat, y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()
def evaluate_accuracy(data_iter, net):
acc = 0
for X, y in data_iter:
acc += accuracy(net(X), y)
return acc / len(data_iter) def train(net, train_iter, test_iter, loss, num_epochs, batch_size,
params=None, lr=None, trainer=None):
for epoch in range(num_epochs):
train_l_sum = 0
train_acc_sum = 0
for X, y in train_iter:
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
if trainer is None:
gb.sgd(params, lr, batch_size)
else:
trainer.step(batch_size) # 下一节将用到。
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter), test_acc)) num_epochs = 5
lr = 0.5
batch_size = 256
loss = gloss.SoftmaxCrossEntropyLoss()
train_iter, test_iter = load_data_fashion_mnist(batch_size) trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':lr})
train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

Gluon 实现 dropout 丢弃法的更多相关文章

  1. 【神经网络】丢弃法(dropout)

    丢弃法是一种降低过拟合的方法,具体过程是在神经网络传播的过程中,随机"沉默"一些节点.这个行为让模型过度贴合训练集的难度更高. 添加丢弃层后,训练速度明显上升,在同样的轮数下测试集 ...

  2. MXNET:丢弃法

    除了前面介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)来应对过拟合问题. 方法与原理 为了确保测试模型的确定性,丢弃法的使用只发生在训练模型时,并非测试模型时.当神经网络中的某一层使 ...

  3. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  4. dropout——gluon

    https://blog.csdn.net/lizzy05/article/details/80162060 from mxnet import nd def dropout(X, drop_prob ...

  5. 动手学深度学习14- pytorch Dropout 实现与原理

    方法 从零开始实现 定义模型参数 网络 评估函数 优化方法 定义损失函数 数据提取与训练评估 pytorch简洁实现 小结 针对深度学习中的过拟合问题,通常使用丢弃法(dropout),丢弃法有很多的 ...

  6. 神经网络优化算法:Dropout、梯度消失/爆炸、Adam优化算法,一篇就够了!

    1. 训练误差和泛化误差 机器学习模型在训练数据集和测试数据集上的表现.如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训练数据集上更准确时,它在测试数据集上却不⼀定更准确.这是为什么呢 ...

  7. 从头学pytorch(七):dropout防止过拟合

    上一篇讲了防止过拟合的一种方式,权重衰减,也即在loss上加上一部分\(\frac{\lambda}{2n} \|\boldsymbol{w}\|^2\),从而使得w不至于过大,即不过分偏向某个特征. ...

  8. dropout总结

    1.伯努利分布:伯努利分布亦称“零一分布”.“两点分布”.称随机变量X有伯努利分布, 参数为p(0<p<1),如果它分别以概率p和1-p取1和0为值.EX= p,DX=p(1-p). 2. ...

  9. mxnet(gluon)—— 模型、数据集、损失函数、优化子等类、接口大全

    1. 数据集 dataset_train = gluon.data.ArrayDataset(X_train, y_train) data_iter = gluon.data.DataLoader(d ...

随机推荐

  1. jenkins启动脚本

    [root@localhost system]# cat /etc/init.d/jenkins #!/bin/sh # # SUSE system statup script for Jenkins ...

  2. Redis 小结

    一.redis简介 redis是一款基于C语言编写的,开源的非关系型数据库,由于其卓越的数据处理机制(按照规则,将常用的部分数据放置缓存,其余数据序列化到硬盘),大家也通常将其当做缓存服务器来使用. ...

  3. 【request获取用户请求ip】

    1:request.getRemoteAddr() 2:如果请求的客户端使用了nginx 等反向代理发送请求的时候:就不能获取到真是的ip地址了:如:将http://192.168.1.110:204 ...

  4. Firebird 列可空非空修改

    2018-12-04 至少到Firebird 3.0.4 已经添加了设置可空 和 非空的语法:如 -- 删除非空(设置为可空) ALTER TABLE TECH ALTER label drop NO ...

  5. XmlSerialize

    以前配置文件都直接写在TXT文件,能看懂就行: 后来写了点代码,就把配置写在ini文件里: 再后来随着趋势就把配置类序列化到本地,即xml配置: 现在懒了,直接ToJson到本地,需要时FromJso ...

  6. MySql数据库与JDBC编程

    JDBC -- Java Database Connectivity,即Java数据库连接,通过使用JDBC就可以使用同一种API访问不同的数据库 SQL语句基础(SQL结构化查询语言) 能完成的任务 ...

  7. [android] socket在手机上的应用

    1.手机助手 1.1 USB链接 可以读取手机的PID和VID,确定唯一的设备,可以给手机安装对应的驱动等 socket在固定端口通信 1.2 WIFI链接 pc在电脑在整个网段发送UDP数据包,手机 ...

  8. Linux必会必知

    一.前言 Linux作为一个开源系统,被极客极力推崇,作为程序员不来了解一下,那就亏了 Linux是一种自由和开放源代码的类UNIX操作系统.该操作系统的内核由林纳斯·托瓦兹在1991年10月5日首次 ...

  9. 域对象中属性变更及感知session绑定的事件监听器

    域对象中属性的变更的时间监听器就是用来监听ServletContext,HttpSession,HttpServletRequest这三个对象中的属性变更信息事件的监听器.这三个监听器接口分别是Ser ...

  10. Java String、string[]、List初始化方法

    String初始化: 1.String str = new String("string1"); 2.String str = "string1"; Strin ...