Spark2.0机器学习系列之4:Logistic回归及Binary分类(二分问题)结果评估
参数设置
α:
梯度上升算法迭代时候权重更新公式中包含 α :
http://blog.csdn.net/lu597203933/article/details/38468303
为了更好理解 α和最大迭代次数的作用,给出Python版的函数计算过程。
# 梯度上升算法-计算回归系数
# 每个回归系数初始化为1
# 重复R次:
# 计算整个数据集的梯度
# 使用α*梯度更新回归系数的向量
# 返回回归系数
def gradAscent(dataMatIn, classLabels,alpha=0.001,maxCycles = ):
dataMatrix = mat(dataMatIn) #转换为numpy数据类型
labelMat = mat(classLabels).transpose()
m,n = shape(dataMatrix)
maxCycles =
weights = ones((n,))
for k in range(maxCycles):
h = sigmoid(dataMatrix*weights)
error = (labelMat - h)
#计算真实类别与预测类别的差值,按照该差值的方向调整回归系数
weights = weights + alpha* dataMatrix.transpose() * error
return weights
λ:
λ,正则化参数(泛化能力),加正则化的前提是特征值要进行归一化。
在实际应该过程中,为了增强模型的泛化能力,防止我们训练的模型过拟合,特别是对于大量的稀疏特征,模型复杂度比较高,需要进行降维,我们需要保证在训练误差最小化的基础上,通过加上正则化项减小模型复杂度。在逻辑回归中,有L1、L2进行正则化。>
损失函数如下:
http://www.bkjia.com/yjs/996300.html
在损失函数里加入一个正则化项,正则化项就是权重的L1或者L2范数乘以一个λ,用来控制损失函数和正则化项的比重,直观的理解,首先防止过拟合的目的就是防止最后训练出来的模型过分的依赖某一个特征,当最小化损失函数的时候,某一维度很大,拟合出来的函数值与真实的值之间的差距很小,通过正则化可以使整体的cost变大,从而避免了过分依赖某一维度的结果。当然加正则化的前提是特征值要进行归一化。
threshold:
threshold变量用来控制分类的阈值,默认值为0.5。表示如果预测值小于threshold则为分类0.0,否则为1.0。
在Spark Java中
ElasticNetParam : α ;RegParam :λ。
LogisticRegression lr=new LogisticRegression()
.setMaxIter()
.setRegParam(0.3)
.setElasticNetParam(0.2)
.setThreshold(0.5);
分类效果评估
参考:http://www.cnblogs.com/tovin/p/3816289.html
http://blog.sina.com.cn/s/blog_900690c60101czyo.html
http://blog.chinaunix.net/uid-446337-id-94448.html
http://blog.csdn.net/abcjennifer/article/details/7834256
混淆矩阵(Confusion matrix):
考虑一个二分问题,即将实例分成正类(positive)或负类(negative)。对一个二分问题来说,会出现四种情况。如果一个实例是正类并且也被 预测成正类,即为真正类(True positive),如果实例是负类被预测成正类,称之为假正类(False positive)。相应地,如果实例是负类被预测成负类,称之为真负类(True negative),正类被预测成负类则为假负类(false negative)。
TP:正确肯定的数目;
FN:漏报,没有正确找到的匹配的数目;
FP:误报,给出的匹配是不正确的;
TN:正确拒绝的非匹配对数
精确率,precision = TP / (TP + FP)
模型判为正的所有样本中有多少是真正的正样本
召回率,recall = TP / (TP + FN)
准确率,accuracy = (TP + TN) / (TP + FP + TN + FN)
反映了分类器统对整个样本的判定能力——能将正的判定为正,负的判定为负
如何在precision和Recall中权衡?
F1 Score = P*R/2(P+R),其中P和R分别为 precision 和 recall
在precision与recall都要求高的情况下,可以用F1 Score来衡量
为什么会有这么多指标呢?
这是因为模式分类和机器学习的需要。判断一个分类器对所用样本的分类能力或者在不同的应用场合时,需要有不同的指标。 当总共有个100 个样本(P+N=100)时,假如只有一个正例(P=1),那么只考虑精确度的话,不需要进行任何模型的训练,直接将所有测试样本判为正例,那么 A 能达到 99%,非常高了,但这并没有反映出模型真正的能力。另外在统计信号分析中,对不同类的判断结果的错误的惩罚是不一样的。举例而言,雷达收到100个来袭导弹的信号,其中只有 3个是真正的导弹信号,其余 97 个是敌方模拟的导弹信号。假如系统判断 98 个(97 个模拟信号加一个真正的导弹信号)信号都是模拟信号,那么Accuracy=98%,很高了,剩下两个是导弹信号,被截掉,这时Recall=2/3=66.67%,Precision=2/2=100%,Precision也很高。但剩下的那颗导弹就会造成灾害。
ROC曲线和AUC
有时候我们需要在精确率与召回率间进行权衡,
调整分类器threshold取值,以FPR(假正率False-positive rate)为横坐标,TPR(True-positive rate)为纵坐标做ROC曲线;
Area Under roc Curve(AUC):处于ROC curve下方的那部分面积的大小通常,AUC的值介于0.5到1.0之间,较大的AUC代表了较好的性能;
精确率和召回率是互相影响的,理想情况下肯定是做到两者都高,但是一般情况下准精确率、召回率就低,召回率低、精确率高,当然如果两者都低,那是什么地方出问题了
Spark 2.0分类评估
//获得回归模型训练的Summary
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); // Obtain the loss per iteration.
//每次迭代的损失,一般会逐渐减小
double[] objectiveHistory = trainingSummary.objectiveHistory();
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
} // Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
// classification problem.
//强制类型转换为二类LR的Summary,然后就可以用混淆矩阵,ROC等评估方法了。Spark2.0还无法针对多类
BinaryLogisticRegressionSummary binarySummary =
(BinaryLogisticRegressionSummary) trainingSummary; // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
Dataset<Row> roc = binarySummary.roc();//获得ROC
roc.show();//显示ROC数据表,可以用这个数据自己画ROC曲线
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());//AUC // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
//不同的阈值,计算不同的F1,然后通过最大的F1找出并重设模型的最佳阈值。
Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble();//获得最大的F1值
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble();//找出最大F1值对应的阈值(最佳阈值)
lrModel.setThreshold(bestThreshold);//并将模型的Threshold设置为选择出来的最佳分类阈值
Logistic回归完整的代码
http://spark.apache.org/docs/latest/ml-classification-regression.html
package my.spark.ml.practice.classification; import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions; public class myLogisticRegression { public static void main(String[] args) {
SparkSession spark=SparkSession
.builder()
.appName("LR")
.master("local[4]")
.config("spark.sql.warehouse.dir","file///:G:/Projects/Java/Spark/spark-warehouse" )
.getOrCreate();
String path="G:/Projects/CgyWin64/home/pengjy3/softwate/spark-2.0.0-bin-hadoop2.6/"
+ "data/mllib/sample_libsvm_data.txt"; //屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF); //Load trainning data
Dataset<Row> trainning_dataFrame=spark.read().format("libsvm").load(path); LogisticRegression lr=new LogisticRegression()
.setMaxIter()
.setRegParam(0.3)
.setElasticNetParam(0.2)
.setThreshold(0.5); //fit the model
LogisticRegressionModel lrModel=lr.fit(trainning_dataFrame); //print the coefficients and intercept for logistic regression
System.out.println
("Coefficient:"+lrModel.coefficients()+"Itercept"+lrModel.intercept()); //Extract the summary from the returned LogisticRegressionModel
LogisticRegressionTrainingSummary summary=lrModel.summary(); //Obtain the loss per iteration.
double[] objectiveHistory=summary.objectiveHistory();
for(double lossPerIteration:objectiveHistory){
System.out.println(lossPerIteration);
}
// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
// classification problem.
BinaryLogisticRegressionTrainingSummary binarySummary=
(BinaryLogisticRegressionTrainingSummary)summary;
//Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
Dataset<Row> roc=binarySummary.roc();
roc.show((int) roc.count());//显示全部的信息,roc.show()默认只显示20行
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble();
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble();
lrModel.setThreshold(bestThreshold);
} }
BinaryClassificationEvaluator
除了上述Logistic回归结果评估方法,在Spark2.0中,二分问题结果评估用BinaryClassificationEvaluator。
参数:
(1)labelCol: label column name (default: label, current: label)
(2)metricName: metric name in evaluation (areaUnderROC|areaUnderPR) (default: areaUnderROC)
看来是没有其它的评估方法
(3)rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction, current: prediction):注意名字不是PredictionCol,
只有这三个参数可以设置!
自定义accuracy
//自定义计算accuracy,
Dataset<Row> predictDF=naiveBayesModel.transform(test); double total=(double) predictDF.count();
Encoder<Double> doubleEncoder=Encoders.DOUBLE();
Dataset<Double> accuracyDF=predictDF.map(new MapFunction<Row,Double>() {
@Override
public Double call(Row row) throws Exception {
if((double)row.get()==(double)row.get()){return 1.0;}
else {return 0.0;}
}
}, doubleEncoder);
accuracyDF.createOrReplaceTempView("view");
double correct=(double) spark.sql("SELECT value FROM view WHERE value=1.0").count();
System.out.println("accuracy "+(correct/total));
Spark2.0机器学习系列之4:Logistic回归及Binary分类(二分问题)结果评估的更多相关文章
- Spark2.0机器学习系列之7: MLPC(多层神经网络)
Spark2.0 MLPC(多层神经网络分类器)算法概述 MultilayerPerceptronClassifier(MLPC)这是一个基于前馈神经网络的分类器,它是一种在输入层与输出层之间含有一层 ...
- Spark2.0机器学习系列之3:决策树
概述 分类决策树模型是一种描述对实例进行分类的树形结构. 决策树可以看为一个if-then规则集合,具有“互斥完备”性质 .决策树基本上都是 采用的是贪心(即非回溯)的算法,自顶向下递归分治构造. 生 ...
- Spark2.0机器学习系列之1: 聚类算法(LDA)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet allocation (LDA) ...
- Spark2.0机器学习系列之12: 线性回归及L1、L2正则化区别与稀疏解
概述 线性回归拟合一个因变量与一个自变量之间的线性关系y=f(x). Spark中实现了: (1)普通最小二乘法 (2)岭回归(L2正规化) (3)La ...
- Spark2.0机器学习系列之11: 聚类(幂迭代聚类, power iteration clustering, PIC)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet all ...
- Spark2.0机器学习系列之10: 聚类(高斯混合模型 GMM)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet allocation (LDA) ...
- Spark2.0机器学习系列之9: 聚类(k-means,Bisecting k-means,Streaming k-means)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet allocation (LDA) ...
- Spark2.0机器学习系列之8:多类分类问题(方法归总和分类结果评估)
一对多(One-vs-Rest classifier) 将只能用于二分问题的分类(如Logistic回归.SVM)方法扩展到多类. 参考:http://www.cnblogs.com/CheeseZH ...
- Spark2.0机器学习系列之6:GBDT(梯度提升决策树)、GBDT与随机森林差异、参数调试及Scikit代码分析
概念梳理 GBDT的别称 GBDT(Gradient Boost Decision Tree),梯度提升决策树. GBDT这个算法还有一些其他的名字,比如说MART(Multiple Addi ...
随机推荐
- [Java] java调用wsdl接口
前提: ① 已经提供了一个wsdl接口 ② 该接口能正常调用 步骤1:使用cxf的wsdl2java工具生成本地类 下载CXF:http://cxf.apache.org/download.html ...
- 更新加子查询加相同的表解决办法 mysql
UPDATE ofuser SET auid = '0' WHERE uid in (SELECT uid FROM (select tmp.* from ofuser tmp)a WHERE aui ...
- 关于Unity5.5中2D动画的制作
1.首先要创建一个精灵 GameProject--2Dproject--Sprite 叫bird 2.给这个精灵附加纹理,并让它显示自己想让它显示的场景层中,一般它的静止纹理就是动画的第一张图片 3. ...
- linux 下简单的ftp客户端程序
该ftp的客服端是在linux下面写,涉及的东西也比较简单,如前ftp的简单介绍,知道ftp主要的工作流程架构,套接字的创建,还有就是字符串和字符的处理.使用的函数都是比较简单平常易见的,写的时候感觉 ...
- c#方法生成mysql if方法(算工作日)
public static string retunSQl(string s,string e){ return @"IF ( "+s+ ">" +e+ ...
- hdu 1233:还是畅通工程(数据结构,图,最小生成树,普里姆(Prim)算法)
还是畅通工程 Time Limit : 4000/2000ms (Java/Other) Memory Limit : 65536/32768K (Java/Other) Total Submis ...
- 【mysql-python】安装+基本使用
安装:从SourceForge.net上下载最新的MySQLdb,http://sourceforge.net/projects/mysql-python/ 运行exe文件 使用 From:http: ...
- C# 流总结(Stream)
本篇文章简单总结了在C#编程中经常会用到的一些流.比如说FileStream.MemoryStream. BufferedStream. NetWorkStream. StreamReader/Str ...
- iOS开发之--MVC 架构模式
随着项目开发时间的增加,从刚开始那种很随意的代码风格,逐渐会改变,现在就介绍下MVC的架构模式,MVC的架构模式,从字面意思上讲,即:MVC 即 Modal View Controller(模型 视图 ...
- 在visual studio中运行C++心得
1.在visual studio中建立C++项目 (1)新建->项目->空项目 C++ (2)右击项目->添加->新建项->C++文件(.app) (3编写C++文件 ...