02-05 scikit-learn库之线性回归
- scikit-learn库之线性回归
- 一、LinearRegression
- 二、ARDRegression
- 三、BayesianRidge
- 四、ElasticNet
- 五、ElasticNetCV
- 六、Lasso
- 七、LassoCV
- 八、LassoLars
- 九、LassoLarsCV
- 十、LassoLarsIC
- 十一、MutilTaskLasso
- 十二、MutilTaskElasticNet
- 十三、MutilTaskLassoCV
- 十四、MutilTaskElasticNetCV
- 十五、OrthogonalMatchingPursuit
- 十六、OrthogonalMatchingPursuitCV
- 十七、RANSACRegressor
- 十八、Ridge
- 十九、RidgeCV
更新、更全的《机器学习》的更新网站,更有python、go、数据结构与算法、爬虫、人工智能教学等着你:https://www.cnblogs.com/nickchen121/p/11686958.html
scikit-learn库之线性回归
由于scikit-learn库中sclearn.linear_model
提供了多种支持线性回归分析的类,本文主要总结一些常用的线性回归的类,并且由于是从官方文档翻译而来,翻译会略有偏颇,如果有兴趣想了解其他类的使用方法的同学也可以去scikit-learn官方文档查看https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model
在讲线性回归理论的时候讲到了,线性回归的目的是找到一个线性回归系数向量\(\omega\),使得输入特征\(X\)和输出向量\(Y\)之间有一个
\]
的映射关系,接下来的线性回归模型和线性回归模型的思想类似。假设一个数据集有\(m\)实例,每个实例有\(n\)个特征,则其中\(Y\)的维度是\(m*1\),\(X\)的维度是\(m*n\),\(\omega\)的维度是\(n*1\)。
使用线性回归的目的就是找到一个合适的线性回归系数\(\omega\)能够最小化我们定义的目标函数,又由于最小化目标函数的优化方法的不同,会有不同的线性回归算法。
由于其他版本的线性回归模型的参数类似于LinearRegression
,即其他类型的线性回归模型的参数详解都会跳过,只会讲解它与LinearRegression
的不同之处。我们接下来的目的就是为了给大家介绍scikit-learn库中常用的线性回归模型。
一、LinearRegression
1.1 使用场景
LinearRegression
回归模型,即我们在线性回归中讲到的普通线性回归,该普通线性回归可以处理一元线性回归,也可以处理多元线性回归,但是该类使用的优化方法是最小二乘法。
通常情况下该类是我们使用线性回归处理线性问题的首选方法,因为它的目标函数较其他线性回归简单,计算量小,如果它拟合数据出现过拟合问题则可以考虑使用正则化形式的线性回归。
1.2 代码
import numpy as np
from sklearn.linear_model import LinearRegression
X = np.array([[2, 0], [1, 9], [6, 6], [8, 8]])
# y = 1 * x_0 + 2 * x_1 + 3
y = np.dot(X, np.array([6, 8])) + 3
reg = LinearRegression()
reg.fit(X, y)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
normalize=False)
reg.score(X, y)
1.0
reg.coef_
array([6., 8.])
reg.intercept_
2.999999999999986
reg.predict(np.array([[8, 6]]))
array([99.])
1.3 参数详解
- fit_intercept:截距(偏置单元),bool类型。是否存在截距或者偏置单元。如果使用中心化的数据(中心点为0的数据),可以考虑设置fit_intercept=False。默认为True。
- normalize:标准化数据,bool类型。当fit_intercept=False的时候,这个参数会被自动忽略;如果fit_intercept=True,回归器会标准化输入数据,该标准化方式为:减去平均值,并且除以相应的二范数。建议在使用fit()训练模型之前使用sklearn.preprocessing.StandardScaler对数据标准化,同时设置normalize=False。默认为False。
- copy_X:复制数据,bool类型。如果copy_X=False,可能会因为对数据中心化把原始X数据覆盖。默认为True。
- n_jobs:并行数,int类型。n_jobs=1使用1个cpu运行程序;n_jobs=2,使用2个cpu运行程序;n_jobs=-1,使用所有cpu运行程序。默认为1。
1.4 属性
- coef_:array类型,线性回归系数。
- intercept_:array类型,截距。
1.5 方法
- fit(X,y,sample_weight=None):把数据放入模型中训练模型,其中sample_weight=None是array类型可以对训练集中实例添加权重,即对训练集中不同的数据增加不同的权重。
- get_params([deep]):返回模型的参数,例如可以用于Pipeline中。
from sklearn.pipeline import Pipeline
p =Pipeline([
('poly', PolynomialFeatures()),
('linear', LinearRegression(fit_intercept=False))])
lin = p.get_params('linear')['linear']
print(lin.coef_)
- predict(X):通过样本X得到X对应的预测值。
- score(X, y[, sample_weight]):基于报告决定系数\(R^2\)评估模型。
- set_prams(**params):创建模型参数。
1.5.1 报告决定系数
报告决定系数\((R^2)\),可以理解成MSE的标准版,\(R^2\)的公式为
\]
其中\(\mu_{(y)}\)是\(y\)的平均值,即\({{\frac{1}{n}}\sum_{i=1}^n(y^{(i)}-\mu_{(y)})^2}\)为\(y\)的方差,公式可以写成
\]
\(R^2\)的取值范围在\(0-1\)之间,如果\(R^2=1\),则均方误差\(MSE=0\),即模型完美的拟合数据。
二、ARDRegression
当数据集中有很多缺失值或异常值时使用ARDRegression
模型,该模型属于贝叶斯回归模型。该模型会对模型输出\(Y\)和模型参数\(\omega\)作出分布假设,并且正则化参数alpha也会从数据中估计得到,虽然该模型对异常值鲁棒性很好,但由于该模型计算量大,耗时,一般情况不推荐使用,此处不多赘述。
三、BayesianRidge
该模型类似于ARDRegression
模型,两者都属于贝叶斯回归,不同之处在于对\(\omega\)的分布假设不同。由于该模型的目标函数类似于Ridge
模型的目标函数,因此取名BayesianRidge
。但由于该模型同样计算量大,耗时,一般情况下也不推荐使用,此处不多赘述。
四、ElasticNet
ElasticNet
模型的优化方法是坐标轴下降法,该模型由L1正则化和L2正则化的加权得到,如果使用L1正则化和L2正则化都不行的时候,可以考虑使用该模型。
该模型由于增加了参数alpha和l1_ratio,需要手动调参,通常使用接下来的LassoCV
。
五、ElasticNetCV
ElasticNetCV
模型在目标函数和优化方式类似于ElasticNet
,但是可以自己手动输入10组、100组参数alpha和l1_ratio,该模型会通过交叉验证后给你这组参数中最优模型。
六、Lasso
Lasso
模型的优化方法是坐标轴下降法,该模型即线性回归L1正则化,该。如果数据集的特征维度较高,可以使用该模型,该模型可以把一些较小的回归系数直接变为\(0\),由于减少了数据集的特征维度,也会间接的减轻模型过拟合问题,增强模型的泛化能力。
该模型由于会把一些较小的回归系数变为\(0\),既可以找出重要的特征,对数据集的解释能力强。
该模型由于增加了参数alpha,需要手动调参,通常使用接下来的LassoCV
。
七、LassoCV
LassoCV
模型在目标函数和优化方式类似于Lasso
,但是可以自己手动输入10组、100组参数alpha,该模型会通过交叉验证后给你这组参数中最优模型。
八、LassoLars
LassoLars
模型的优化方法是最小角回归法,该模型类似于Lasso
模型,但是该模型优化方法为。
该模型由于增加了参数alpha,需要手动调参,通常使用接下来的LassoLarsCV
。
九、LassoLarsCV
LassoLarsCV
模型在目标函数和优化方式类似于LassoLars
,但是可以自己手动输入10组、100组参数alpha,该模型会通过交叉验证后给你这组参数中最优模型。
十、LassoLarsIC
LassoLarsIC
模型类似于Lasso
模型,不同之处在于它并不使用交叉验证的方式得到最优模型。它基于AIC和BIC准则,一轮就可以找到找到一个最优alpha和最优模型,而交叉验证如果使用\(k\)折交叉验证,则需要\(k-1\)次才能找到最优模型。
该模型从上述讲述看起来是很完美的,但是该模型要求数据集是由某个假设的模型产生的,并且如果当特征数量大于实例数量的时候该模型可能会成为一个较差的模型,所以在工业上一般不推荐使用。
十一、MutilTaskLasso
MutilTaskLasso
模型的优化方法是坐标轴下降法,模型中的MutilTask可以理解成“多个”而不是“多进程”,即一次性使用多个L1正则化线性回归模型拟合数据,有时候也称之为共享特征协同回归。
普通线性回归的模型是
\]
其中假设一个数据集有\(m\)实例,每个实例有\(n\)个特征,则其中\(Y\)的维度是\(m*1\),\(X\)的维度是\(m*n\),\(\omega\)的维度是\(n*1\)。
该模型去掉正则化项是
\]
其中假设一个数据集有\(m\)实例,每个实例有\(n\)个特征,则其中\(Y\)的维度是\(m*k\),\(X\)的维度是\(m*n\),\(W\)的维度是\(n*k\),其中\(k\)为回归模型的个数,即该模型的fit()方法可以传入\(k\)维的特征。
该模型由于增加了参数alpha和\(k\),需要手动调参,通常使用接下来的MutilTaskLassoCV
。
十二、MutilTaskElasticNet
MutilTaskElasticNet
模型的优化方法是坐标轴下降法,该模型类似于MutilTaskLasso
模型,只是在正则化项上前者使用了L1正则项,后者使用了弹性网络正则项。
该模型由于增加了参数alpha和l1_ratio,需要手动调参,通常使用接下来的LassoCV
。
十三、MutilTaskLassoCV
MutilTaskLassoCV
模型在目标函数和优化方式类似于MutilTaskLasso
,但是可以自己手动输入10组、100组参数alpha,该模型会通过交叉验证后给你这组参数中最优模型。
十四、MutilTaskElasticNetCV
该模型在目标函数和优化方式类似于Lasso
,但是可以自己手动输入10组、100组参数alpha和l1_ratio,该模型会通过交叉验证后给你这组参数中最优模型。
十五、OrthogonalMatchingPursuit
OrthogonalMatchingPursuit
模型优化方法是前向选择算法,优化方法速度虽然快,但是精确度较低。
该模型使用参数n_nonzero_coefs限制模型参数\(\omega\)向量中元素非\(0\)的个数,由于该特征可以用于稀疏特征模型的特征选择上,这一点类似于Lasso
模型,但是由于优化方法前向选择算法,一般不推荐使用。
该模型由于增加了参数n_nonzero_coefs,需要手动调参,通常使用接下来的LassoCV
。
十六、OrthogonalMatchingPursuitCV
OrthogonalMatchingPursuitCV
模型在目标函数和优化方式类似于OrthogonalMatchingPursuitCV
,但是可以自己手动输入10组、100组参数n_nonzero_coefs,该模型会通过交叉验证后给你这组参数中最优模型。
十七、RANSACRegressor
RANSACRegressor
模型使用的优化算法是RANSACR算法,该算法可以控制使用部分区域的数据集训练模型。
可以参考《RANSAC算法线性回归(波斯顿房价预测)》。
十八、Ridge
Ridge
模型的优化方法是最小二乘法,该模型即线性回归L2正则化,一般使用LinearRegression
模型时模型过拟合时可以使用该方法。
由于额外增加了alpha参数,一般情况下需要自己手动调参,所以可以在自己测试的时候使用,一般工业上使用较多的是接下来的RidgeCV
。
十九、RidgeCV
RidgeCV
模型在目标函数和优化方式类似于Ridge
,但是可以自己手动输入10组、100组参数alpha,该模型会通过交叉验证后给你这组参数中最优模型。
02-05 scikit-learn库之线性回归的更多相关文章
- (原创)(三)机器学习笔记之Scikit Learn的线性回归模型初探
一.Scikit Learn中使用estimator三部曲 1. 构造estimator 2. 训练模型:fit 3. 利用模型进行预测:predict 二.模型评价 模型训练好后,度量模型拟合效果的 ...
- Scikit Learn: 在python中机器学习
转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...
- scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类 (python代码)
scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类数据集 fetch_20newsgroups #-*- coding: UTF-8 -*- import ...
- 【网络爬虫入门02】HTTP客户端库Requests的基本原理与基础应用
[网络爬虫入门02]HTTP客户端库Requests的基本原理与基础应用 广东职业技术学院 欧浩源 1.引言 实现网络爬虫的第一步就是要建立网络连接并向服务器或网页等网络资源发起请求.urllib是 ...
- (原创)(四)机器学习笔记之Scikit Learn的Logistic回归初探
目录 5.3 使用LogisticRegressionCV进行正则化的 Logistic Regression 参数调优 一.Scikit Learn中有关logistics回归函数的介绍 1. 交叉 ...
- 怎样用Python的Scikit-Learn库实现线性回归?
来源商业新知号网,原标题:用Python的Scikit-Learn库实现线性回归 回归和分类是两种 监督 机器 学习算法, 前者预测连续值输出,而后者预测离散输出. 例如,用美元预测房屋的价格是回归问 ...
- Scikit Learn
Scikit Learn Scikit-Learn简称sklearn,基于 Python 语言的,简单高效的数据挖掘和数据分析工具,建立在 NumPy,SciPy 和 matplotlib 上.
- python进阶05 常用问题库(1)json os os.path模块
python进阶05 常用问题库(1)json os os.path模块 一.json模块(数据交互) web开发和爬虫开发都离不开数据交互,web开发是做网站后台的,要跟网站前端进行数据交互 1.什 ...
- 使用TensorFlow v2库实现线性回归
使用TensorFlow v2库实现线性回归 此示例使用简单方法来更好地理解训练过程背后的所有机制 from __future__ import absolute_import, division, ...
随机推荐
- SpringBoot+SpringMVC+MyBatis快速整合搭建
作为开发人员,大家都知道,SpringBoot是基于Spring4.0设计的,不仅继承了Spring框架原有的优秀特性,而且还通过简化配置来进一步简化了Spring应用的整个搭建和开发过程.另外Spr ...
- git 中文乱码-一次被坑经历
git log和gitcommit中文出现乱码,花了大半天的时间试了网上的各种方法,还是搞不定. 只好放大招. 卸载软件后重装,还是不行.然后git config --list 发现一些奇怪的配置信息 ...
- tarjan缩点(洛谷P387)
此题解部分借鉴于九野的博客 题目分析 给定一个 \(n\) 个点 \(m\) 条边有向图,每个点有一个权值,求一条路径,使路径经过的点权值之和最大.你只需要求出这个权值和. 允许多次经过一条边或者一个 ...
- DevOps平台
DevOps定义(来自维基百科): DevOps(Development和Operations的组合词)是一种重视"软件开发人员(Dev)"和"IT运维技术人员(Ops) ...
- FreeSql (一)入门
FreeSql 是一个功能强大的对象关系映射程序(O/RM),支持 .NETCore 2.1+ 或 .NETFramework 4.5+(QQ群:4336577) FreeSql采用MIT开源协议托管 ...
- 如何部署 H5 游戏到云服务器?
在自学游戏开发的路上,最有成就感的时刻就是将自己的小游戏做出来分享给朋友试玩,原生的游戏开可以打包分享,小游戏上线流程又长,那 H5 小游戏该怎么分享呢?本文就带大家通过 nginx 将构建好的 H5 ...
- 疑难杂症----Windows10
现在大多数个人电脑所用的操作系统都是win10,而我们使用win10时总是会碰上各种各样的问题,所以专门写一篇博客来记录我碰上的各种问题,便于以后更快的解决问题. 一.小娜搜索不到应用问题解决方案 小 ...
- Sublime Text 3 中实现编译C语言程序
这个是真坑,感觉用devc++写c程序特别的不爽,所以就用了sublime,但是,编译的时候又有不少问题, 下面就把我踩的坑记录下来 tools>Build System>New Buil ...
- Android手机QQ文件夹解析
注:切勿修改手机QQ文件夹,以免造成不必要的使用问题及无法修复的数据丢失] 安卓手机QQ tencent文件夹解析 QQ下载的聊天背景:tencent→MobileQQ→system_backgrou ...
- web-文件上传漏洞总结
思维导图: 一,js验证绕过 1.我们直接删除代码中onsubmit事件中关于文件上传时验证上传文件的相关代码即可. 或者可以不加载所有js,还可以将html源码copy一份到本地,然后对相应代码进行 ...