MXNET:权重衰减-gluon实现
构建数据集
# -*- coding: utf-8 -*-
from mxnet import init
from mxnet import ndarray as nd
from mxnet.gluon import loss as gloss
import gb
n_train = 20
n_test = 100
num_inputs = 200
true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05
features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels = nd.dot(features, true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]
数据迭代器
from mxnet import autograd
from mxnet.gluon import data as gdata
batch_size = 1
num_epochs = 10
learning_rate = 0.003
train_iter = gdata.DataLoader(gdata.ArrayDataset(
train_features, train_labels), batch_size, shuffle=True)
loss = gloss.L2Loss()
训练并展示结果
gb.semilogy函数:绘制训练和测试数据的loss
from mxnet import gluon
from mxnet.gluon import nn
def fit_and_plot(weight_decay):
net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(init.Normal(sigma=1))
# 对权重参数做 L2 范数正则化,即权重衰减。
trainer_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd', {
'learning_rate': learning_rate, 'wd': weight_decay})
# 不对偏差参数做 L2 范数正则化。
trainer_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd', {
'learning_rate': learning_rate})
train_ls = []
test_ls = []
for _ in range(num_epochs):
for X, y in train_iter:
with autograd.record():
l = loss(net(X), y)
l.backward()
# 对两个 Trainer 实例分别调用 step 函数。
trainer_w.step(batch_size)
trainer_b.step(batch_size)
train_ls.append(loss(net(train_features),
train_labels).mean().asscalar())
test_ls.append(loss(net(test_features),
test_labels).mean().asscalar())
gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
return 'w[:10]:', net[0].weight.data()[:, :10], 'b:', net[0].bias.data()
print fit_and_plot(5)
- 使用 Gluon 的 wd 超参数可以使用权重衰减来应对过拟合问题。
- 我们可以定义多个 Trainer 实例对不同的模型参数使用不同的迭代方法。
MXNET:权重衰减-gluon实现的更多相关文章
- MXNET:权重衰减
权重衰减是应对过拟合问题的常用方法. \(L_2\)范数正则化 在深度学习中,我们常使用L2范数正则化,也就是在模型原先损失函数基础上添加L2范数惩罚项,从而得到训练所需要最小化的函数. L2范数惩罚 ...
- 调参过程中的参数 学习率,权重衰减,冲量(learning_rate , weight_decay , momentum)
无论是深度学习还是机器学习,大多情况下训练中都会遇到这几个参数,今天依据我自己的理解具体的总结一下,可能会存在错误,还请指正. learning_rate , weight_decay , momen ...
- 权重衰减(weight decay)与学习率衰减(learning rate decay)
本文链接:https://blog.csdn.net/program_developer/article/details/80867468“微信公众号” 1. 权重衰减(weight decay)L2 ...
- 从头学pytorch(六):权重衰减
深度学习中常常会存在过拟合现象,比如当训练数据过少时,训练得到的模型很可能在训练集上表现非常好,但是在测试集上表现不好. 应对过拟合,可以通过数据增强,增大训练集数量.我们这里先不介绍数据增强,先从模 ...
- MxNet新前端Gluon模型转换到Symbol
1. 导入各种包 from mxnet import gluon from mxnet.gluon import nn import matplotlib.pyplot as plt from mxn ...
- 使用MxNet新接口Gluon提供的预训练模型进行微调
1. 导入各种包 from mxnet import gluon import mxnet as mx from mxnet.gluon import nn from mxnet import nda ...
- MXNET:丢弃法
除了前面介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)来应对过拟合问题. 方法与原理 为了确保测试模型的确定性,丢弃法的使用只发生在训练模型时,并非测试模型时.当神经网络中的某一层使 ...
- MXNET:监督学习
线性回归 给定一个数据点集合 X 和对应的目标值 y,线性模型的目标就是找到一条使用向量 w 和位移 b 描述的线,来尽可能地近似每个样本X[i] 和 y[i]. 数学公式表示为\(\hat{y}=X ...
- mxnet深度学习实战学习笔记-9-目标检测
1.介绍 目标检测是指任意给定一张图像,判断图像中是否存在指定类别的目标,如果存在,则返回目标的位置和类别置信度 如下图检测人和自行车这两个目标,检测结果包括目标的位置.目标的类别和置信度 因为目标检 ...
随机推荐
- HDU 1348 Wall 【凸包】
<题目链接> 题目大意: 给出二维坐标轴上 n 个点,这 n 个点构成了一个城堡,国王想建一堵墙,城墙与城堡之间的距离总不小于一个数 L ,求城墙的最小长度,答案四舍五入. 解题分析: 求 ...
- POJ 2337 Catenyms(有向欧拉图:输出欧拉路径)
题目链接>>>>>> 题目大意: 给出一些字符串,问能否将这些字符串 按照 词语接龙,首尾相接 的规则 使得每个字符串出现一次 如果可以 按字典序输出这个字符串 ...
- python 常用模块之random,os,sys 模块
python 常用模块random,os,sys 模块 python全栈开发OS模块,Random模块,sys模块 OS模块 os模块是与操作系统交互的一个接口,常见的函数以及用法见一下代码: #OS ...
- go语言学习-函数
函数声明 函数声明包括函数名,形参列表,返回值列表(可选),函数体组成 func test(parameters) (returns) { // ... } 其中 parameters 就是函数的形参 ...
- 洛谷P2982 [USACO10FEB]慢下来Slowing down(线段树 DFS序 区间增减 单点查询)
To 洛谷.2982 慢下来Slowing down 题目描述 Every day each of Farmer John's N (1 <= N <= 100,000) cows con ...
- Oracle ORA-12541:TNS:无监听程序
背景:自己机子做oracle服务器,其他机子可以ping得通我的机子,但是jdbc就是连不上,后来用plsql连出现无监听程序.... 我昨天重新安装Oracle后,用PL/SQL Developer ...
- 9、SQL逻辑查询语句执行顺序
本篇导航: SELECT语句关键字的定义顺序 SELECT语句关键字的执行顺序 准备表和数据 准备SQL逻辑查询测试语句 执行顺序分析 一.SELECT语句关键字的定义顺序 SELECT DISTIN ...
- scrollView 刷新显示在中间的问题
scrollView问题 打开activity之后 屏幕初始位置不是顶部 而是在中间 也就是scroll滚动条不在上面 而是在中间 楼主你好,我大概是和你遇见了同样的问题,你可以灵活处理一下,不要去管 ...
- 理解TCP之Keepalive
理解HTTP之keep-alive 在前面一篇文章中讲了TCP的keepalive,这篇文章再讲讲HTTP层面keep-alive.两种keepalive在拼写上面就是不一样的,只是发音一样,于是乎大 ...
- [数据结构与算法分析(Mark Allen Weiss)]不相交集 @ Python
最简单的不相交集的实现,来自MAW的<数据结构与算法分析>. 代码: class DisjSet: def __init__(self, NumSets): self.S = [0 for ...