参数设置

α:

梯度上升算法迭代时候权重更新公式中包含 α :

 
http://blog.csdn.net/lu597203933/article/details/38468303 
为了更好理解 α和最大迭代次数的作用,给出Python版的函数计算过程。

# 梯度上升算法-计算回归系数
# 每个回归系数初始化为1
# 重复R次:
# 计算整个数据集的梯度
# 使用α*梯度更新回归系数的向量
# 返回回归系数
def gradAscent(dataMatIn, classLabels,alpha=0.001,maxCycles = ):
dataMatrix = mat(dataMatIn) #转换为numpy数据类型
labelMat = mat(classLabels).transpose()
m,n = shape(dataMatrix)
maxCycles =
weights = ones((n,))
for k in range(maxCycles):
h = sigmoid(dataMatrix*weights)
error = (labelMat - h)
#计算真实类别与预测类别的差值,按照该差值的方向调整回归系数
weights = weights + alpha* dataMatrix.transpose() * error
return weights

λ:

λ,正则化参数(泛化能力),加正则化的前提是特征值要进行归一化

在实际应该过程中,为了增强模型的泛化能力,防止我们训练的模型过拟合,特别是对于大量的稀疏特征,模型复杂度比较高,需要进行降维,我们需要保证在训练误差最小化的基础上,通过加上正则化项减小模型复杂度。在逻辑回归中,有L1、L2进行正则化。> 
损失函数如下: 
 
http://www.bkjia.com/yjs/996300.html 
在损失函数里加入一个正则化项,正则化项就是权重的L1或者L2范数乘以一个λ,用来控制损失函数和正则化项的比重,直观的理解,首先防止过拟合的目的就是防止最后训练出来的模型过分的依赖某一个特征,当最小化损失函数的时候,某一维度很大,拟合出来的函数值与真实的值之间的差距很小,通过正则化可以使整体的cost变大,从而避免了过分依赖某一维度的结果。当然加正则化的前提是特征值要进行归一化。

threshold:

threshold变量用来控制分类的阈值,默认值为0.5。表示如果预测值小于threshold则为分类0.0,否则为1.0。

在Spark Java中 
ElasticNetParam : α ;RegParam :λ。

LogisticRegression lr=new LogisticRegression()
.setMaxIter()
.setRegParam(0.3)
.setElasticNetParam(0.2)
.setThreshold(0.5);

分类效果评估

参考:http://www.cnblogs.com/tovin/p/3816289.html 
http://blog.sina.com.cn/s/blog_900690c60101czyo.html 
http://blog.chinaunix.net/uid-446337-id-94448.html 
http://blog.csdn.net/abcjennifer/article/details/7834256

混淆矩阵(Confusion matrix): 
考虑一个二分问题,即将实例分成正类(positive)或负类(negative)。对一个二分问题来说,会出现四种情况。如果一个实例是正类并且也被 预测成正类,即为真正类(True positive),如果实例是负类被预测成正类,称之为假正类(False positive)。相应地,如果实例是负类被预测成负类,称之为真负类(True negative),正类被预测成负类则为假负类(false negative)。 
TP:正确肯定的数目; 
FN:漏报,没有正确找到的匹配的数目; 
FP:误报,给出的匹配是不正确的; 
TN:正确拒绝的非匹配对数 
 
精确率,precision = TP / (TP + FP) 
模型判为正的所有样本中有多少是真正的正样本 
召回率,recall = TP / (TP + FN) 
准确率,accuracy = (TP + TN) / (TP + FP + TN + FN) 
反映了分类器统对整个样本的判定能力——能将正的判定为正,负的判定为负 
如何在precision和Recall中权衡? 
F1 Score = P*R/2(P+R),其中P和R分别为 precision 和 recall 
在precision与recall都要求高的情况下,可以用F1 Score来衡量

为什么会有这么多指标呢? 
这是因为模式分类和机器学习的需要。判断一个分类器对所用样本的分类能力或者在不同的应用场合时,需要有不同的指标。 当总共有个100 个样本(P+N=100)时,假如只有一个正例(P=1),那么只考虑精确度的话,不需要进行任何模型的训练,直接将所有测试样本判为正例,那么 A 能达到 99%,非常高了,但这并没有反映出模型真正的能力。另外在统计信号分析中,对不同类的判断结果的错误的惩罚是不一样的。举例而言,雷达收到100个来袭导弹的信号,其中只有 3个是真正的导弹信号,其余 97 个是敌方模拟的导弹信号。假如系统判断 98 个(97 个模拟信号加一个真正的导弹信号)信号都是模拟信号,那么Accuracy=98%,很高了,剩下两个是导弹信号,被截掉,这时Recall=2/3=66.67%,Precision=2/2=100%,Precision也很高。但剩下的那颗导弹就会造成灾害。

