再xgboost的源码中有xgboost的SparkWithDataFrame的实现,如下:https://github.com/dmlc/xgboost/tree/master/jvm-packages。但是由于各种各样的原因吧,这些代码在我的IDE里面编译不过,因此又写了如下代码以供以后查阅使用。

package xgboost

import ml.dmlc.xgboost4j.scala.spark.{XGBoost, XGBoostModel}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.{Row, DataFrame, SparkSession} object App{
def main(args: Array[String]): Unit ={
val trainPath: String = "xxx/train.txt"
val testPath: String = "xxx/test.txt"
val binaryModelPath: String = "xxx/model.binary"
val textModelPath: String = "xxx/model.txt"
val spark = SparkSession
.builder()
.master("yarn")
.getOrCreate() // define xgboost parameters
val maxDepth = 3
val numRound = 4
val nworker = 1
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> maxDepth,
"objective" -> "binary:logistic").toMap //read libsvm file
var dfTrain = spark.read.format("libsvm").load(trainPath).toDF("labelCol", "featureCol")
var dfTest = spark.read.format("libsvm").load(testPath).toDF("labelCol", "featureCol")
dfTrain.show(true)
printf("begin...")
val model:XGBoostModel = XGBoost.trainWithDataFrame(dfTrain, paramMap, numRound, nworker,
useExternalMemory = true,
featureCol = "featureCol", labelCol = "labelCol",
missing = 0.0f) //predict the test set
val predict:DataFrame = model.transform(dfTest)
val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)
.rdd
.map{case Row(score:Double, label:Double) => (score, label)} //get the auc
val metric = new BinaryClassificationMetrics(scoreAndLabels)
val auc = metric.areaUnderROC()
println("auc:" + auc) //save model
this.saveBinaryModel(model, spark, binaryModelPath)
this.saveTextModel(model, spark, textModelPath, numRound, maxDepth)
} def saveBinaryModel(model:XGBoostModel, spark: SparkSession, path: String): Unit = {
model.saveModelAsHadoopFile(path)(spark.sparkContext)
} def saveTextModel(model:XGBoostModel, spark: SparkSession, path: String, numRound: Int, maxDepth: Int): Unit = {
val dumpModel = model
.booster
.getModelDump()
.toList
.zipWithIndex
.map(x => s"booster:[${x._2}]\n${x._1}") val header = s"numRound: $numRound, maxDepth: $maxDepth"
print(dumpModel)
import spark.implicits._
val text: List[String] = header +: dumpModel
text.toDF
.coalesce(1)
.write
.mode("overwrite")
.text(path)
}
}

  其中:

  1.训练集和测试集都是libsvm格式,如下所示:

1 3:1 10:1 11:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
0 3:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:1

  2.最终生成的模型如下所示:

numRound: 4, maxDepth: 3
booster:[0]
0:[f29<] yes=1,no=2,missing=2
1:leaf=0.152941
2:leaf=-0.191209 booster:[1]
0:[f29<2] yes=1,no=2,missing=2
1:leaf=0.141901
2:leaf=-0.174499 booster:[2]
0:[f29<2] yes=1,no=2,missing=2
1:leaf=0.132731
2:leaf=-0.161685 booster:[3]
0:[f29<2] yes=1,no=2,missing=2
1:leaf=0.124972
2:leaf=-0.15155

  相关解释:”numRound: 4, maxDepth: 3”表示生成树的个数为4,树的最大深度为3;booster[n]表示第n棵树;以下保存树的结构,0号节点为根节点,每个节点有两个子节点,节点序号按层序技术,即1号和2号节点为根节点0号节点的子节点,相同层的节点有相同缩进,且比父节点多一级缩进。
  在节点行,首先声明节点序号,中括号里写明该节点采用第几个特征(如f29即为训练数据的第29个特征),同时表明特征值划分条件,“[f29<2] yes=1,no=2,missing=2”:表示f29号特征大于2时该样本划分到1号叶子节点,f29>=2时划分到2号叶子节点,当没有该特征(None)划分到2号叶子节点。

  3.预测的结果如下:

|labelCol|featureCol                                                                                                                                                  |probabilities                          |prediction|
|1.0 |(126,[2,9,10,20,29,33,35,39,40,52,57,64,68,76,85,87,91,94,101,104,116,123],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.3652743101119995,0.6347256898880005]|1.0 |
|0.0 |(126,[2,9,19,20,22,33,35,38,40,52,55,64,68,76,85,87,91,94,101,105,115,119],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.6635029911994934,0.3364970088005066]|0.0 |

  

