L2范数惩罚项,高维线性回归
- %matplotlib inline
- import mxnet
- from mxnet import nd,autograd
- from mxnet import gluon,init
- from mxnet.gluon import data as gdata,loss as gloss,nn
- import gluonbook as gb
- n_train, n_test, num_inputs = 20,100,200
- true_w = nd.ones((num_inputs, 1)) * 0.01
- true_b = 0.05
- features = nd.random.normal(shape=(n_train+n_test, num_inputs))
- labels = nd.dot(features,true_w) + true_b
- labels += nd.random.normal(scale=0.01, shape=labels.shape)
- train_feature = features[:n_train,:]
- test_feature = features[n_train:,:]
- train_labels = labels[:n_train]
- test_labels = labels[n_train:]
- #print(features,train_feature,test_feature)
- # 初始化模型参数
- def init_params():
- w = nd.random.normal(scale=1, shape=(num_inputs, 1))
- b = nd.zeros(shape=(1,))
- w.attach_grad()
- b.attach_grad()
- return [w,b]
- # 定义,训练,测试
- batch_size = 1
- num_epochs = 100
- lr = 0.03
- train_iter = gdata.DataLoader(gdata.ArrayDataset(train_feature,train_labels),batch_size=batch_size,shuffle=True)
- # 定义网络
- def linreg(X, w, b):
- return nd.dot(X,w) + b
- # 损失函数
- def squared_loss(y_hat, y):
- """Squared loss."""
- return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
- # L2 范数惩罚
- def l2_penalty(w):
- return (w**2).sum() / 2
- def sgd(params, lr, batch_size):
- for param in params:
- param[:] = param - lr * param.grad / batch_size
- def fit_and_plot(lambd):
- w, b = init_params()
- train_ls, test_ls = [], []
- for _ in range(num_epochs):
- for X, y in train_iter:
- with autograd.record():
- # 添加了 L2 范数惩罚项。
- l = squared_loss(linreg(X, w, b), y) + lambd * l2_penalty(w)
- l.backward()
- sgd([w, b], lr, batch_size)
- train_ls.append(squared_loss(linreg(train_feature, w, b),
- train_labels).mean().asscalar())
- test_ls.append(squared_loss(linreg(test_feature, w, b),
- test_labels).mean().asscalar())
- gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
- range(1, num_epochs + 1), test_ls, ['train', 'test'])
- print('L2 norm of w:', w.norm().asscalar())
- fit_and_plot(0)
fit_and_plot(3)
训练集太少,容易出现过拟合,即训练集loss远小于测试集loss,解决方案,权重衰减——(L2范数正则化)
例如线性回归:
loss(w1,w2,b) = 1/n * sum(x1w1 + x2w2 + b - y)^2 /2 ,平方损失函数。
权重参数 w = [w1,w2],
新损失函数 loss(w1,w2,b) += lambd / 2n *||w||^2
迭代方程:
L2范数惩罚项,高维线性回归的更多相关文章
- 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播
下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...
- 机器学习中的范数规则化 L0、L1与L2范数 核范数与规则项参数选择
http://blog.csdn.net/zouxy09/article/details/24971995 机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http: ...
- 机器学习中的范数规则化之 L0、L1与L2范数、核范数与规则项参数选择
装载自:https://blog.csdn.net/u012467880/article/details/52852242 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化.我们先简单的来理 ...
- 《机器学习实战》学习笔记第八章 —— 线性回归、L1、L2范数正则项
相关笔记: 吴恩达机器学习笔记(一) —— 线性回归 吴恩达机器学习笔记(三) —— Regularization正则化 ( 问题遗留: 小可只知道引入正则项能降低参数的取值,但为什么能保证 Σθ2 ...
- deep learning (五)线性回归中L2范数的应用
cost function 加一个正则项的原因是防止产生过拟合现象.正则项有L1,L2 等范数,我看过讲的最好的是这个博客上的:机器学习中的范数规则化之(一)L0.L1与L2范数.看完应该就答题明白了 ...
- paper 126:[转载] 机器学习中的范数规则化之(一)L0、L1与L2范数
机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http://blog.csdn.net/zouxy09 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化. ...
- 机器学习中的范数规则化之(一)L0、L1与L2范数(转)
http://blog.csdn.net/zouxy09/article/details/24971995 机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http: ...
- L0、L1与L2范数、核范数(转)
L0.L1与L2范数.核范数 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化.我们先简单的来理解下常用的L0.L1.L2和核范数规则化.最后聊下规则化项参数的选择问题.这里因为篇幅比较庞大 ...
- 机器学习中的范数规则化之(一)L0、L1与L2范数 非常好,必看
机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http://blog.csdn.net/zouxy09 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化. ...
随机推荐
- echarts Y轴数据类型不同怎么让折线图显示差距不大
如果希望在同一grid中展示不同数据类型的折线(1000或10%),那么展现出来的折线肯定显示差距很大,那么怎么让这两条折线显示效果差不多,在之前的项目中碰到了这个问题 每条折线对应的是不同的数据组, ...
- spark第二篇:Application Submission Guide
提交应用 Spark的bin目录中的spark-submit脚本用于启动集群上的应用程序.它可以通过一个统一的接口使用所有Spark支持的集群管理器. 绑定应用程序的依赖 如果你的代码依赖其他项目,你 ...
- 有关tensorflow一些问题
1.python版本 采用64位的python 2.系统不支持高版本tensorflow(>1.6),运行报错如下: 问题描述如下: ImportError: DLL load failed: ...
- nginx 问题总结
1, 403错误 403是很常见的错误代码,一般就是未授权被禁止访问的意思. 可能的原因有两种:Nginx程序用户无权限访问web目录文件Nginx需要访问目录,但是autoindex选项被关闭 修复 ...
- Dubbo解析及原理浅析
原文链接:https://blog.csdn.net/chao_19/article/details/51764150 一.Duboo基本概念解释 Dubbo是一种分布式服务框架. Webservic ...
- Composite Design Pattern in Java--转
https://dzone.com/articles/composite-design-pattern-in-java-1 The composite pattern is meant to &quo ...
- WPF 窗体在Alt+Tab中隐藏
问题: 近段时间由于项目上的需求,需要在WPF中使用COM组件,并且由于软件界面设计等等原因,需要将部分控件显示在COM组件之上,由于WindowsFormsHost的一些原因,导致继承在WPF中的W ...
- BNU34058——干了这桶冰红茶!——————【递推】
干了这桶冰红茶! Time Limit: 1000ms Memory Limit: 65536KB 64-bit integer IO format: %lld Java class nam ...
- 当Activity出现Exception时是如何处理的?
1.ActivityThread 2.PerformStop 在这里会调用mWindow.closeAllPanels(),从而关闭OptionMenu, ContextMenu.如果自己通过Wind ...
- shiro 配置导图
1 web.xml配置:shiro filter必须放在其他filter之前 <filter> <filter-name>shiroFilter</filter-name ...