ROC曲线和AUC 
有时候我们需要在精确率与召回率间进行权衡, 
调整分类器threshold取值,以FPR(假正率False-positive rate)为横坐标,TPR(True-positive rate)为纵坐标做ROC曲线; 
Area Under roc Curve(AUC):处于ROC curve下方的那部分面积的大小通常,AUC的值介于0.5到1.0之间,较大的AUC代表了较好的性能; 
精确率和召回率是互相影响的,理想情况下肯定是做到两者都高,但是一般情况下准精确率、召回率就低,召回率低、精确率高,当然如果两者都低,那是什么地方出问题了 

Spark 2.0分类评估

//获得回归模型训练的Summary
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); // Obtain the loss per iteration.
//每次迭代的损失,一般会逐渐减小
double[] objectiveHistory = trainingSummary.objectiveHistory();
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
} // Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
// classification problem.
//强制类型转换为二类LR的Summary,然后就可以用混淆矩阵,ROC等评估方法了。Spark2.0还无法针对多类
BinaryLogisticRegressionSummary binarySummary =
(BinaryLogisticRegressionSummary) trainingSummary; // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
Dataset<Row> roc = binarySummary.roc();//获得ROC
roc.show();//显示ROC数据表,可以用这个数据自己画ROC曲线
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());//AUC // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
//不同的阈值,计算不同的F1,然后通过最大的F1找出并重设模型的最佳阈值。
Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble();//获得最大的F1值
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble();//找出最大F1值对应的阈值(最佳阈值)
lrModel.setThreshold(bestThreshold);//并将模型的Threshold设置为选择出来的最佳分类阈值

Logistic回归完整的代码 
http://spark.apache.org/docs/latest/ml-classification-regression.html

package my.spark.ml.practice.classification;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions; public class myLogisticRegression { public static void main(String[] args) {
SparkSession spark=SparkSession
.builder()
.appName("LR")
.master("local[4]")
.config("spark.sql.warehouse.dir","file///:G:/Projects/Java/Spark/spark-warehouse" )
.getOrCreate();
String path="G:/Projects/CgyWin64/home/pengjy3/softwate/spark-2.0.0-bin-hadoop2.6/"
+ "data/mllib/sample_libsvm_data.txt"; //屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF); //Load trainning data
Dataset<Row> trainning_dataFrame=spark.read().format("libsvm").load(path); LogisticRegression lr=new LogisticRegression()
.setMaxIter()
.setRegParam(0.3)
.setElasticNetParam(0.2)
.setThreshold(0.5); //fit the model
LogisticRegressionModel lrModel=lr.fit(trainning_dataFrame); //print the coefficients and intercept for logistic regression
System.out.println
("Coefficient:"+lrModel.coefficients()+"Itercept"+lrModel.intercept()); //Extract the summary from the returned LogisticRegressionModel
LogisticRegressionTrainingSummary summary=lrModel.summary(); //Obtain the loss per iteration.
double[] objectiveHistory=summary.objectiveHistory();
for(double lossPerIteration:objectiveHistory){
System.out.println(lossPerIteration);
}
// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
// classification problem.
BinaryLogisticRegressionTrainingSummary binarySummary=
(BinaryLogisticRegressionTrainingSummary)summary;
//Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
Dataset<Row> roc=binarySummary.roc();
roc.show((int) roc.count());//显示全部的信息,roc.show()默认只显示20行
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble();
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble();
lrModel.setThreshold(bestThreshold);
} }

BinaryClassificationEvaluator

除了上述Logistic回归结果评估方法,在Spark2.0中,二分问题结果评估用BinaryClassificationEvaluator。 
参数: 
(1)labelCol: label column name (default: label, current: label) 
(2)metricName: metric name in evaluation (areaUnderROC|areaUnderPR) (default: areaUnderROC) 
看来是没有其它的评估方法 
(3)rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction, current: prediction):注意名字不是PredictionCol, 
只有这三个参数可以设置! 
自定义accuracy

//自定义计算accuracy,
Dataset<Row> predictDF=naiveBayesModel.transform(test); double total=(double) predictDF.count();
Encoder<Double> doubleEncoder=Encoders.DOUBLE();
Dataset<Double> accuracyDF=predictDF.map(new MapFunction<Row,Double>() {
@Override
public Double call(Row row) throws Exception {
if((double)row.get()==(double)row.get()){return 1.0;}
else {return 0.0;}
}
}, doubleEncoder);
accuracyDF.createOrReplaceTempView("view");
double correct=(double) spark.sql("SELECT value FROM view WHERE value=1.0").count();
System.out.println("accuracy "+(correct/total));

