本文主要对 Spark ML库下模型评估指标的讲解,以下代码均以Jupyter Notebook进行讲解,Spark版本为2.4.5。模型评估指标位于包org.apache.spark.ml.evaluation下。

模型评估指标是指测试集的评估指标,而不是训练集的评估指标

1、回归评估指标

RegressionEvaluator

Evaluator for regression, which expects two input columns: prediction and label.

评估指标支持以下几种:

val metricName: Param[String]

  • "rmse" (default): root mean squared error
  • "mse": mean squared error
  • "r2": R2 metric
  • "mae": mean absolute error

Examples

# import dependencies
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.evaluation.RegressionEvaluator // Load training data
val data = spark.read.format("libsvm")
.load("/data1/software/spark/data/mllib/sample_linear_regression_data.txt") val lr = new LinearRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8) // Fit the model
val lrModel = lr.fit(training) // Summarize the model over the training set and print out some metrics
val trainingSummary = lrModel.summary
println(s"Train MSE: ${trainingSummary.meanSquaredError}")
println(s"Train RMSE: ${trainingSummary.rootMeanSquaredError}")
println(s"Train MAE: ${trainingSummary.meanAbsoluteError}")
println(s"Train r2: ${trainingSummary.r2}") val predictions = lrModel.transform(test) // 计算精度
val evaluator = new RegressionEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("mse")
val accuracy = evaluator.evaluate(predictions)
print(s"Test MSE: ${accuracy}")

输出:

Train MSE: 101.57870147367461
Train RMSE: 10.078625971513905
Train MAE: 8.108865602095849
Train r2: 0.039467152584195975 Test MSE: 114.28454406581636

2、分类评估指标

2.1 BinaryClassificationEvaluator

Evaluator for binary classification, which expects two input columns: rawPrediction and label. The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities).

评估指标支持以下几种:

val metricName: Param[String]
param for metric name in evaluation (supports "areaUnderROC" (default), "areaUnderPR")

Examples

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator // Load training data
val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_libsvm_data.txt") val Array(train, test) = data.randomSplit(Array(0.8, 0.2)) val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8) // Fit the model
val lrModel = lr.fit(train) // Summarize the model over the training set and print out some metrics
val trainSummary = lrModel.summary
println(s"Train accuracy: ${trainSummary.accuracy}")
println(s"Train weightedPrecision: ${trainSummary.weightedPrecision}")
println(s"Train weightedRecall: ${trainSummary.weightedRecall}")
println(s"Train weightedFMeasure: ${trainSummary.weightedFMeasure}") val predictions = lrModel.transform(test)
predictions.show(5) // 模型评估
val evaluator = new BinaryClassificationEvaluator()
.setLabelCol("label")
.setRawPredictionCol("rawPrediction")
.setMetricName("areaUnderROC")
val auc = evaluator.evaluate(predictions)
print(s"Test AUC: ${auc}") val mulEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("weightedPrecision")
val precision = evaluator.evaluate(predictions)
print(s"Test weightedPrecision: ${precision}")

输出结果:

