%matplotlib inline
import mxnet
from mxnet import nd,autograd
from mxnet import gluon,init
from mxnet.gluon import data as gdata,loss as gloss,nn
import gluonbook as gb n_train, n_test, num_inputs = 20,100,200 true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05 features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels = nd.dot(features,true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape) train_feature = features[:n_train,:]
test_feature = features[n_train:,:]
train_labels = labels[:n_train]
test_labels = labels[n_train:] #print(features,train_feature,test_feature) # 初始化模型参数
def init_params():
w = nd.random.normal(scale=1, shape=(num_inputs, 1))
b = nd.zeros(shape=(1,))
w.attach_grad()
b.attach_grad()
return [w,b] # 定义,训练,测试 batch_size = 1
num_epochs = 100
lr = 0.03 train_iter = gdata.DataLoader(gdata.ArrayDataset(train_feature,train_labels),batch_size=batch_size,shuffle=True) # 定义网络
def linreg(X, w, b):
return nd.dot(X,w) + b # 损失函数
def squared_loss(y_hat, y):
"""Squared loss."""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2 # L2 范数惩罚
def l2_penalty(w):
return (w**2).sum() / 2 def sgd(params, lr, batch_size):
for param in params:
param[:] = param - lr * param.grad / batch_size def fit_and_plot(lambd):
w, b = init_params()
train_ls, test_ls = [], []
for _ in range(num_epochs):
for X, y in train_iter:
with autograd.record():
# 添加了 L2 范数惩罚项。
l = squared_loss(linreg(X, w, b), y) + lambd * l2_penalty(w)
l.backward()
sgd([w, b], lr, batch_size)
train_ls.append(squared_loss(linreg(train_feature, w, b),
train_labels).mean().asscalar())
test_ls.append(squared_loss(linreg(test_feature, w, b),
test_labels).mean().asscalar())
gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
print('L2 norm of w:', w.norm().asscalar())
fit_and_plot(0)
fit_and_plot(3)

训练集太少,容易出现过拟合,即训练集loss远小于测试集loss,解决方案,权重衰减——(L2范数正则化)

例如线性回归:

loss(w1,w2,b) = 1/n * sum(x1w1 + x2w2 + b - y)^2 /2 ,平方损失函数。

权重参数 w = [w1,w2],

新损失函数 loss(w1,w2,b) += lambd / 2n *||w||^2

迭代方程:

L2范数惩罚项,高维线性回归的更多相关文章

  1. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  2. 机器学习中的范数规则化 L0、L1与L2范数 核范数与规则项参数选择

    http://blog.csdn.net/zouxy09/article/details/24971995 机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http: ...

  3. 机器学习中的范数规则化之 L0、L1与L2范数、核范数与规则项参数选择

    装载自:https://blog.csdn.net/u012467880/article/details/52852242 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化.我们先简单的来理 ...

  4. 《机器学习实战》学习笔记第八章 —— 线性回归、L1、L2范数正则项

    相关笔记: 吴恩达机器学习笔记(一) —— 线性回归 吴恩达机器学习笔记(三) —— Regularization正则化 ( 问题遗留: 小可只知道引入正则项能降低参数的取值,但为什么能保证 Σθ2  ...

  5. deep learning (五)线性回归中L2范数的应用

    cost function 加一个正则项的原因是防止产生过拟合现象.正则项有L1,L2 等范数,我看过讲的最好的是这个博客上的:机器学习中的范数规则化之(一)L0.L1与L2范数.看完应该就答题明白了 ...

  6. paper 126:[转载] 机器学习中的范数规则化之(一)L0、L1与L2范数

    机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http://blog.csdn.net/zouxy09 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化. ...

  7. 机器学习中的范数规则化之(一)L0、L1与L2范数(转)

    http://blog.csdn.net/zouxy09/article/details/24971995 机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http: ...

  8. L0、L1与L2范数、核范数(转)

    L0.L1与L2范数.核范数 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化.我们先简单的来理解下常用的L0.L1.L2和核范数规则化.最后聊下规则化项参数的选择问题.这里因为篇幅比较庞大 ...

  9. 机器学习中的范数规则化之(一)L0、L1与L2范数 非常好,必看

    机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http://blog.csdn.net/zouxy09 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化. ...

随机推荐

  1. mysql GPID学习

    1.为什么引入GPID? 解决主备复制的延时问题 单线程太慢, 多线程复制的问题是:最终数据可能不一致 MySQL主从延时这么长,要怎么优化? 2. 引入后有哪些缺点 不支持create table ...

  2. ZABBIX 监控基本报警故障

    CPU触发器: 1)Processor load is too high on {HOST.NAME} {HOST.NAME}上处理器负载太高 触发器表达式:{Zabbix server:system ...

  3. js栈内存和堆内存的区别

    首先JavaScript中的变量分为基本类型和引用类型.基本类型就是保存在栈内存中的简单数据段,而引用类型指的是那些保存在堆内存中的对象. 1.基本类型 基本类型有Undefined.Null.Boo ...

  4. Web请求过程总结

    Web请求过程总结 1.CND架构图 图片来源:深入分析JavaWeb技术内幕(许令波著) 2.发起HTTP请求 发起一个HTTP请求就是浏览器建立socket通信的过程,HttpClient开源的通 ...

  5. MYSQL冷知识——ON DUPLICATE KEY 批量增删改

    一 有个需求要批量增删改,并且是混合的,也就是仅不存在才增. 删简单,因为有个deleteStaute之类的字段,删除本质上就是就是一个修改 所以就是实现批量混合增改,然而组长说mysql不支持混合增 ...

  6. jqgrid 上移下移单元格

    在表格中常常需要调整表格中数据的显示顺序,我用的是jqgrid,实现原理就是将表中的行数保存到数据库中,取数据时按行进行排序 1.上移,下移按钮 <a href="javascript ...

  7. 深入理解JavaScript系列(3):全面解析Module模式

    简介 Module模式是JavaScript编程中一个非常通用的模式,一般情况下,大家都知道基本用法,本文尝试着给大家更多该模式的高级使用方式. 首先我们来看看Module模式的基本特征: 模块化,可 ...

  8. s中的闭包

    今天看了关于js闭包方面的文章,还是有些云里雾里,对于一个菜鸟来说,学习闭包确实有一定的难度,不说别的,能够在网上找到一篇优秀的是那样的不易. 当然之所以闭包难理解,个人觉得是基础知识掌握的不牢,因为 ...

  9. [转]C# - JSON详解

    本文转自:http://www.cnblogs.com/QLJ1314/p/3862583.html 最近在做微信开发时用到了一些json的问题,就是把微信返回回来的一些json数据做一些处理,但是之 ...

  10. C#基础(第一天)

    Ctrl+K+D:对其代码: #Region      #endRegion:折叠多余代码: Ctrl+K+S:可以折叠代码写注释: 语法格式:数据类型  变量名:                  ...