Spark2.0机器学习系列之4:Logistic回归及Binary分类(二分问题)结果评估的更多相关文章

  1. Spark2.0机器学习系列之7: MLPC(多层神经网络)

    Spark2.0 MLPC(多层神经网络分类器)算法概述 MultilayerPerceptronClassifier(MLPC)这是一个基于前馈神经网络的分类器,它是一种在输入层与输出层之间含有一层 ...

  2. Spark2.0机器学习系列之3:决策树

    概述 分类决策树模型是一种描述对实例进行分类的树形结构. 决策树可以看为一个if-then规则集合,具有“互斥完备”性质 .决策树基本上都是 采用的是贪心(即非回溯)的算法,自顶向下递归分治构造. 生 ...

  3. Spark2.0机器学习系列之1: 聚类算法(LDA)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:      (1)K-means      (2)Latent Dirichlet allocation (LDA)  ...

  4. Spark2.0机器学习系列之12: 线性回归及L1、L2正则化区别与稀疏解

    概述 线性回归拟合一个因变量与一个自变量之间的线性关系y=f(x).       Spark中实现了:       (1)普通最小二乘法       (2)岭回归(L2正规化)       (3)La ...

  5. Spark2.0机器学习系列之11: 聚类(幂迭代聚类, power iteration clustering, PIC)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:             (1)K-means             (2)Latent Dirichlet all ...

  6. Spark2.0机器学习系列之10: 聚类(高斯混合模型 GMM)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:      (1)K-means      (2)Latent Dirichlet allocation (LDA)  ...

  7. Spark2.0机器学习系列之9: 聚类(k-means,Bisecting k-means,Streaming k-means)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:      (1)K-means      (2)Latent Dirichlet allocation (LDA)  ...

  8. Spark2.0机器学习系列之8:多类分类问题(方法归总和分类结果评估)

    一对多(One-vs-Rest classifier) 将只能用于二分问题的分类(如Logistic回归.SVM)方法扩展到多类. 参考:http://www.cnblogs.com/CheeseZH ...

  9. Spark2.0机器学习系列之6:GBDT(梯度提升决策树)、GBDT与随机森林差异、参数调试及Scikit代码分析

    概念梳理 GBDT的别称 GBDT(Gradient Boost Decision Tree),梯度提升决策树.     GBDT这个算法还有一些其他的名字,比如说MART(Multiple Addi ...

随机推荐

  1. jq 按回车键 登陆

    $("body").keydown(function(e) {             if (event.which == "13") {//keyCode= ...

  2. [android] AndroidManifest.xml - 【 manifest -> 其他次要配置】

    <uses-sdk> 作用:使应用程序的兼容性更好,指明应用程序需要的最小API,编译API以及最大支持的API.值都是整数 <uses-sdk android:minSdkVers ...

  3. 【BZOJ】1045: [HAOI2008]糖果传递(中位数)

    http://www.lydsy.com/JudgeOnline/problem.php?id=1045 白书上有讲 没ac的坑点在,数据范围n<=1,000,000 #include < ...

  4. hdu 3791:二叉搜索树(数据结构,二叉搜索树 BST)

    二叉搜索树 Time Limit : 2000/1000ms (Java/Other)   Memory Limit : 32768/32768K (Java/Other) Total Submiss ...

  5. 有用的Python代码片段

    我列出的这些有用的Python代码片段,为我节省了大量的时间,并且我希望他们也能为你节省一些时间.大多数的这些片段出自寻找解决方案,查找博客和StackOverflow解决类似问题的答案.下面所有的代 ...

  6. ajax 请求登录超时跳转登录页的示例代码

    Ajax AJAX即“Asynchronous Javascript + XML”(异步JavaScript和XML),是指一种创建交互式网页应用的网页开发技术. 在Filter里判断是否登录,如果未 ...

  7. Django学习笔记第五篇--实战练习一--查询数据库并操作cookie

    一.启动项目: django-admin start mysite1 cd mysite1 python manage.py startapp loginapp 根据上文敲命令就可以创建好了一个项目结 ...

  8. HDFS编程

    HDFS编程主要API Hadoop类 功能 org.apache.hadoop.fs.FileSystem 一个通用文件系统的抽象基类,可以被分布式文件系统继承.所有的可能使用Hadoop文件系统的 ...

  9. 桌面不显示IE图标解决方法

    如果你的是WIN7,恰好你的桌面上IE图标不知道怎么被搞咩了. 向这里看过来,这里有一个文件你正好能用上. 使用方法很简单: 猛戳下载,猛戳两下安装 下载解决

  10. 170410、java Socket通信的简单例子(TCP)

    服务端代码: package com.bobohe.socket; import java.io.*; import java.net.*; import java.applet.Applet; pu ...