LASSOLeast Absolute Shrinkage and Selection Operator)回归模型一般都是用英文缩写表示,
硬要翻译的话,可翻译为 最小绝对收缩和选择算子

它是一种线性回归模型的扩展,其主要目标是解决高维数据中的特征选择和正则化问题。

1. 概述

LASSO中,通过使用L1正则化项,它能够在回归系数中引入稀疏性,
也就是允许某些系数在优化过程中缩减为零,从而实现特征的选择。

与岭回归不同的是,LASSO的损失函数一般定义为:\(L(w) = (y-wX)^2+\lambda\parallel w\parallel_1\)
其中 \(\lambda\parallel w\parallel_1\),也就是 L1正则化项(岭回归中用的是 L2正则化项)。

模型训练的过程就是寻找让损失函数\(L(w)\)最小的参数\(w\)。
也就等价于:\(\begin{align}
& arg\ min(y-wX)^2 \\
& s.t. \sum |w_{ij}| < s
\end{align}\)
这两个公式表示,在满足约束条件 \(\sum |w_{ij}| < s\)的情况下,计算 \((y-wX)^2\)的最小值。

2. 创建样本数据

相比于岭回归模型,LASSO回归模型不仅对于共线性数据集友好,
对于高维数据的数据集,也有不错的性能表现。

它通过将不重要的特征的系数压缩为零,帮助我们选择最重要的特征,从而提高模型的预测准确性和可解释性。
下面我们模拟创建一些高维数据,创建一个特征数比样本数还多的样本数据集。

from sklearn.datasets import make_regression

X, y = make_regression(n_samples=80, n_features=100, noise=10)

这个数据集中,只有80个样本,每个样本却有100个特征,并且噪声也设置的很大(noise=10)。

3. 模型训练

第一步,分割训练集测试集

from sklearn.model_selection import train_test_split

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

scikit-learn中的LASSO模型来训练:

from sklearn.linear_model import Lasso

# 初始化LASSO线性模型
reg = Lasso()
# 训练模型
reg.fit(X_train, y_train)

这里使用的 Lasso()的默认参数来训练模型,它的主要参数包括:

  1. alpha:正则化项系数。它控制了L1正则化项的强度,即对模型复杂度的惩罚。alpha越大,模型越简单,但过大的alpha可能会导致模型欠拟合;alpha越小,模型越复杂,但过小的alpha可能会导致模型过拟合。默认值为1.0
  2. fit_intercept:布尔值,指定是否需要计算截距b值。如果设为False,则不计算b值默认值为True
  3. normalize:布尔值。如果设为True,则在模型训练之前将数据归一化。默认值为False
  4. precompute:布尔值,指定是否预先计算X的平方和。如果设为True,则在每次迭代之前计算X的平方和。默认值为False
  5. copy_X:布尔值,指定是否在训练过程中复制X。如果设为True,则在训练过程中复制X默认值为True
  6. max_iter:最大迭代次数。默认值为1000
  7. tol:阈值,用于判断是否达到收敛条件。默认值为1e-4
  8. warm_start:布尔值,如果设为True,则使用前一次的解作为本次迭代的起始点。默认值为False
  9. positive:布尔值,如果设为True,则强制系数为正。默认值为False
  10. selection:用于在每次迭代中选择系数的算法(有“cyclic”和“random”两种选择)。默认值为“cyclic”,即循环选择。

最后验证模型的训练效果:

from sklearn import metrics

y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred) print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error)) # 运行结果
均方误差:441.07830708712186
复相关系数:0.9838880665687711
中位数绝对误差:11.643348614829785

误差看上去不小,因为这次实际生成的样本,不仅数量小(80件)且噪声大(noise=10)。

3.1. 与岭回归模型比较

单独看LASSO模型的训练结果,看不出其处理高维数据的优势。
同样用上面分割好的训练集测试集,来看看岭回归模型的拟合效果。

from sklearn.linear_model import Ridge
# from sklearn.model_selection import train_test_split mse, r2, m_error = 0.0, 0.0, 0.0 # 初始化岭回归线性模型
reg = Ridge()
# 训练模型
reg.fit(X_train, y_train) y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred) print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error)) # 运行结果
均方误差:6315.046844910431
复相关系数:0.7693207470296398
中位数绝对误差:60.65140692273637

对于高维数据,可以看出,岭回归模型的误差 远远大于 LASSO模型。

3.2. 与最小二乘法模型比较

同样用上面分割好的训练集测试集,再来看看线性模型(最小二乘法)的拟合效果。

from sklearn.linear_model import LinearRegression

mse, r2, m_error = 0.0, 0.0, 0.0

# 初始化最小二乘法线性模型
reg = LinearRegression()
# 训练模型
reg.fit(X_train, y_train) y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred) print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error)) # 运行结果
均方误差:5912.442445894787
复相关系数:0.7840272859181612
中位数绝对误差:62.89225147465376

可以看出,线性模型的训练效果和岭回归模型差不多,但是都远远不如LASSO模型

4. 总结

总的来说,LASSO回归模型是一种流行的线性回归扩展,具有一些显著的优势和劣势。
比如,在特征选择上,LASSO通过将某些系数压缩为零,能够有效地进行特征选择,这在高维数据集中特别有用。
此外,LASSO可以作为正则化工具,有助于防止过拟合。

