xgboost: 速度快效果好的boosting模型
转自:http://cos.name/2015/03/xgboost/
本文作者:何通,SupStat Inc(总部在纽约,中国分部为北京数博思达信息科技有限公司)数据科学家,加拿大Simon Fraser University计算机学院研究生,研究兴趣为数据挖掘和生物信息学。
引言
在数据分析的过程中,我们经常需要对数据建模并做预测。在众多的选择中,randomForest, gbm和glmnet是三个尤其流行的R包,它们在Kaggle的各大数据挖掘竞赛中的出现频率独占鳌头,被坊间人称为R数据挖掘包中的三驾马车。根据我的个人经验,gbm包比同样是使用树模型的randomForest包占用的内存更少,同时训练速度较快,尤其受到大家的喜爱。在python的机器学习库sklearn里也有GradientBoostingClassifier的存在。
Boosting分类器属于集成学习模型,它基本思想是把成百上千个分类准确率较低的树模型组合起来,成为一个准确率很高的模型。这个模型会不断地迭代,每次迭代就生成一颗新的树。对于如何在每一步生成合理的树,大家提出了很多的方法,我们这里简要介绍由Friedman提出的Gradient Boosting Machine。它在生成每一棵树的时候采用梯度下降的思想,以之前生成的所有树为基础,向着最小化给定目标函数的方向多走一步。在合理的参数设置下,我们往往要生成一定数量的树才能达到令人满意的准确率。在数据集较大较复杂的时候,我们可能需要几千次迭代运算,如果生成一个树模型需要几秒钟,那么这么多迭代的运算耗时,应该能让你专心地想静静……
现在,我们希望能通过xgboost工具更好地解决这个问题。xgboost的全称是eXtreme Gradient Boosting。正如其名,它是Gradient Boosting Machine的一个c++实现,作者为正在华盛顿大学研究机器学习的大牛陈天奇。他在研究中深感自己受制于现有库的计算速度和精度,因此在一年前开始着手搭建xgboost项目,并在去年夏天逐渐成型。xgboost最大的特点在于,它能够自动利用CPU的多线程进行并行,同时在算法上加以改进提高了精度。它的处女秀是Kaggle的希格斯子信号识别竞赛,因为出众的效率与较高的预测准确度在比赛论坛中引起了参赛选手的广泛关注,在1700多支队伍的激烈竞争中占有一席之地。随着它在Kaggle社区知名度的提高,最近也有队伍借助xgboost在比赛中夺得第一。
为了方便大家使用,陈天奇将xgboost封装成了python库。我有幸和他合作,制作了xgboost工具的R语言接口,并将其提交到了CRAN上。也有用户将其封装成了julia库。python和R接口的功能一直在不断更新,大家可以通过下文了解大致的功能,然后选择自己最熟悉的语言进行学习。
功能介绍
一、基础功能
首先,我们从github上安装这个包:
devtools::install_github('dmlc/xgboost',subdir='R-package')
动手时间到!第一步,运行下面的代码载入样例数据:
require(xgboost)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train <- agaricus.train
test <- agaricus.test
这份数据需要我们通过一些蘑菇的若干属性判断这个品种是否有毒。数据以1或0来标记某个属性存在与否,所以样例数据为稀疏矩阵类型:
> class(train$data)
[1] "dgCMatrix"
attr(,"package")
[1] "Matrix"
不用担心,xgboost支持稀疏矩阵作为输入。下面就是训练模型的命令
> bst <- xgboost(data = train$data, label = train$label, max.depth = 2, eta = 1,
+ nround = 2, objective = "binary:logistic")
[0] train-error:0.046522
[1] train-error:0.022263
我们迭代了两次,可以看到函数输出了每一次迭代模型的误差信息。这里的数据是稀疏矩阵,当然也支持普通的稠密矩阵。如果数据文件太大不希望读进R中,我们也可以通过设置参数data = 'path_to_file'
使其直接从硬盘读取数据并分析。目前支持直接从硬盘读取libsvm格式的文件。
做预测只需要一句话:
pred <- predict(bst, test$data)
做交叉验证的函数参数与训练函数基本一致,只需要在原有参数的基础上设置nfold
:
> cv.res <- xgb.cv(data = train$data, label = train$label, max.depth = 2,
+ eta = 1, nround = 2, objective = "binary:logistic",
+ nfold = 5)
[0] train-error:0.046522+0.001102 test-error:0.046523+0.004410
[1] train-error:0.022264+0.000864 test-error:0.022266+0.003450
> cv.res
train.error.mean train.error.std test.error.mean test.error.std
1: 0.046522 0.001102 0.046523 0.004410
2: 0.022264 0.000864 0.022266 0.003450
交叉验证的函数会返回一个data.table类型的结果,方便我们监控训练集和测试集上的表现,从而确定最优的迭代步数。
二、高速准确
上面的几行代码只是一个入门,使用的样例数据没法表现出xgboost高效准确的能力。xgboost通过如下的优化使得效率大幅提高:
- xgboost借助OpenMP,能自动利用单机CPU的多核进行并行计算。需要注意的是,Mac上的Clang对OpenMP的支持较差,所以默认情况下只能单核运行。
- xgboost自定义了一个数据矩阵类DMatrix,会在训练开始时进行一遍预处理,从而提高之后每次迭代的效率。
在尽量保证所有参数都一致的情况下,我们使用希格斯子竞赛的数据做了对照实验。
MODEL AND PARAMETER | GBM | XGBOOST | |||
---|---|---|---|---|---|
1 thread | 2 threads | 4 threads | 8 threads | ||
Time (in secs) | 761.48 | 450.22 | 102.41 | 44.18 | 34.04 |
以上实验使用的CPU是i7-4700MQ。python的sklearn速度与gbm相仿。如果想要自己对这个结果进行测试,可以在比赛的官方网站下载数据,并参考这份demo中的代码。
除了明显的速度提升外,xgboost在比赛中的效果也非常好。在这个竞赛初期,大家惊讶地发现R和python中的gbm竟然难以突破组织者预设的benchmark。而xgboost横空出世,用不到一分钟的训练时间便打入当时的top 10,引起了大家的兴趣与关注。准确度提升的主要原因在于,xgboost的模型和传统的GBDT相比加入了对于模型复杂度的控制以及后期的剪枝处理,使得学习出来的模型更加不容易过拟合。更多算法上的细节可以参考这份陈天奇给出的介绍性讲义。
三、进阶特征
除了速度快精度高,xgboost还有一些很有用的进阶特性。下面的“demo”链接对应着相应功能的简单样例代码。
- 只要能够求出目标函数的梯度和Hessian矩阵,用户就可以自定义训练模型时的目标函数。demo
- 允许用户在交叉验证时自定义误差衡量方法,例如回归中使用RMSE还是RMSLE,分类中使用AUC,分类错误率或是F1-score。甚至是在希格斯子比赛中的“奇葩”衡量标准AMS。demo
- 交叉验证时可以返回模型在每一折作为预测集时的预测结果,方便构建ensemble模型。demo
- 允许用户先迭代1000次,查看此时模型的预测效果,然后继续迭代1000次,最后模型等价于一次性迭代2000次。demo
- 可以知道每棵树将样本分类到哪片叶子上,facebook介绍过如何利用这个信息提高模型的表现。demo
- 可以计算变量重要性并画出树状图。demo
- 可以选择使用线性模型替代树模型,从而得到带L1+L2惩罚的线性回归或者logistic回归。demo
这些丰富的功能来源于对日常使用场景的总结,数据挖掘比赛需求以及许多用户给出的精彩建议。
四、未来计划
现在,机器学习工具在实用中会不可避免地遇到“单机性能不够”的问题。目前,xgboost的多机分布式版本正在开发当中。基础设施搭建完成之日,便是新一轮R包开始设计与升级之时。
结语
我为xgboost制作R接口的目的就是希望引进好的工具,让大家使用R的时候心情更愉悦。总结下来,xgboost的特点有三个:速度快,效果好,功能多,希望它能受到大家的喜爱,成为一驾新的马车。
xgboost功能较多,参数设置比较繁杂,希望在上手之后有更全面了解的读者可以参考项目wiki。欢迎大家多多交流,在项目issue区提出疑问与建议。我们也邀请有兴趣的读者提交代码完善功能,让xgboost成为更好用的工具。
另外,在使用github开发的过程中,我深切地感受到了协作写代码带来的变化。一群人在一起的时候,可以写出更有效率的代码,在丰富的使用场景中发现新的需求,在极端情况发现隐藏很深的bug,甚至在主代码手拖延症较为忙碌的时候有人挺身而出拿下一片issue。这样的氛围,能让一个语言社区的交流丰富起来,从而充满生命力地活跃下去。
xgboost: 速度快效果好的boosting模型的更多相关文章
- 百度DMLC分布式深度机器学习开源项目(简称“深盟”)上线了如xgboost(速度快效果好的Boosting模型)、CXXNET(极致的C++深度学习库)、Minerva(高效灵活的并行深度学习引擎)以及Parameter Server(一小时训练600T数据)等产品,在语音识别、OCR识别、人脸识别以及计算效率提升上发布了多个成熟产品。
百度为何开源深度机器学习平台? 有一系列领先优势的百度却选择开源其深度机器学习平台,为何交底自己的核心技术?深思之下,却是在面对业界无奈时的远见之举. 5月20日,百度在github上开源了其 ...
- 集成学习-Boosting 模型深度串讲
首先强调一下,这篇文章适合有很好的基础的人 梯度下降 这里不系统讲,只介绍相关的点,便于理解后文 先放一个很早以前写的 梯度下降 实现 logistic regression 的代码 def tidu ...
- Cesium专栏-裁剪效果(基于3dtiles模型,附源码下载)
Cesium Cesium 是一款面向三维地球和地图的,世界级的JavaScript开源产品.它提供了基于JavaScript语言的开发包,方便用户快速搭建一款零插件的虚拟地球Web应用,并在性能,精 ...
- R语言︱XGBoost极端梯度上升以及forecastxgb(预测)+xgboost(回归)双案例解读
XGBoost不仅仅可以用来做分类还可以做时间序列方面的预测,而且已经有人做的很好,可以见最后的案例. 应用一:XGBoost用来做预测 ------------------------------- ...
- xgboost入门与实战(原理篇)
sklearn实战-乳腺癌细胞数据挖掘 https://study.163.com/course/introduction.htm?courseId=1005269003&utm_campai ...
- XGBoost浅入浅出
http://wepon.me/ XGBoost风靡Kaggle.天池.DataCastle.Kesci等国内外数据竞赛平台,是比赛夺冠的必备大杀器.我在之前参加过的一些比赛中,着实领略了其威力,也取 ...
- Boosting(提升方法)之XGBoost
XGBoost是一个机器学习味道非常浓厚的模型,在数学上非常规范,运用正则化.L2范数.二阶梯度.泰勒公式和分布式计算方法,对GBDT等提升树模型进行优化,不仅能处理更大规模的数据,而且运行效率特别高 ...
- 论文笔记 XGBoost: A Scalable Tree Boosting System
XGBoost是boosting算法的其中一种.Boosting算法的思想是将许多弱分类器集成在一起形成一个强分类器,其更关注与降低基模型的偏差.XGBoost是一种提升树模型(Gradient bo ...
- xgboost原理总结和代码展示
关于xgboost的学习推荐两篇博客,每篇看2遍,我都能看懂,你肯定没问题 两篇方法互通,知识点互补!记录下来,方便以后查看 第一篇:作者:milter链接:https://www.jianshu.c ...
随机推荐
- Java——并发编程
1.在java中守护线程和本地线程区别? java中的线程分为两种:守护线程(Daemon)和用户线程(User). 任何线程都可以设置为守护线程和用户线程,通过方法Thread.setDaemon( ...
- Gerrit安装配置
环境: CentOS 1611 + gerrit-2.11.4 (review.openstack.org) 1. 安装java1.8 (>1.7) [root@review ~]# yum i ...
- 发现微信支付bug
第一张银行卡支付金额不足无法付款,选择另一张同样密码的银行卡,居然不用重新输入密码即可直接付款成功!
- 第 6 章 C控制语句:循环
6.16.3 使用嵌套循环,按下面格式打印字母: F FE FED FEDC FEDCB FEDCBA #include <stdio.h> int main() { ; ); row ! ...
- java Math数学工具及Random随机函数
Math类包含用于执行基本数学运算的方法,如绝对值.对数.平方根和三角函数.它是一个final类,其中定义的都是一些常量和静 态方法.常用方法如下:public static double sqrt( ...
- JS定时器相关用法
一.定时器在javascript中的作用 1.制作动画 <!DOCTYPE html> <html lang="en"> <head> < ...
- css3实现 依次出现三个点(一般用于提示加载中。。。 提交中。。。)
<a href="javascript:" class="login">登录中<span class="dotting"& ...
- linux下添加用户到sudo组 并禁止sudo用户修改密码
linux下添加用户到sudo组 创建用户 useradd hanli 为新用户设置密码 passwd hanli 创建用户组 groupadd op 将用户添加到用户组 usermod - ...
- ElasticSearch 简单的 搜索 聚合 分析
一. 搜索1.DSL搜索 全部数据没有任何条件 GET /shop/goods/_search { "query": { "match_all": {} } } ...
- Spring源码分析(二十五)finishRefresh
摘要: 本文结合<Spring源码深度解析>来分析Spring 5.0.6版本的源代码.若有描述错误之处,欢迎指正. 在 Spring 中还提供了 Lifecycle 接口, Lifecy ...