1. /**
  2. * Created by lkl on 2017/12/6.
  3. */
  4. import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
  5. import org.apache.spark.mllib.linalg.Vectors
  6. import org.apache.spark.mllib.regression.LabeledPoint
  7. import org.apache.spark.mllib.tree.GradientBoostedTrees
  8. import org.apache.spark.mllib.tree.configuration.BoostingStrategy
  9. import org.apache.spark.sql.hive.HiveContext
  10. import org.apache.spark.{SparkConf, SparkContext}
  11. import scala.collection.mutable.ArrayBuffer
  12. object GradientBoostingClassificationForLK {
  13. //http://blog.csdn.net/xubo245/article/details/51499643
  14. def main(args: Array[String]): Unit = {
  15. val conf = new SparkConf().setAppName("GradientBoostingClassificationForLK")
  16. val sc = new SparkContext(conf)
  17.  
  18. // sc is an existing SparkContext.
  19. val hc = new HiveContext(sc)
  20.  
  21. if(args.length!=){
  22. println("请输入参数:trainingData对应的库名、表名、模型运行时间")
  23. System.exit()
  24. }
  25.  
  26. //分别传入库名、表名、对比效果路径
  27. // val database = args(0)
  28. // val table = args(1)
  29. // val date = args(2)
  30.  //lkl_card_score.overdue_result_all_new_woe
  31. val format = new java.text.SimpleDateFormat("yyyyMMdd")
  32. val database ="lkl_card_score"
  33. val table = "overdue_result_all_new_woe"
  34. val date =format.format(new java.util.Date())
  35. //提取数据集 RDD[LabeledPoint]
  36. //val data = hc.sql(s"select * from $database.$table").map{
  37.  
  38. val data = hc.sql(s"select * from lkl_card_score.overdue_result_all_new_woe").map{
  39. row =>
  40. var arr = new ArrayBuffer[Double]()
  41. //剔除label、contact字段
  42. for(i <- until row.size){
  43. if(row.isNullAt(i)){
  44. arr += 0.0
  45. }
  46. else if(row.get(i).isInstanceOf[Int])
  47. arr += row.getInt(i).toDouble
  48. else if(row.get(i).isInstanceOf[Double])
  49. arr += row.getDouble(i)
  50. else if(row.get(i).isInstanceOf[Long])
  51. arr += row.getLong(i).toDouble
  52. else if(row.get(i).isInstanceOf[String])
  53. arr += 0.0
  54. }
  55. LabeledPoint(row.getInt(), Vectors.dense(arr.toArray))
  56. }
  57. // Split the data into training and test sets (30% held out for testing)
  58. val splits = data.randomSplit(Array(0.7, 0.3))
  59. val (trainingData, testData) = (splits(), splits())
  60.  
  61. // Train a GradientBoostedTrees model.
  62. // The defaultParams for Classification use LogLoss by default.
  63. val boostingStrategy = BoostingStrategy.defaultParams("Classification")
  64. boostingStrategy.setNumIterations() // Note: Use more iterations in practice.
  65. boostingStrategy.treeStrategy.setNumClasses()
  66. boostingStrategy.treeStrategy.setMaxDepth()
  67. // Empty categoricalFeaturesInfo indicates all features are continuous.
  68. //boostingStrategy.treeStrategy.setCategoricalFeaturesInfo(Map[Int, Int]())
  69.  
  70. val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
  71.  
  72. // Evaluate model on test instances and compute test error
  73. val predictionAndLabels = testData.map { point =>
  74. val prediction = model.predict(point.features)
  75. (point.label, prediction)
  76. }
  77.  
  78. predictionAndLabels.map(x => {"predicts: "+x._1+"--> labels:"+x._2}).saveAsTextFile(s"hdfs://ns1/tmp/$date/predictionAndLabels")
  79. //===================================================================
  80. //使用BinaryClassificationMetrics评估模型
  81. val metrics = new BinaryClassificationMetrics(predictionAndLabels)
  82.  
  83. // Precision by threshold
  84. val precision = metrics.precisionByThreshold
  85. precision.map({case (t, p) =>
  86. "Threshold: "+t+"Precision:"+p
  87. }).saveAsTextFile(s"hdfs://ns1/tmp/$date/precision")
  88.  
  89. // Recall by threshold
  90. val recall = metrics.recallByThreshold
  91. recall.map({case (t, r) =>
  92. "Threshold: "+t+"Recall:"+r
  93. }).saveAsTextFile(s"hdfs://ns1/tmp/$date/recall")
  94.  
  95. //the beta factor in F-Measure computation.
  96. val f1Score = metrics.fMeasureByThreshold
  97. f1Score.map(x => {"Threshold: "+x._1+"--> F-score:"+x._2+"--> Beta = 1"})
  98. .saveAsTextFile(s"hdfs://ns1/tmp/$date/f1Score")
  99.  
  100. /**
  101. * 如果要选择Threshold, 这三个指标中, 自然F1最为合适
  102. * 求出最大的F1, 对应的threshold就是最佳的threshold
  103. */
  104. /*val maxFMeasure = f1Score.select(max("F-Measure")).head().getDouble(0)
  105. val bestThreshold = f1Score.where($"F-Measure" === maxFMeasure)
  106. .select("threshold").head().getDouble(0)*/
  107.  
  108. // Precision-Recall Curve
  109. val prc = metrics.pr
  110. prc.map(x => {"Recall: " + x._1 + "--> Precision: "+x._2 }).saveAsTextFile(s"hdfs://ns1/tmp/$date/prc")
  111.  
  112. // AUPRC,精度,召回曲线下的面积
  113. val auPRC = metrics.areaUnderPR
  114. sc.makeRDD(Seq("Area under precision-recall curve = " +auPRC)).saveAsTextFile(s"hdfs://ns1/tmp/$date/auPRC")
  115.  
  116. //roc
  117. val roc = metrics.roc
  118. roc.map(x => {"FalsePositiveRate:" + x._1 + "--> Recall: " +x._2}).saveAsTextFile(s"hdfs://ns1/tmp/$date/roc")
  119.  
  120. // AUC
  121. val auROC = metrics.areaUnderROC
  122. sc.makeRDD(Seq("Area under ROC = " + +auROC)).saveAsTextFile(s"hdfs://ns1/tmp/$date/auROC")
  123. println("Area under ROC = " + auROC)
  124.  
  125. val testErr = predictionAndLabels.filter(r => r._1 != r._2).count.toDouble / testData.count()
  126. sc.makeRDD(Seq("Test Mean Squared Error = " + testErr)).saveAsTextFile(s"hdfs://ns1/tmp/$date/testErr")
  127. sc.makeRDD(Seq("Learned regression tree model: " + model.toDebugString)).saveAsTextFile(s"hdfs://ns1/tmp/$date/GBDTclassification")
  128. }
  129.  
  130. }

