本文主要对 Spark ML库下模型评估指标的讲解,以下代码均以Jupyter Notebook进行讲解,Spark版本为2.4.5。模型评估指标位于包org.apache.spark.ml.evaluation下。

模型评估指标是指测试集的评估指标,而不是训练集的评估指标

1、回归评估指标

RegressionEvaluator

Evaluator for regression, which expects two input columns: prediction and label.

评估指标支持以下几种:

val metricName: Param[String]

  • "rmse" (default): root mean squared error
  • "mse": mean squared error
  • "r2": R2 metric
  • "mae": mean absolute error

Examples

  1. # import dependencies
  2. import org.apache.spark.ml.regression.LinearRegression
  3. import org.apache.spark.ml.evaluation.RegressionEvaluator
  4. // Load training data
  5. val data = spark.read.format("libsvm")
  6. .load("/data1/software/spark/data/mllib/sample_linear_regression_data.txt")
  7. val lr = new LinearRegression()
  8. .setMaxIter(10)
  9. .setRegParam(0.3)
  10. .setElasticNetParam(0.8)
  11. // Fit the model
  12. val lrModel = lr.fit(training)
  13. // Summarize the model over the training set and print out some metrics
  14. val trainingSummary = lrModel.summary
  15. println(s"Train MSE: ${trainingSummary.meanSquaredError}")
  16. println(s"Train RMSE: ${trainingSummary.rootMeanSquaredError}")
  17. println(s"Train MAE: ${trainingSummary.meanAbsoluteError}")
  18. println(s"Train r2: ${trainingSummary.r2}")
  19. val predictions = lrModel.transform(test)
  20. // 计算精度
  21. val evaluator = new RegressionEvaluator()
  22. .setLabelCol("label")
  23. .setPredictionCol("prediction")
  24. .setMetricName("mse")
  25. val accuracy = evaluator.evaluate(predictions)
  26. print(s"Test MSE: ${accuracy}")

输出:

  1. Train MSE: 101.57870147367461
  2. Train RMSE: 10.078625971513905
  3. Train MAE: 8.108865602095849
  4. Train r2: 0.039467152584195975
  5. Test MSE: 114.28454406581636

2、分类评估指标

2.1 BinaryClassificationEvaluator

Evaluator for binary classification, which expects two input columns: rawPrediction and label. The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities).

评估指标支持以下几种:

  1. val metricName: Param[String]
  2. param for metric name in evaluation (supports "areaUnderROC" (default), "areaUnderPR")

Examples

  1. import org.apache.spark.ml.classification.LogisticRegression
  2. import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
  3. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
  4. // Load training data
  5. val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_libsvm_data.txt")
  6. val Array(train, test) = data.randomSplit(Array(0.8, 0.2))
  7. val lr = new LogisticRegression()
  8. .setMaxIter(10)
  9. .setRegParam(0.3)
  10. .setElasticNetParam(0.8)
  11. // Fit the model
  12. val lrModel = lr.fit(train)
  13. // Summarize the model over the training set and print out some metrics
  14. val trainSummary = lrModel.summary
  15. println(s"Train accuracy: ${trainSummary.accuracy}")
  16. println(s"Train weightedPrecision: ${trainSummary.weightedPrecision}")
  17. println(s"Train weightedRecall: ${trainSummary.weightedRecall}")
  18. println(s"Train weightedFMeasure: ${trainSummary.weightedFMeasure}")
  19. val predictions = lrModel.transform(test)
  20. predictions.show(5)
  21. // 模型评估
  22. val evaluator = new BinaryClassificationEvaluator()
  23. .setLabelCol("label")
  24. .setRawPredictionCol("rawPrediction")
  25. .setMetricName("areaUnderROC")
  26. val auc = evaluator.evaluate(predictions)
  27. print(s"Test AUC: ${auc}")
  28. val mulEvaluator = new MulticlassClassificationEvaluator()
  29. .setLabelCol("label")
  30. .setPredictionCol("prediction")
  31. .setMetricName("weightedPrecision")
  32. val precision = evaluator.evaluate(predictions)
  33. print(s"Test weightedPrecision: ${precision}")

