sklearn学习2-----LogisticsRegression
1、官网地址:
http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
2、class sklearn.linear_model.
LogisticRegression
(penalty=’l2’, dual=False, tol=0.0001, C=1.0, fit_intercept=True, intercept_scaling=1, class_weight=None, random_state=None, solver=’liblinear’, max_iter=100, multi_class=’ovr’, verbose=0, warm_start=False, n_jobs=1)
3、以上14个参数说明:(来自链接https://blog.csdn.net/jark_/article/details/78342644)
- penalty:惩罚项,str类型,可选参数为l1和l2,默认为l2。用于指定惩罚项中使用的规范。newton-cg、sag和lbfgs求解算法只支持L2规范。L1G规范假设的是模型的参数满足拉普拉斯分布,L2假设的模型参数满足高斯分布,所谓的范式就是加上对参数的约束,使得模型更不会过拟合(overfit),但是如果要说是不是加了约束就会好,这个没有人能回答,只能说,加约束的情况下,理论上应该可以获得泛化能力更强的结果。
- dual:对偶或原始方法,bool类型,默认为False。对偶方法只用在求解线性多核(liblinear)的L2惩罚项上。当样本数量>样本特征的时候,dual通常设置为False。
- tol:停止求解的标准,float类型,默认为1e-4。就是求解到多少的时候,停止,认为已经求出最优解。
- c:正则化系数λ的倒数,float类型,默认为1.0。必须是正浮点型数。像SVM一样,越小的数值表示越强的正则化。
- fit_intercept:是否存在截距或偏差,bool类型,默认为True。
- intercept_scaling:仅在正则化项为”liblinear”,且fit_intercept设置为True时有用。float类型,默认为1。
- class_weight:用于标示分类模型中各种类型的权重,可以是一个字典或者’balanced’字符串,默认为不输入,也就是不考虑权重,即为None。如果选择输入的话,可以选择balanced让类库自己计算类型权重,或者自己输入各个类型的权重。举个例子,比如对于0,1的二元模型,我们可以定义class_weight={0:0.9,1:0.1},这样类型0的权重为90%,而类型1的权重为10%。如果class_weight选择balanced,那么类库会根据训练样本量来计算权重。某种类型样本量越多,则权重越低,样本量越少,则权重越高。当class_weight为balanced时,类权重计算方法如下:n_samples / (n_classes * np.bincount(y))。n_samples为样本数,n_classes为类别数量,np.bincount(y)会输出每个类的样本数,例如y=[1,0,0,1,1],则np.bincount(y)=[2,3]。
- 那么class_weight有什么作用呢?
- 在分类模型中,我们经常会遇到两类问题:
- 第一种是误分类的代价很高。比如对合法用户和非法用户进行分类,将非法用户分类为合法用户的代价很高,我们宁愿将合法用户分类为非法用户,这时可以人工再甄别,但是却不愿将非法用户分类为合法用户。这时,我们可以适当提高非法用户的权重。
- 第二种是样本是高度失衡的,比如我们有合法用户和非法用户的二元样本数据10000条,里面合法用户有9995条,非法用户只有5条,如果我们不考虑权重,则我们可以将所有的测试集都预测为合法用户,这样预测准确率理论上有99.95%,但是却没有任何意义。这时,我们可以选择balanced,让类库自动提高非法用户样本的权重。提高了某种分类的权重,相比不考虑权重,会有更多的样本分类划分到高权重的类别,从而可以解决上面两类问题。
- 那么class_weight有什么作用呢?
- random_state:随机数种子,int类型,可选参数,默认为无,仅在正则化优化算法为sag,liblinear时有用。
- solver:优化算法选择参数,只有五个可选参数,即newton-cg,lbfgs,liblinear,sag,saga。默认为liblinear。solver参数决定了我们对逻辑回归损失函数的优化方法,有四种算法可以选择,分别是:
- liblinear:使用了开源的liblinear库实现,内部使用了坐标轴下降法来迭代优化损失函数。
- lbfgs:拟牛顿法的一种,利用损失函数二阶导数矩阵即海森矩阵来迭代优化损失函数。
- newton-cg:也是牛顿法家族的一种,利用损失函数二阶导数矩阵即海森矩阵来迭代优化损失函数。
- sag:即随机平均梯度下降,是梯度下降法的变种,和普通梯度下降法的区别是每次迭代仅仅用一部分的样本来计算梯度,适合于样本数据多的时候。
- saga:线性收敛的随机优化算法的的变重。
- 总结:
- liblinear适用于小数据集,而sag和saga适用于大数据集因为速度更快。
- 对于多分类问题,只有newton-cg,sag,saga和lbfgs能够处理多项损失,而liblinear受限于一对剩余(OvR)。啥意思,就是用liblinear的时候,如果是多分类问题,得先把一种类别作为一个类别,剩余的所有类别作为另外一个类别。一次类推,遍历所有类别,进行分类。
- newton-cg,sag和lbfgs这三种优化算法时都需要损失函数的一阶或者二阶连续导数,因此不能用于没有连续导数的L1正则化,只能用于L2正则化。而liblinear和saga通吃L1正则化和L2正则化。
- 同时,sag每次仅仅使用了部分样本进行梯度迭代,所以当样本量少的时候不要选择它,而如果样本量非常大,比如大于10万,sag是第一选择。但是sag不能用于L1正则化,所以当你有大量的样本,又需要L1正则化的话就要自己做取舍了。要么通过对样本采样来降低样本量,要么回到L2正则化。
- 从上面的描述,大家可能觉得,既然newton-cg, lbfgs和sag这么多限制,如果不是大样本,我们选择liblinear不就行了嘛!错,因为liblinear也有自己的弱点!我们知道,逻辑回归有二元逻辑回归和多元逻辑回归。对于多元逻辑回归常见的有one-vs-rest(OvR)和many-vs-many(MvM)两种。而MvM一般比OvR分类相对准确一些。郁闷的是liblinear只支持OvR,不支持MvM,这样如果我们需要相对精确的多元逻辑回归时,就不能选择liblinear了。也意味着如果我们需要相对精确的多元逻辑回归不能使用L1正则化了。
- max_iter:算法收敛最大迭代次数,int类型,默认为10。仅在正则化优化算法为newton-cg, sag和lbfgs才有用,算法收敛的最大迭代次数。
- multi_class:分类方式选择参数,str类型,可选参数为ovr和multinomial,默认为ovr。ovr即前面提到的one-vs-rest(OvR),而multinomial即前面提到的many-vs-many(MvM)。如果是二元逻辑回归,ovr和multinomial并没有任何区别,区别主要在多元逻辑回归上。
- OvR和MvM有什么不同*?*
- OvR的思想很简单,无论你是多少元逻辑回归,我们都可以看做二元逻辑回归。具体做法是,对于第K类的分类决策,我们把所有第K类的样本作为正例,除了第K类样本以外的所有样本都作为负例,然后在上面做二元逻辑回归,得到第K类的分类模型。其他类的分类模型获得以此类推。
- 而MvM则相对复杂,这里举MvM的特例one-vs-one(OvO)作讲解。如果模型有T类,我们每次在所有的T类样本里面选择两类样本出来,不妨记为T1类和T2类,把所有的输出为T1和T2的样本放在一起,把T1作为正例,T2作为负例,进行二元逻辑回归,得到模型参数。我们一共需要T(T-1)/2次分类。
- 可以看出OvR相对简单,但分类效果相对略差(这里指大多数样本分布情况,某些样本分布下OvR可能更好)。而MvM分类相对精确,但是分类速度没有OvR快。如果选择了ovr,则4种损失函数的优化方法liblinear,newton-cg,lbfgs和sag都可以选择。但是如果选择了multinomial,则只能选择newton-cg, lbfgs和sag了。
- OvR和MvM有什么不同*?*
- verbose:日志冗长度,int类型。默认为0。就是不输出训练过程,1的时候偶尔输出结果,大于1,对于每个子模型都输出。
- warm_start:热启动参数,bool类型。默认为False。如果为True,则下一次训练是以追加树的形式进行(重新使用上一次的调用作为初始化)。
- n_jobs:并行数。int类型,默认为1。1的时候,用CPU的一个内核运行程序,2的时候,用CPU的2个内核运行程序。为-1的时候,用所有CPU的内核运行程序。
4、使用(参考链接https://blog.csdn.net/loveliuzz/article/details/78708359)
#导入包
from sklearn.linear_model import LogisticRegression #构建并训练模型
for C in (10,1,0.1):
lr = LogisticRegression(multi_class="ovr",C=C,penalty="l2",solver="lbfgs",tol=0.01)
re = lr.fit(X_train,Y_train) #模型效果获取
r = re.score(X_train,Y_train)
#预测
Y_predict = lr.predict(X_test)
LogisticRegressionCV
from sklearn.linear_model import LogisticRegressionCV
#对数据的训练集进行标准化
ss = StandardScaler()
X_train = ss.fit_transform(X_train) #先拟合数据在进行标准化 #构建并训练模型
## multi_class:分类方式选择参数,有"ovr(默认)"和"multinomial"两个值可选择,在二元逻辑回归中无区别
## cv:几折交叉验证
## solver:优化算法选择参数,当penalty为"l1"时,参数只能是"liblinear(坐标轴下降法)"
## "lbfgs"和"cg"都是关于目标函数的二阶泰勒展开
## 当penalty为"l2"时,参数可以是"lbfgs(拟牛顿法)","newton_cg(牛顿法变种)","seg(minibactch随机平均梯度下降)"
## 维度<10000时,选择"lbfgs"法,维度>10000时,选择"cs"法比较好,显卡计算的时候,lbfgs"和"cs"都比"seg"快
## penalty:正则化选择参数,用于解决过拟合,可选"l1","l2"
## tol:当目标函数下降到该值是就停止,叫:容忍度,防止计算的过多
lr = LogisticRegressionCV(multi_class="ovr",fit_intercept=True,Cs=np.logspace(-2,2,20),cv=2,penalty="l2",solver="lbfgs",tol=0.01)
re = lr.fit(X_train,Y_train) #模型效果获取
r = re.score(X_train,Y_train)
print("R值(准确率):",r)
print("参数:",re.coef_)
print("截距:",re.intercept_)
print("稀疏化特征比率:%.2f%%" %(np.mean(lr.coef_.ravel()==0)*100))
print("=========sigmoid函数转化的值,即:概率p=========")
print(re.predict_proba(X_test)) #sigmoid函数转化的值,即:概率p #模型的保存与持久化
from sklearn.externals import joblib
joblib.dump(ss,"logistic_ss.model") #将标准化模型保存
joblib.dump(lr,"logistic_lr.model") #将训练后的线性模型保存
joblib.load("logistic_ss.model") #加载模型,会保存该model文件
joblib.load("logistic_lr.model") #预测
X_test = ss.transform(X_test) #数据标准化
Y_predict = lr.predict(X_test) #预测
5、属性
coef_:表示参数w,intercept:表示截距b,n_iter表示迭代次数。
6、方法:
注意:predict预测出来的是类别(阈值固定为0.5),predict_proba预测出来的是概率值,可以通过改变阈值来设定类别。
sklearn学习2-----LogisticsRegression的更多相关文章
- sklearn学习笔记之简单线性回归
简单线性回归 线性回归是数据挖掘中的基础算法之一,从某种意义上来说,在学习函数的时候已经开始接触线性回归了,只不过那时候并没有涉及到误差项.线性回归的思想其实就是解一组方程,得到回归函数,不过在出现误 ...
- sklearn学习总结(超全面)
https://blog.csdn.net/fuqiuai/article/details/79495865 前言sklearn想必不用我多介绍了,一句话,她是机器学习领域中最知名的python模块之 ...
- sklearn学习 第一篇:knn分类
K临近分类是一种监督式的分类方法,首先根据已标记的数据对模型进行训练,然后根据模型对新的数据点进行预测,预测新数据点的标签(label),也就是该数据所属的分类. 一,kNN算法的逻辑 kNN算法的核 ...
- sklearn 学习 第一篇:分类
分类属于监督学习算法,是指根据已有的数据和标签(分类)进行学习,预测未知数据的标签.分类问题的目标是预测数据的类别标签(class label),可以把分类问题划分为二分类和多分类问题.二分类是指在两 ...
- SKlearn | 学习总结
1 简介 scikit-learn,又写作sklearn,是一个开源的基于python语言的机器学习工具包.它通过NumPy, SciPy和Matplotlib等python数值计算的库实现高效的算法 ...
- sklearn学习笔记3
Explaining Titanic hypothesis with decision trees decision trees are very simple yet powerful superv ...
- sklearn学习笔记2
Text classifcation with Naïve Bayes In this section we will try to classify newsgroup messages using ...
- sklearn学习笔记1
Image recognition with Support Vector Machines #our dataset is provided within scikit-learn #let's s ...
- 莫烦sklearn学习自修第九天【过拟合问题处理】
1. 过拟合问题可以通过调整机器学习的参数来完成,比如sklearn中通过调节gamma参数,将训练损失和测试损失降到最低 2. 代码实现(显示gamma参数对训练损失和测试损失的影响) from _ ...
- 莫烦sklearn学习自修第八天【过拟合问题】
1. 什么是过拟合问题 所谓过拟合问题指的是使用训练样本进行训练时100%正确分类或规划,当使用测试样本时则不能正确分类和规划 2. 代码实战(模拟过拟合问题) from __future__ imp ...
随机推荐
- windows 2003一个网卡绑定多个IP地址
1.打开“网络连接”,选中需要添加多个IP的“本地连接”-->右键-->“属性”: 2.从“常规”中找到“Internet 协议(TCP/IP)属性”: 3.选择手动设置IP地址.网关.掩 ...
- SCU - 4110 - PE class
先上题目: 4110: PE class Submit your solution Discuss this problem Best solutions Description ...
- HDU 2224 The shortest path
The shortest path Time Limit: 1000/1000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others ...
- MySQL Workbench出现:Error Code: 2013. Lost connection to MySQL server during query的问题解决
解决办法: [Edit]->[Preference]->[SQL Editor] 将下图DBMS connection read time out (in seconds)适当调大: 参考 ...
- ps -ef与ps aux的区别
ps -ef与ps aux的区别 学习:http://www.linuxidc.com/Linux/2016-07/133515.htm ps aux可以查看其内存使用情况:
- EditText焦点问题
1.在一个Activity中加入一个EditText后,每次进入这个Activity时输入法都会自己主动弹出来.非常烦,找了些资料,在此记下解决的方法: 方法:在EditText的父控件中获得焦点.这 ...
- strcpy函数使用方法以及底层实现
strcpy(s1, s2); strcpy函数的意思是:把字符串s2中的内容copy到s1中.连字符串结束标志也一起copy. 这样s1在内存中的存放为:ch\0; 在cout<<s ...
- hdu 2883 kebab(时间区间压缩 && dinic)
kebab Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others) Total Subm ...
- Mysql 索引需要了解的几个注意
索引是做什么的? 索引用于快速找出在某个列中有一特定值的行.不使用索引,MySQL必须从第1条记录开始然后读完整个表直到找出相关的行.表越大,花费的时间越多.如果表中查询的列有一个索引,MySQL能快 ...
- 南海区行政审批管理系统接口规范v0.3(规划)
1. 会话API 1.1. login [登录验证] {"r_code":"500","r_msg":"操作失败",&q ...