不过,LASSO会假设特征是线性相关的,对于非线性关系的数据,效果可能不佳。
而且,如果数据存在复杂模式或噪声,LASSO可能会过度拟合这些模式。

【scikit-learn基础】--『监督学习』之 LASSO回归的更多相关文章

  1. Python基础『一』

    内置数据类型 数据名称 例子 数字: Bool,Complex,Float,Integer True/False; z=a+bj; 1.23; 123 字符串: String '123456' 元组: ...

  2. Python基础『二』

    目录 语句,表达式 赋值语句 打印语句 分支语句 循环语句 函数 函数的作用 函数的三要素 函数定义 DEF语句 RETURN语句 函数调用 作用域 闭包 递归函数 匿名函数 迭代 语句,表达式 赋值 ...

  3. 『cs231n』计算机视觉基础

    线性分类器损失函数明细: 『cs231n』线性分类器损失函数 最优化Optimiz部分代码: 1.随机搜索 bestloss = float('inf') # 无穷大 for num in range ...

  4. Scikit Learn: 在python中机器学习

    转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...

  5. [原创] 【2014.12.02更新网盘链接】基于EasySysprep4.1的 Windows 7 x86/x64 『视频』封装

    [原创] [2014.12.02更新网盘链接]基于EasySysprep4.1的 Windows 7 x86/x64 『视频』封装 joinlidong 发表于 2014-11-29 14:25:50 ...

  6. (原创)(三)机器学习笔记之Scikit Learn的线性回归模型初探

    一.Scikit Learn中使用estimator三部曲 1. 构造estimator 2. 训练模型:fit 3. 利用模型进行预测:predict 二.模型评价 模型训练好后,度量模型拟合效果的 ...

  7. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  8. 『TensorFlow』批处理类

    『教程』Batch Normalization 层介绍 基础知识 下面有莫凡的对于批处理的解释: fc_mean,fc_var = tf.nn.moments( Wx_plus_b, axes=[0] ...

  9. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  10. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

随机推荐

  1. 比 nvm 更好用的 node 版本管理工具

    什么是 Volta Volta 是一种管理 JavaScript 命令行工具的便捷方式. volta 的特点: 速度 无缝,每个项目的版本切换 跨平台支持,包括 Windows 和所有 Unix sh ...

  2. KRPano插件一键解密大师 支持最新版KRPano XML/JS解密 ,支持分析下载静态/动态网站资源

    KRPano插件一键解密大师,可以一键解密KRPano的XML/JS插件,并可以分析下载静态和动态网站的所有资源.软件下载安装即可使用,解密仅需鼠标一键点击即可,无需配置任何开发环境,方便全景开发人员 ...

  3. DAY005_异或运算

    运算规则 二进制:相同为0 相异为1 十进制:相同为0 任何数字和0异或都是它本身 不利用额外变量交换两个数 数组中一种数字出现了奇数次,其他数都出现了偶数次,怎么得到这个出现了奇数次的数 将所有的数 ...

  4. Visual Studio 2022 设置代码补全

    Visual Studio 2022 设置代码补全 VS默认使用 Tab 键进行代码补全. 若要使用回车补全需要重新设置,具体路径如下: ​ 工具----选项----文本编辑器----C/C++--- ...

  5. 关于.Net 6.0 在Linux ,Docker容器中,不安装任何依赖就生成图形验证码!!!!!!!!!!!

    在.Net Framework时代,我们生成验证码大多都是用System.Drawing. 在.Net 6中使用也是没有问题的. 但是,System.Drawing却依赖于Windows GDI+. ...

  6. netstat命令输出详解

    netstat命令输出详解 1. 列出所有的TCP和UDP端口 2. 命令输出详解 Proto:协议名(tcp协议还是udp协议) recv-Q:网络接收队列,send-Q:网路发送队列 a. rec ...

  7. Java SE 21 新增特性

    Java SE 21 新增特性 作者:Grey 原文地址: 博客园:Java SE 21 新增特性 CSDN:Java SE 21 新增特性 源码 源仓库: Github:java_new_featu ...

  8. ElasticSearch系列——介绍、安装、插件介绍、安装ElasticSearch插件、安装Kibana、安装中文分词器、倒排索引、索引操作、映射管理

    文章目录 ElasticSearch之介绍 一 Elasticsearch产生背景 1.1 大规模数据如何检索 1.2 传统数据库的应对解决方案 1.3 非关系型数据库解决方案 1.4 内存数据库解决 ...

  9. 模块化打包工具-Webpack插件与其他功能

    1.Webpack插件机制 上一篇提到的webpack的loader可以用来加载资源,包括各种css,图片文件资源,实现打包文件的功能,而webpack的插件则起到了加强webpack的作用,可以完成 ...

  10. CSS 浮动和清除浮动方法总结

    作者:WangMin 格言:努力做好自己喜欢的每一件事 什么是浮动 float? 标准流:盒子会各占整行位置.子盒子若是标准流,父盒子虽然没有高度,但是会撑开父盒子高度. 浮动:盒子浮了起来,不会占据 ...