输出结果:

  1. Train accuracy: 0.9873417721518988
  2. Train weightedPrecision: 0.9876110961486668
  3. Train weightedRecall: 0.9873417721518987
  4. Train weightedFMeasure: 0.9873124561568825
  5. +-----+--------------------+--------------------+--------------------+----------+
  6. |label| features| rawPrediction| probability|prediction|
  7. +-----+--------------------+--------------------+--------------------+----------+
  8. | 0.0|(692,[122,123,148...|[0.29746771419036...|[0.57382336211209...| 0.0|
  9. | 0.0|(692,[125,126,127...|[0.42262389447949...|[0.60411095396791...| 0.0|
  10. | 0.0|(692,[126,127,128...|[0.74220898710237...|[0.67747871191347...| 0.0|
  11. | 0.0|(692,[126,127,128...|[0.77729372618481...|[0.68509655708828...| 0.0|
  12. | 0.0|(692,[127,128,129...|[0.70928896866149...|[0.67024402884354...| 0.0|
  13. +-----+--------------------+--------------------+--------------------+----------+
  14. Test AUC: 1.0
  15. Test weightedPrecision: 1.0

2.2 MulticlassClassificationEvaluator

Evaluator for multiclass classification, which expects two input columns: prediction and label.

注:既然适用于多分类,当然适用于上面的二分类

评估指标支持如下几种:

  1. val metricName: Param[String]
  2. param for metric name in evaluation (supports "f1" (default), "weightedPrecision", "weightedRecall", "accuracy")

Examples

  1. import org.apache.spark.ml.Pipeline
  2. import org.apache.spark.ml.classification.DecisionTreeClassificationModel
  3. import org.apache.spark.ml.classification.DecisionTreeClassifier
  4. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
  5. import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
  6. // Load the data stored in LIBSVM format as a DataFrame.
  7. val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_libsvm_data.txt")
  8. // Index labels, adding metadata to the label column.
  9. // Fit on whole dataset to include all labels in index.
  10. val labelIndexer = new StringIndexer()
  11. .setInputCol("label")
  12. .setOutputCol("indexedLabel")
  13. .fit(data)
  14. // Automatically identify categorical features, and index them.
  15. val featureIndexer = new VectorIndexer()
  16. .setInputCol("features")
  17. .setOutputCol("indexedFeatures")
  18. .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
  19. .fit(data)
  20. // Split the data into training and test sets (30% held out for testing).
  21. val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
  22. // Train a DecisionTree model.
  23. val dt = new DecisionTreeClassifier()
  24. .setLabelCol("indexedLabel")
  25. .setFeaturesCol("indexedFeatures")
  26. // Convert indexed labels back to original labels.
  27. val labelConverter = new IndexToString()
  28. .setInputCol("prediction")
  29. .setOutputCol("predictedLabel")
  30. .setLabels(labelIndexer.labels)
  31. // Chain indexers and tree in a Pipeline.
  32. val pipeline = new Pipeline()
  33. .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
  34. // Train model. This also runs the indexers.
  35. val model = pipeline.fit(trainingData)
  36. // Make predictions.
  37. val predictions = model.transform(testData)
  38. // Select example rows to display.
  39. predictions.select("predictedLabel", "label", "features").show(5)
  40. // Select (prediction, true label) and compute test error.
  41. val evaluator = new MulticlassClassificationEvaluator()
  42. .setLabelCol("indexedLabel")
  43. .setPredictionCol("prediction")
  44. .setMetricName("accuracy")
  45. val accuracy = evaluator.evaluate(predictions)
  46. println(s"Test Error = ${(1.0 - accuracy)}")

输出结果:

  1. +--------------+-----+--------------------+
  2. |predictedLabel|label| features|
  3. +--------------+-----+--------------------+
  4. | 0.0| 0.0|(692,[95,96,97,12...|
  5. | 0.0| 0.0|(692,[122,123,124...|
  6. | 0.0| 0.0|(692,[122,123,148...|
  7. | 0.0| 0.0|(692,[126,127,128...|
  8. | 0.0| 0.0|(692,[126,127,128...|
  9. +--------------+-----+--------------------+
  10. only showing top 5 rows
  11. Test Error = 0.040000000000000036

Spark ML机器学习库评估指标示例的更多相关文章

  1. 【Udacity】机器学习性能评估指标

    评估指标 Evaluation metrics 机器学习性能评估指标 选择合适的指标 分类与回归的不同性能指标 分类的指标(准确率.精确率.召回率和 F 分数) 回归的指标(平均绝对误差和均方误差) ...

  2. Spark ML机器学习

    Spark提供了常用机器学习算法的实现, 封装于spark.ml和spark.mllib中. spark.mllib是基于RDD的机器学习库, spark.ml是基于DataFrame的机器学习库. ...

  3. [机器学习] 性能评估指标(精确率、召回率、ROC、AUC)

    混淆矩阵 介绍这些概念之前先来介绍一个概念:混淆矩阵(confusion matrix).对于 k 元分类,其实它就是一个k x k的表格,用来记录分类器的预测结果.对于常见的二元分类,它的混淆矩阵是 ...

  4. UDA机器学习基础—评估指标

    这里举例说明 混淆矩阵  精确率 召回率  F1

  5. 机器学习性能评估指标(精确率、召回率、ROC、AUC)

    http://blog.csdn.net/u012089317/article/details/52156514 ,y^)=1nsamples∑i=1nsamples(yi−y^i)2

  6. Spark 中的机器学习库及示例

    MLlib 是 Spark 的机器学习库,旨在简化机器学习的工程实践工作,并方便扩展到更大规模.MLlib 由一些通用的学习算法和工具组成,包括分类.回归.聚类.协同过滤.降维等,同时还包括底层的优化 ...

  7. 《Spark 官方文档》机器学习库(MLlib)指南

    spark-2.0.2 机器学习库(MLlib)指南 MLlib是Spark的机器学习(ML)库.旨在简化机器学习的工程实践工作,并方便扩展到更大规模.MLlib由一些通用的学习算法和工具组成,包括分 ...

  8. 掌握Spark机器学习库(课程目录)

    第1章 初识机器学习 在本章中将带领大家概要了解什么是机器学习.机器学习在当前有哪些典型应用.机器学习的核心思想.常用的框架有哪些,该如何进行选型等相关问题. 1-1 导学 1-2 机器学习概述 1- ...

  9. [DeeplearningAI笔记]ML strategy_1_1正交化/单一数字评估指标

    机器学习策略 ML strategy 觉得有用的话,欢迎一起讨论相互学习~Follow Me 1.1 什么是ML策略 机器学习策略简介 情景模拟 假设你正在训练一个分类器,你的系统已经达到了90%准确 ...

随机推荐

  1. VUEJS文件扩展名esm.js和common.js是什么意思

    vue.js : vue.js则是直接用在<script>标签中的,完整版本,直接就可以通过script引用. vue.common.js :预编译调试时,CommonJS规范的格式,可以 ...

  2. 吴裕雄--天生自然python学习笔记:python 用 Open CV抓取脸部图形及保存

    将面部的范围识别出来后,可以对识别出来的部分进行抓取.抓取一张图片中 的部分图形是通过 pillow 包中的 crop 方法来实现的 我们首先学习用 pillow 包来读取图片文件,语法为: 例如,打 ...

  3. 在Linux中#!/usr/bin/python之后把后面的代码当成程序来执行。 但是在windows中用IDLE编程的话#后面的都是注释,之后的代码都被当成文本了。 该怎么样才能解决这个问题呢?

    本文转自:http://bbs.csdn.net/topics/392027744?locationNum=6&fps=1 这种问题是大神不屑于解答,小白又完全不懂的问题... 同遇到这个问题 ...

  4. OC门与OD门以及线与逻辑

    OC(Open Collector)门又叫集电极开路门,主要针对的是BJT电路(从上往下依次是基极,集电极,发射极)OD(Open Drain)门又叫漏极开路门,主要针对的是MOS管(从上往下依次是漏 ...

  5. confessed to doing|conform|confined|entitle|

    to admit that you have done something wrong or something that you feel guilty or bad about 坦白:供认,招认: ...

  6. Facebook要做约会服务,国内社交眼红吗?

    看看现在的各种相亲趣事就能深深感悟到,中国还是以家庭为重的国家.在传统文化的浸染下,国人始终是将家庭摆在第一位.而对于欧美等发达国家来说,他们固然也以家庭为重,但更注重的是男女之间的关系定位--恋爱也 ...

  7. 求求你,下次面试别再问我什么是 Spring AOP 和代理了!

    https://mbd.baidu.com/newspage/data/landingsuper?context=%7B%22nid%22%3A%22news_9403056301388627935% ...

  8. Javascript 表达式中连续的 && 和 || 之赋值区别

    为了区分赋值表达式中出现的连续的 ‘&&’和 ‘||’的不同的赋值含义,做了一个小测试,代码如下: function write(msg){     for(var i = 0; i ...

  9. SQL提高性能

    1.对外键建立索引,大数据量时性能提高明显(建索引可以直接[Merge Join],否则还须在查询时生成HASH表作[Hash Join]) 2.尽量少使用inner join,使用left join ...

  10. Java 获取Enumeration类型的集合

    学习到java的io流中关于序列流SequenceInputStream使用,其中把3个以上的流串联起来操作, 使用的参数是生成运行时类型为 InputStream 对象的 Enumeration 型 ...