lakala GradientBoostedTrees的更多相关文章

  1. lakala反欺诈建模实际应用代码GBDT监督学习

    /** * Created by lkl on 2018/1/16. */ import org.apache.spark.mllib.evaluation.BinaryClassificationM ...

  2. lakala proportion轨迹分析代码

    /** * Created by lkl on 2017/12/7. */ import breeze.numerics.abs import org.apache.spark.sql.SQLCont ...

  3. 决策树和基于决策树的集成方法(DT,RF,GBDT,XGBT)复习总结

    摘要: 1.算法概述 2.算法推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 内容: 1.算法概述 1.1 决策树(DT)是一种基本的分类和回归方法.在分类问题中它可以认为是if-the ...

  4. 《Spark 官方文档》机器学习库(MLlib)指南

    spark-2.0.2 机器学习库(MLlib)指南 MLlib是Spark的机器学习(ML)库.旨在简化机器学习的工程实践工作,并方便扩展到更大规模.MLlib由一些通用的学习算法和工具组成,包括分 ...

  5. ORACLE11G常用函数

    1 单值函数 1.1 日期函数 1.1.1 Round [舍入到最接近的日期](day:舍入到最接近的星期日) select sysdate S1, round(sysdate) S2 , round ...

  6. 决策树和基于决策树的集成方法(DT,RF,GBDT,XGB)复习总结

    摘要: 1.算法概述 2.算法推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 内容: 1.算法概述 1.1 决策树(DT)是一种基本的分类和回归方法.在分类问题中它可以认为是if-the ...

  7. MLlib--GBDT算法

    转载请标明出处http://www.cnblogs.com/haozhengfei/p/8b9cb1875288d9f6cfc2f5a9b2f10eac.html GBDT算法 江湖传言:GBDT算法 ...

  8. spark MLlib Classification and regression 学习

    二分类:SVMs,logistic regression,decision trees,random forests,gradient-boosted trees,naive Bayes 多分类:  ...

  9. Oracle分析函数及常用函数: over(),rank()over()作用及用法--分区(分组)求和& 不连续/连续排名

    (1)   函数:  over()的作用及用法:    -- 分区(分组)求和. sum() over( partition by column1 order by column2 )主要用来对某个字 ...

