构建数据集

# -*- coding: utf-8 -*-
from mxnet import init
from mxnet import ndarray as nd
from mxnet.gluon import loss as gloss
import gb n_train = 20
n_test = 100 num_inputs = 200
true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05
features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels = nd.dot(features, true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

数据迭代器

from mxnet import autograd
from mxnet.gluon import data as gdata batch_size = 1
num_epochs = 10
learning_rate = 0.003 train_iter = gdata.DataLoader(gdata.ArrayDataset(
train_features, train_labels), batch_size, shuffle=True)
loss = gloss.L2Loss()

训练并展示结果

gb.semilogy函数:绘制训练和测试数据的loss

from mxnet import gluon
from mxnet.gluon import nn def fit_and_plot(weight_decay):
net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(init.Normal(sigma=1))
# 对权重参数做 L2 范数正则化,即权重衰减。
trainer_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd', {
'learning_rate': learning_rate, 'wd': weight_decay})
# 不对偏差参数做 L2 范数正则化。
trainer_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd', {
'learning_rate': learning_rate})
train_ls = []
test_ls = []
for _ in range(num_epochs):
for X, y in train_iter:
with autograd.record():
l = loss(net(X), y)
l.backward()
# 对两个 Trainer 实例分别调用 step 函数。
trainer_w.step(batch_size)
trainer_b.step(batch_size)
train_ls.append(loss(net(train_features),
train_labels).mean().asscalar())
test_ls.append(loss(net(test_features),
test_labels).mean().asscalar())
gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
return 'w[:10]:', net[0].weight.data()[:, :10], 'b:', net[0].bias.data()
print fit_and_plot(5)
  • 使用 Gluon 的 wd 超参数可以使用权重衰减来应对过拟合问题。
  • 我们可以定义多个 Trainer 实例对不同的模型参数使用不同的迭代方法。

MXNET:权重衰减-gluon实现的更多相关文章

  1. MXNET:权重衰减

    权重衰减是应对过拟合问题的常用方法. \(L_2\)范数正则化 在深度学习中,我们常使用L2范数正则化,也就是在模型原先损失函数基础上添加L2范数惩罚项,从而得到训练所需要最小化的函数. L2范数惩罚 ...

  2. 调参过程中的参数 学习率,权重衰减,冲量(learning_rate , weight_decay , momentum)

    无论是深度学习还是机器学习,大多情况下训练中都会遇到这几个参数,今天依据我自己的理解具体的总结一下,可能会存在错误,还请指正. learning_rate , weight_decay , momen ...

  3. 权重衰减(weight decay)与学习率衰减(learning rate decay)

    本文链接:https://blog.csdn.net/program_developer/article/details/80867468“微信公众号” 1. 权重衰减(weight decay)L2 ...

  4. 从头学pytorch(六):权重衰减

    深度学习中常常会存在过拟合现象,比如当训练数据过少时,训练得到的模型很可能在训练集上表现非常好,但是在测试集上表现不好. 应对过拟合,可以通过数据增强,增大训练集数量.我们这里先不介绍数据增强,先从模 ...

  5. MxNet新前端Gluon模型转换到Symbol

    1. 导入各种包 from mxnet import gluon from mxnet.gluon import nn import matplotlib.pyplot as plt from mxn ...

  6. 使用MxNet新接口Gluon提供的预训练模型进行微调

    1. 导入各种包 from mxnet import gluon import mxnet as mx from mxnet.gluon import nn from mxnet import nda ...

  7. MXNET:丢弃法

    除了前面介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)来应对过拟合问题. 方法与原理 为了确保测试模型的确定性,丢弃法的使用只发生在训练模型时,并非测试模型时.当神经网络中的某一层使 ...

  8. MXNET:监督学习

    线性回归 给定一个数据点集合 X 和对应的目标值 y,线性模型的目标就是找到一条使用向量 w 和位移 b 描述的线,来尽可能地近似每个样本X[i] 和 y[i]. 数学公式表示为\(\hat{y}=X ...

  9. mxnet深度学习实战学习笔记-9-目标检测

    1.介绍 目标检测是指任意给定一张图像,判断图像中是否存在指定类别的目标,如果存在,则返回目标的位置和类别置信度 如下图检测人和自行车这两个目标,检测结果包括目标的位置.目标的类别和置信度 因为目标检 ...

随机推荐

  1. HDU 1348 Wall 【凸包】

    <题目链接> 题目大意: 给出二维坐标轴上 n 个点,这 n 个点构成了一个城堡,国王想建一堵墙,城墙与城堡之间的距离总不小于一个数 L ,求城墙的最小长度,答案四舍五入. 解题分析: 求 ...

  2. POJ 2337 Catenyms(有向欧拉图:输出欧拉路径)

    题目链接>>>>>> 题目大意: 给出一些字符串,问能否将这些字符串  按照 词语接龙,首尾相接  的规则 使得每个字符串出现一次 如果可以 按字典序输出这个字符串 ...

  3. python 常用模块之random,os,sys 模块

    python 常用模块random,os,sys 模块 python全栈开发OS模块,Random模块,sys模块 OS模块 os模块是与操作系统交互的一个接口,常见的函数以及用法见一下代码: #OS ...

  4. go语言学习-函数

    函数声明 函数声明包括函数名,形参列表,返回值列表(可选),函数体组成 func test(parameters) (returns) { // ... } 其中 parameters 就是函数的形参 ...

  5. 洛谷P2982 [USACO10FEB]慢下来Slowing down(线段树 DFS序 区间增减 单点查询)

    To 洛谷.2982 慢下来Slowing down 题目描述 Every day each of Farmer John's N (1 <= N <= 100,000) cows con ...

  6. Oracle ORA-12541:TNS:无监听程序

    背景:自己机子做oracle服务器,其他机子可以ping得通我的机子,但是jdbc就是连不上,后来用plsql连出现无监听程序.... 我昨天重新安装Oracle后,用PL/SQL Developer ...

  7. 9、SQL逻辑查询语句执行顺序

    本篇导航: SELECT语句关键字的定义顺序 SELECT语句关键字的执行顺序 准备表和数据 准备SQL逻辑查询测试语句 执行顺序分析 一.SELECT语句关键字的定义顺序 SELECT DISTIN ...

  8. scrollView 刷新显示在中间的问题

    scrollView问题 打开activity之后 屏幕初始位置不是顶部 而是在中间 也就是scroll滚动条不在上面 而是在中间 楼主你好,我大概是和你遇见了同样的问题,你可以灵活处理一下,不要去管 ...

  9. 理解TCP之Keepalive

    理解HTTP之keep-alive 在前面一篇文章中讲了TCP的keepalive,这篇文章再讲讲HTTP层面keep-alive.两种keepalive在拼写上面就是不一样的,只是发音一样,于是乎大 ...

  10. [数据结构与算法分析(Mark Allen Weiss)]不相交集 @ Python

    最简单的不相交集的实现,来自MAW的<数据结构与算法分析>. 代码: class DisjSet: def __init__(self, NumSets): self.S = [0 for ...