xgboost: 速度快效果好的boosting模型
转自:http://cos.name/2015/03/xgboost/
本文作者:何通,SupStat Inc(总部在纽约,中国分部为北京数博思达信息科技有限公司)数据科学家,加拿大Simon Fraser University计算机学院研究生,研究兴趣为数据挖掘和生物信息学。
引言
在数据分析的过程中,我们经常需要对数据建模并做预测。在众多的选择中,randomForest, gbm和glmnet是三个尤其流行的R包,它们在Kaggle的各大数据挖掘竞赛中的出现频率独占鳌头,被坊间人称为R数据挖掘包中的三驾马车。根据我的个人经验,gbm包比同样是使用树模型的randomForest包占用的内存更少,同时训练速度较快,尤其受到大家的喜爱。在python的机器学习库sklearn里也有GradientBoostingClassifier的存在。
Boosting分类器属于集成学习模型,它基本思想是把成百上千个分类准确率较低的树模型组合起来,成为一个准确率很高的模型。这个模型会不断地迭代,每次迭代就生成一颗新的树。对于如何在每一步生成合理的树,大家提出了很多的方法,我们这里简要介绍由Friedman提出的Gradient Boosting Machine。它在生成每一棵树的时候采用梯度下降的思想,以之前生成的所有树为基础,向着最小化给定目标函数的方向多走一步。在合理的参数设置下,我们往往要生成一定数量的树才能达到令人满意的准确率。在数据集较大较复杂的时候,我们可能需要几千次迭代运算,如果生成一个树模型需要几秒钟,那么这么多迭代的运算耗时,应该能让你专心地想静静……
现在,我们希望能通过xgboost工具更好地解决这个问题。xgboost的全称是eXtreme Gradient Boosting。正如其名,它是Gradient Boosting Machine的一个c++实现,作者为正在华盛顿大学研究机器学习的大牛陈天奇。他在研究中深感自己受制于现有库的计算速度和精度,因此在一年前开始着手搭建xgboost项目,并在去年夏天逐渐成型。xgboost最大的特点在于,它能够自动利用CPU的多线程进行并行,同时在算法上加以改进提高了精度。它的处女秀是Kaggle的希格斯子信号识别竞赛,因为出众的效率与较高的预测准确度在比赛论坛中引起了参赛选手的广泛关注,在1700多支队伍的激烈竞争中占有一席之地。随着它在Kaggle社区知名度的提高,最近也有队伍借助xgboost在比赛中夺得第一。
为了方便大家使用,陈天奇将xgboost封装成了python库。我有幸和他合作,制作了xgboost工具的R语言接口,并将其提交到了CRAN上。也有用户将其封装成了julia库。python和R接口的功能一直在不断更新,大家可以通过下文了解大致的功能,然后选择自己最熟悉的语言进行学习。
功能介绍
一、基础功能
首先,我们从github上安装这个包:
devtools::install_github('dmlc/xgboost',subdir='R-package')
动手时间到!第一步,运行下面的代码载入样例数据:
require(xgboost)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train <- agaricus.train
test <- agaricus.test
这份数据需要我们通过一些蘑菇的若干属性判断这个品种是否有毒。数据以1或0来标记某个属性存在与否,所以样例数据为稀疏矩阵类型:
> class(train$data)
[1] "dgCMatrix"
attr(,"package")
[1] "Matrix"
不用担心,xgboost支持稀疏矩阵作为输入。下面就是训练模型的命令
> bst <- xgboost(data = train$data, label = train$label, max.depth = 2, eta = 1,
+ nround = 2, objective = "binary:logistic")
[0] train-error:0.046522
[1] train-error:0.022263
我们迭代了两次,可以看到函数输出了每一次迭代模型的误差信息。这里的数据是稀疏矩阵,当然也支持普通的稠密矩阵。如果数据文件太大不希望读进R中,我们也可以通过设置参数data = 'path_to_file'
使其直接从硬盘读取数据并分析。目前支持直接从硬盘读取libsvm格式的文件。
做预测只需要一句话:
pred <- predict(bst, test$data)
做交叉验证的函数参数与训练函数基本一致,只需要在原有参数的基础上设置nfold
:
> cv.res <- xgb.cv(data = train$data, label = train$label, max.depth = 2,
+ eta = 1, nround = 2, objective = "binary:logistic",
+ nfold = 5)
[0] train-error:0.046522+0.001102 test-error:0.046523+0.004410
[1] train-error:0.022264+0.000864 test-error:0.022266+0.003450
> cv.res
train.error.mean train.error.std test.error.mean test.error.std
1: 0.046522 0.001102 0.046523 0.004410
2: 0.022264 0.000864 0.022266 0.003450
交叉验证的函数会返回一个data.table类型的结果,方便我们监控训练集和测试集上的表现,从而确定最优的迭代步数。
二、高速准确
上面的几行代码只是一个入门,使用的样例数据没法表现出xgboost高效准确的能力。xgboost通过如下的优化使得效率大幅提高:
- xgboost借助OpenMP,能自动利用单机CPU的多核进行并行计算。需要注意的是,Mac上的Clang对OpenMP的支持较差,所以默认情况下只能单核运行。
- xgboost自定义了一个数据矩阵类DMatrix,会在训练开始时进行一遍预处理,从而提高之后每次迭代的效率。
在尽量保证所有参数都一致的情况下,我们使用希格斯子竞赛的数据做了对照实验。
MODEL AND PARAMETER | GBM | XGBOOST | |||
---|---|---|---|---|---|
1 thread | 2 threads | 4 threads | 8 threads | ||
Time (in secs) | 761.48 | 450.22 | 102.41 | 44.18 | 34.04 |
以上实验使用的CPU是i7-4700MQ。python的sklearn速度与gbm相仿。如果想要自己对这个结果进行测试,可以在比赛的官方网站下载数据,并参考这份demo中的代码。
除了明显的速度提升外,xgboost在比赛中的效果也非常好。在这个竞赛初期,大家惊讶地发现R和python中的gbm竟然难以突破组织者预设的benchmark。而xgboost横空出世,用不到一分钟的训练时间便打入当时的top 10,引起了大家的兴趣与关注。准确度提升的主要原因在于,xgboost的模型和传统的GBDT相比加入了对于模型复杂度的控制以及后期的剪枝处理,使得学习出来的模型更加不容易过拟合。更多算法上的细节可以参考这份陈天奇给出的介绍性讲义。
三、进阶特征
除了速度快精度高,xgboost还有一些很有用的进阶特性。下面的“demo”链接对应着相应功能的简单样例代码。
- 只要能够求出目标函数的梯度和Hessian矩阵,用户就可以自定义训练模型时的目标函数。demo
- 允许用户在交叉验证时自定义误差衡量方法,例如回归中使用RMSE还是RMSLE,分类中使用AUC,分类错误率或是F1-score。甚至是在希格斯子比赛中的“奇葩”衡量标准AMS。demo
- 交叉验证时可以返回模型在每一折作为预测集时的预测结果,方便构建ensemble模型。demo
- 允许用户先迭代1000次,查看此时模型的预测效果,然后继续迭代1000次,最后模型等价于一次性迭代2000次。demo
- 可以知道每棵树将样本分类到哪片叶子上,facebook介绍过如何利用这个信息提高模型的表现。demo
- 可以计算变量重要性并画出树状图。demo
- 可以选择使用线性模型替代树模型,从而得到带L1+L2惩罚的线性回归或者logistic回归。demo
这些丰富的功能来源于对日常使用场景的总结,数据挖掘比赛需求以及许多用户给出的精彩建议。
四、未来计划
现在,机器学习工具在实用中会不可避免地遇到“单机性能不够”的问题。目前,xgboost的多机分布式版本正在开发当中。基础设施搭建完成之日,便是新一轮R包开始设计与升级之时。
结语
我为xgboost制作R接口的目的就是希望引进好的工具,让大家使用R的时候心情更愉悦。总结下来,xgboost的特点有三个:速度快,效果好,功能多,希望它能受到大家的喜爱,成为一驾新的马车。
xgboost功能较多,参数设置比较繁杂,希望在上手之后有更全面了解的读者可以参考项目wiki。欢迎大家多多交流,在项目issue区提出疑问与建议。我们也邀请有兴趣的读者提交代码完善功能,让xgboost成为更好用的工具。
另外,在使用github开发的过程中,我深切地感受到了协作写代码带来的变化。一群人在一起的时候,可以写出更有效率的代码,在丰富的使用场景中发现新的需求,在极端情况发现隐藏很深的bug,甚至在主代码手拖延症较为忙碌的时候有人挺身而出拿下一片issue。这样的氛围,能让一个语言社区的交流丰富起来,从而充满生命力地活跃下去。
xgboost: 速度快效果好的boosting模型的更多相关文章
- 百度DMLC分布式深度机器学习开源项目(简称“深盟”)上线了如xgboost(速度快效果好的Boosting模型)、CXXNET(极致的C++深度学习库)、Minerva(高效灵活的并行深度学习引擎)以及Parameter Server(一小时训练600T数据)等产品,在语音识别、OCR识别、人脸识别以及计算效率提升上发布了多个成熟产品。
百度为何开源深度机器学习平台? 有一系列领先优势的百度却选择开源其深度机器学习平台,为何交底自己的核心技术?深思之下,却是在面对业界无奈时的远见之举. 5月20日,百度在github上开源了其 ...
- 集成学习-Boosting 模型深度串讲
首先强调一下,这篇文章适合有很好的基础的人 梯度下降 这里不系统讲,只介绍相关的点,便于理解后文 先放一个很早以前写的 梯度下降 实现 logistic regression 的代码 def tidu ...
- Cesium专栏-裁剪效果(基于3dtiles模型,附源码下载)
Cesium Cesium 是一款面向三维地球和地图的,世界级的JavaScript开源产品.它提供了基于JavaScript语言的开发包,方便用户快速搭建一款零插件的虚拟地球Web应用,并在性能,精 ...
- R语言︱XGBoost极端梯度上升以及forecastxgb(预测)+xgboost(回归)双案例解读
XGBoost不仅仅可以用来做分类还可以做时间序列方面的预测,而且已经有人做的很好,可以见最后的案例. 应用一:XGBoost用来做预测 ------------------------------- ...
- xgboost入门与实战(原理篇)
sklearn实战-乳腺癌细胞数据挖掘 https://study.163.com/course/introduction.htm?courseId=1005269003&utm_campai ...
- XGBoost浅入浅出
http://wepon.me/ XGBoost风靡Kaggle.天池.DataCastle.Kesci等国内外数据竞赛平台,是比赛夺冠的必备大杀器.我在之前参加过的一些比赛中,着实领略了其威力,也取 ...
- Boosting(提升方法)之XGBoost
XGBoost是一个机器学习味道非常浓厚的模型,在数学上非常规范,运用正则化.L2范数.二阶梯度.泰勒公式和分布式计算方法,对GBDT等提升树模型进行优化,不仅能处理更大规模的数据,而且运行效率特别高 ...
- 论文笔记 XGBoost: A Scalable Tree Boosting System
XGBoost是boosting算法的其中一种.Boosting算法的思想是将许多弱分类器集成在一起形成一个强分类器,其更关注与降低基模型的偏差.XGBoost是一种提升树模型(Gradient bo ...
- xgboost原理总结和代码展示
关于xgboost的学习推荐两篇博客,每篇看2遍,我都能看懂,你肯定没问题 两篇方法互通,知识点互补!记录下来,方便以后查看 第一篇:作者:milter链接:https://www.jianshu.c ...
随机推荐
- Prometheus Node_exporter 之 Basic CPU / Mem Graph
1. CPU Basic cpu 的基本信息 /proc/stat type: GraphUnit: shortBusy System: cpu 处于核心态的占比 metrics: sum by (i ...
- 对MBProgressHUD进行二次封装并精简使用
对MBProgressHUD进行二次封装并精简使用 https://github.com/jdg/MBProgressHUD 几个效果图: 以下源码是MBProgressHUD支持最新的iOS8的版本 ...
- 利用Linode面板Clone克隆搬家迁移不同VPS数据及利用IP Swap迁移IP地址
在众多海外VPS服务商中,老蒋个人认为Linode提供的VPS方案和性价比还是比较高的,尤其目前基础1GB方案仅需10美元每月且全部是SSD固态硬盘,无论是流量还是硬盘大小,基本上可以满足我们大部分用 ...
- 一、JDBC的概述 二、通过JDBC实现对数据的CRUD操作 三、封装JDBC访问数据的工具类 四、通过JDBC实现登陆和注册 五、防止SQL注入
一.JDBC的概述###<1>概念 JDBC:java database connection ,java数据库连接技术 是java内部提供的一套操作数据库的接口(面向接口编程),实现对数 ...
- 【转】【Flex】FLEX 学习网站分享
[转:http://hi.baidu.com/tanghecaiyu/item/d662fbd7f5fbe02c38f6f764 ] FLEX 学习网站分享 http://blog.minidx.co ...
- 面对对象程序设计_task2_C++视频教程
lessons about C++ 1月份的事情不该留到2月份来做,这几天看了几个地方的C++视频教程,不习惯于云课堂的话多等等,最终还是选择了慕课网上面的资源,也安下心来看了一些内容,下面附上课程详 ...
- [T-ARA][yayaya]
歌词来源:http://music.163.com/#/song?id=22704449 U look at me Right T-ARA U Ready Let me seeya LaLaLaLa ...
- 基于jquery分页插件
今天终于完成了基于jquery的分页插件的代码编写,也通过了功能测试,实现了分页功能:由于刚开始写jquery的插件,所以梳理逻辑的时间也很长,整个过程整整一周时间,今天终于搞完了,先将整个分页插件的 ...
- etherlime-3-Etherlime Library API-Deployed Contract Wrapper
Deployed Contract Wrapper部署合约的封装 Wrappers封装 One of the advancements of the etherlime is the result o ...
- 多线程之并发容器ConcurrentHashMap
这部分内容转载自: http://www.haogongju.net/art/2350374 JDK5中添加了新的concurrent包,相对同步容器而言,并发容器通过一些机制改进了并发性能.因为同步 ...