随机推荐

  1. [转载]生活在 Emacs 中

    Brian Bilbrey2002 年 8 月 20 日发布 教程简介 本教程讲什么? Emacs 是一个流行的无模式文本编辑器,有许多强大的功能.本教程将教您使用 Emacs 的基础知识.为了让您很 ...

  2. TRUNC 截取日期或数字,返回指定的值。

    TRUNC(number,num_digits) Number 需要截尾取整的数字. Num_digits 用于指定取整精度的数字.Num_digits 的默认值为 0.   /*********** ...

  3. 基于jQuery图片元素网格布局插件

    基于jQuery图片元素网格布局插件是一款可以将图片或HTML元素均匀分布排列为网格布局的jQuery插件jMosaic.效果图如下: 在线预览   源码下载 实现的代码. html代码: <c ...

  4. CTF之PHP黑魔法总结

    继上一篇php各版本的姿势(不同版本的利用特性),文章总结了php版本差异,现在在来一篇本地日记总结的php黑魔法,是以前做CTF时遇到并记录的,很适合在做CTF代码审计的时候翻翻看看. 一.要求变量 ...

  5. python.pandas read and write CSV file

    #read and write csv of pandasimport pandas as pd goog =pd.read_csv(r'C:\python\demo\LiaoXueFeng\data ...

  6. JQUERY根据值将input控件选中!

    <select>: $('#country').find("option[value = " + data.country + "]").attr( ...

  7. 使用Task代替ThreadPool和Thread

    转载:改善C#程序的建议9:使用Task代替ThreadPool和Thread 一:Task的优势 ThreadPool相比Thread来说具备了很多优势,但是ThreadPool却又存在一些使用上的 ...

  8. webbrowser取页面验证码

    碰到一个无比坑爹,外加蛋疼乳酸的问题.从昨天晚上发现bug,到今天下午解决问题,搞了大半天的时间.光是找问题就花了半天,解决问题的方法简单,但是方案的形成也是无比纠结的过程. 背景:webbrowse ...

  9. mysql深坑之--group_concat有长度限制!!!!默认1024

    在mysql中,有个函数叫“group_concat”,平常使用可能发现不了问题,在处理大数据的时候,会发现内容被截取了,其实MYSQL内部对这个是有设置的,默认不设置的长度是1024,如果我们需要更 ...

  10. SSH-CLIENT : gSTM

    Linux环境下可以使用终端命令行直接登录SSH帐号.但是对Linux新手,可能不太习惯用命令行,于是我就琢磨找一款Linux环境下可以图形化管理ssh帐号的客户端软件,还真让我找着了. gSTM,是 ...