xgboost的SparkWithDataFrame版本实现的更多相关文章

  1. 在Window平台下安装xgboost的Python版本

    原文:http://blog.csdn.net/pengyulong/article/details/50515916 原文修改了两个地方才安装成功,第3步可以不用,第2步重新生成所有的就行了. 第4 ...

  2. 小巧玲珑:机器学习届快刀XGBoost的介绍和使用

    欢迎大家前往腾讯云技术社区,获取更多腾讯海量技术实践干货哦~ 作者:张萌 序言 XGBoost效率很高,在Kaggle等诸多比赛中使用广泛,并且取得了不少好成绩.为了让公司的算法工程师,可以更加方便的 ...

  3. xgboost 参数调优指南

    一.XGBoost的优势 XGBoost算法可以给预测模型带来能力的提升.当我对它的表现有更多了解的时候,当我对它的高准确率背后的原理有更多了解的时候,我发现它具有很多优势: 1 正则化 标准GBDT ...

  4. XGBoost 与 Boosted Tree

    http://www.52cs.org/?p=429 作者:陈天奇,毕业于上海交通大学ACM班,现就读于华盛顿大学,从事大规模机器学习研究. 注解:truth4sex  编者按:本文是对开源xgboo ...

  5. xgboost入门与实战(原理篇)

    sklearn实战-乳腺癌细胞数据挖掘 https://study.163.com/course/introduction.htm?courseId=1005269003&utm_campai ...

  6. 机器学习--boosting家族之XGBoost算法

    一.概念 XGBoost全名叫(eXtreme Gradient Boosting)极端梯度提升,经常被用在一些比赛中,其效果显著.它是大规模并行boosted tree的工具,它是目前最快最好的开源 ...

  7. xgboost 参数

    XGBoost 参数 在运行XGBoost程序之前,必须设置三种类型的参数:通用类型参数(general parameters).booster参数和学习任务参数(task parameters). ...

  8. XGBoost:在Python中使用XGBoost

    原文:http://blog.csdn.net/zc02051126/article/details/46771793 在Python中使用XGBoost 下面将介绍XGBoost的Python模块, ...

  9. 【转】XGBoost 与 Boosted Tree

    XGBoost 与 Boosted Tree http://www.52cs.org/?p=429 作者:陈天奇,毕业于上海交通大学ACM班,现就读于华盛顿大学,从事大规模机器学习研究. 注解:tru ...

随机推荐

  1. 缓存淘汰算法之LRU实现

    Java中最简单的LRU算法实现,就是利用 LinkedHashMap,覆写其中的removeEldestEntry(Map.Entry)方法即可 如果你去看LinkedHashMap的源码可知,LR ...

  2. csa Round #73 (Div. 2 only)

    Three Equal Time limit: 1000 msMemory limit: 256 MB   You are given an array AA of NN integers betwe ...

  3. Tyk-Hybrid模式安装—抽象方法论,重用它

    最近,公司有计划运用API网关.那么,在经过权衡之后,使用了Tyk的Hybrid模式!现在环境没问题了,API调用也测通了.我得想想合并服务,监控API实时情况的东西.但在这个环境搭建的过程中,我目前 ...

  4. php preg_replace去除html xml 注释

    php preg_replace去除html xml 注释 //不确定是否最优 $content = preg_replace('/<!--((?!-->).)*-->/s', '' ...

  5. [LOJ#516]「LibreOJ β Round #2」DP 一般看规律

    [LOJ#516]「LibreOJ β Round #2」DP 一般看规律 试题描述 给定一个长度为 \(n\) 的序列 \(a\),一共有 \(m\) 个操作. 每次操作的内容为:给定 \(x,y\ ...

  6. spring入门到放弃——spring事务管理

    Spring事务提供了两种管理的的方式:编程式事务和声明式事务 简单回顾下事务: 事务:逻辑上的一组操作,组成操作的各个单元,要么全部成功,要么全部失败. 事务特性: 原子性:一个事务包含的各个操作单 ...

  7. docker (centOS 7) 使用笔记3 - 修改docker默认的虚拟网址

    近日在使用VPN时发现和docker的虚拟网址发生了冲突,都是172.17.0.1,故需要修改docker的默认网址. 1. 当前状态 # ifconfig docker0: flags=<UP ...

  8. python 缺少包

    https://pypi.python.org/pypi/pdfminer/20140328 到这里下载相应的包,再进行安装. tar  –xivf  pybloomfilter-1.0 cd  py ...

  9. 方格取数(hdu 1565)

    Problem Description 给你一个n*n的格子的棋盘,每个格子里面有一个非负数.从中取出若干个数,使得任意的两个数所在的格子没有公共边,就是说所取的数所在的2个格子不能相邻,并且取出的数 ...

  10. JSTL <C:if></C:if> 和<C:ForEach></C:ForEach> 入门级~

    一.<C:If>标签:条件判断语句 <c:if test="${objList.nodetype == 1}">上级节点</c:if>   te ...