Train accuracy: 0.9873417721518988
Train weightedPrecision: 0.9876110961486668
Train weightedRecall: 0.9873417721518987
Train weightedFMeasure: 0.9873124561568825 +-----+--------------------+--------------------+--------------------+----------+
|label| features| rawPrediction| probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
| 0.0|(692,[122,123,148...|[0.29746771419036...|[0.57382336211209...| 0.0|
| 0.0|(692,[125,126,127...|[0.42262389447949...|[0.60411095396791...| 0.0|
| 0.0|(692,[126,127,128...|[0.74220898710237...|[0.67747871191347...| 0.0|
| 0.0|(692,[126,127,128...|[0.77729372618481...|[0.68509655708828...| 0.0|
| 0.0|(692,[127,128,129...|[0.70928896866149...|[0.67024402884354...| 0.0|
+-----+--------------------+--------------------+--------------------+----------+ Test AUC: 1.0 Test weightedPrecision: 1.0

2.2 MulticlassClassificationEvaluator

Evaluator for multiclass classification, which expects two input columns: prediction and label.

注:既然适用于多分类,当然适用于上面的二分类

评估指标支持如下几种:

val metricName: Param[String]
param for metric name in evaluation (supports "f1" (default), "weightedPrecision", "weightedRecall", "accuracy")

Examples

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // Load the data stored in LIBSVM format as a DataFrame.
val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
// Automatically identify categorical features, and index them.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
.fit(data) // Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a DecisionTree model.
val dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures") // Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels) // Chain indexers and tree in a Pipeline.
val pipeline = new Pipeline()
.setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) // Train model. This also runs the indexers.
val model = pipeline.fit(trainingData) // Make predictions.
val predictions = model.transform(testData) // Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5) // Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")

输出结果:

+--------------+-----+--------------------+
|predictedLabel|label| features|
+--------------+-----+--------------------+
| 0.0| 0.0|(692,[95,96,97,12...|
| 0.0| 0.0|(692,[122,123,124...|
| 0.0| 0.0|(692,[122,123,148...|
| 0.0| 0.0|(692,[126,127,128...|
| 0.0| 0.0|(692,[126,127,128...|
+--------------+-----+--------------------+
only showing top 5 rows Test Error = 0.040000000000000036

Spark ML机器学习库评估指标示例的更多相关文章

  1. 【Udacity】机器学习性能评估指标

    评估指标 Evaluation metrics 机器学习性能评估指标 选择合适的指标 分类与回归的不同性能指标 分类的指标(准确率.精确率.召回率和 F 分数) 回归的指标(平均绝对误差和均方误差) ...

  2. Spark ML机器学习

    Spark提供了常用机器学习算法的实现, 封装于spark.ml和spark.mllib中. spark.mllib是基于RDD的机器学习库, spark.ml是基于DataFrame的机器学习库. ...

  3. [机器学习] 性能评估指标(精确率、召回率、ROC、AUC)

    混淆矩阵 介绍这些概念之前先来介绍一个概念:混淆矩阵(confusion matrix).对于 k 元分类,其实它就是一个k x k的表格,用来记录分类器的预测结果.对于常见的二元分类,它的混淆矩阵是 ...

  4. UDA机器学习基础—评估指标

    这里举例说明 混淆矩阵  精确率 召回率  F1

  5. 机器学习性能评估指标(精确率、召回率、ROC、AUC)

    http://blog.csdn.net/u012089317/article/details/52156514 ,y^)=1nsamples∑i=1nsamples(yi−y^i)2

  6. Spark 中的机器学习库及示例

    MLlib 是 Spark 的机器学习库,旨在简化机器学习的工程实践工作,并方便扩展到更大规模.MLlib 由一些通用的学习算法和工具组成,包括分类.回归.聚类.协同过滤.降维等,同时还包括底层的优化 ...

  7. 《Spark 官方文档》机器学习库(MLlib)指南

    spark-2.0.2 机器学习库(MLlib)指南 MLlib是Spark的机器学习(ML)库.旨在简化机器学习的工程实践工作,并方便扩展到更大规模.MLlib由一些通用的学习算法和工具组成,包括分 ...

  8. 掌握Spark机器学习库(课程目录)

    第1章 初识机器学习 在本章中将带领大家概要了解什么是机器学习.机器学习在当前有哪些典型应用.机器学习的核心思想.常用的框架有哪些,该如何进行选型等相关问题. 1-1 导学 1-2 机器学习概述 1- ...

  9. [DeeplearningAI笔记]ML strategy_1_1正交化/单一数字评估指标

    机器学习策略 ML strategy 觉得有用的话,欢迎一起讨论相互学习~Follow Me 1.1 什么是ML策略 机器学习策略简介 情景模拟 假设你正在训练一个分类器,你的系统已经达到了90%准确 ...

随机推荐

  1. js强制浏览器重新渲染页面

    今天遇到一个浏览器兼容性问题,大致原因就是在用某一个前端框架做分页时,由于是使用的jQuery的hide和show方法,其本质是为某个iframe加上一个display=none,这在chrome中是 ...

  2. 华为路由器AR1220E-S通过web页面不能登录

    问题原因:由于在WEB页面配置了“远程信任主机”,但是信任主机和路由器不在一个网段,导致所有IP都不能通过WEB页面管理路由器 解决方案:通过console口直接连接路由器,删除信任主机,此次咨询了华 ...

  3. Graylog

    Graylog #Graylog 是与 ELK 可以相提并论的一款集中式日志管理方案,支持数据收集.检索.可视化 ​#Graylog 架构 - Graylog 负责接收来自各种设备和应用的日志,并为用 ...

  4. java中的赋值

    java中的赋值使用符号“=”. 按照java编程思想的解释:它的意思是“取等号右边的值,把它复制给左边”. 当然左边必须是一个明确的,已命名的变量. 基本类型: int a=2; int b=3; ...

  5. Jmeter阶梯式压测

    https://www.cnblogs.com/Zfc-Cjk/p/11639219.html 什么是阶梯式压测? 阶梯式压测,就是对系统的压力呈现阶梯性增加的过程,每个阶段压力值都要增加一个数量值, ...

  6. 《运筹学基础及应用》习题1.1(b),1.1(c),1.2(a)

    用图解法求解下列线性规划问题,并指出问题具有惟一最优解,无穷多最优解,无界解还是无可行解. 习题1.1(b):$\max z=3x_1+2x_2$$$s.t\begin{cases}  2x_1+x_ ...

  7. sql查询语句解析过程--根据网络资料整理

    查询语句: (8)SELECT(9)DISTINCT(11)<TopNum> <selectlist> (1)FROM<left_table> (3)<joi ...

  8. The file named error_log is too large

    The file named errorlog is too large */--> The file named errorlog is too large 1 Problem One day ...

  9. django框架基础-路由系统-长期维护

    ##################   路由系统介绍    ####################### 路由系统就是路径和函数的对应关系, 路由系统可以看成支撑你这个网站的目录,就像是一本书一样 ...

  10. Critical-Value|Critical-Value Approach to Hypothesis Testing

    9.2 Critical-Value Approach to Hypothesis Testing example: 对于mean 值 275 的假设: 有一个关于sample mean的distri ...