机器学习 | 详解GBDT梯度提升树原理,看完再也不怕面试了
本文始发于个人公众号:TechFlow,原创不易,求个关注
今天是机器学习专题的第30篇文章,我们今天来聊一个机器学习时代可以说是最厉害的模型——GBDT。
虽然文无第一武无第二,在机器学习领域并没有什么最厉害的模型这一说。但在深度学习兴起和流行之前,GBDT的确是公认效果最出色的几个模型之一。虽然现在已经号称进入了深度学习以及人工智能时代,但是GBDT也没有落伍,它依然在很多的场景和公司当中被广泛使用。也是面试当中经常会问到的模型之一。
遗憾的是市面上关于GBDT的资料虽然不少,但是很少有人把其中的核心精髓介绍清楚的。新手在初学的时候往往会被”梯度“,”残差“等这些令人费解的概念给困惑住,耽误了算法原理的学习和理解。但其实GBDT整体的原理还是比较直观和简单的,只要我们找对了方法,抓住了核心,我相信对于绝大多数人来说,应该都不会问题。
GBDT基础概念
GBDT的英文原文是Gradient Boosting Decision Tree,即梯度提升决策树。从它的英文表述我们可以看出来,GBDT的基础还是决策树。决策树我们在之前的文章当中曾经有过详细的讨论,我们这里就不再赘述了。在GBDT当中用到的主要是决策树的CART算法,在CART算法当中,我们每次都会选择一个特征并且寻找一个阈值进行二分。将样本根据阈值分成小于等于阈值的以及大于阈值的两个部分,在CART树当中,同一个特征可以重复使用,其他类似的ID3和C4.5都没有这个性质。
另外一个关键词是Boosting,Boosting表示一种集成模型的训练方法,我们之前在介绍AdaBoost模型的时候曾经提到过。它最大的特点就是会训练多个模型,通过不断地迭代来降低整体模型的偏差。比如在Adaboost模型当中,会设置多个弱分类器,根据这些分类器的表现我们会给与它们不同的权值。通过这种设计尽可能让效果好的分类器拥有高权重,从而保证模型的拟合能力。
但GBDT的Boosting方法与众不同,它是一个由多棵CART决策回归树构成的加法模型。我们可以简单理解成最后整个模型的预测结果是所有回归树预测结果的和,理解了这一点对于后面理解梯度和残差非常重要。
我们可以试着写一下GBDT的预测公式:
公式中的M表示CART树的个数,表示第i棵回归树对于样本的预测结果,其中的表示每一棵回归树当中的参数。所以整个过程就和我刚才说的一样,GBDT模型最后的结果是所有回归树预测结果的加和。
但是这就有了一个问题,如果是回归问题那还好说,如果是分类问题那怎么办?难道分类结果也能加和吗?
其实也是可以的,我们知道在逻辑回归当中,我们用到的公式是,这个式子的结果表示样本的类别是1的概率。我们当然不能直接来拟合这个概率,但是我们可以用加和的方式来拟合的结果,这样我们就间接得到了概率。
今天的文章当中我们主要先来讲解回归问题,因为它的公式和理解最直观简单。分类的问题我们将会放到下一篇文章当中,因此这里稍作了解即可。
梯度和残差
下面我们要介绍到梯度和残差的概念了,我们先来回顾一下线性回归当中梯度下降的用法。
在线性回归当中我们使用梯度下降法是为了寻找最佳的参数,使得损失函数最小。实际上目前绝大多数的模型都是这么做的,计算梯度的目的是为了调整参数。但是GBDT不同,计算梯度是为了下一轮的迭代,这句话非常关键,一定要理解。
我们来举个例子,假设我们用线性回归拟合一个值,这里的目标y是20。我们当前的得到的是10,那么我们应该计算梯度来调整参数,明显应该将它调大一些从而降低偏差。
但是GBDT不是这么干的,同样假设我们第一棵回归树得到的结果也是10,和真实结果相差了10,我们一样来计算梯度。在回归问题当中,我们通常使用均方差MSE作为损失函数,那么我们可以来算一下这个函数的梯度。我们先写出损失函数的公式:
L关于的负梯度值刚好等于,看起来刚好是我们要预测的目标值减去之前模型预测的结果。这个值也就是我们常说的残差。
我们用表示第m棵回归树对于样本i的训练目标,它的公式为:
从直观上来讲究很简单了,我们要预测的结果是20,第一棵树预测了10,相差还剩10,于是我们用第二棵树来逼近。第二棵树预测了5,相差变成了5,我们继续创建第三棵树……
一直到我们已经逼近到了非常接近小于我们设定的阈值的时候,或者子树的数量达到了上限,这个时候模型的训练就停止了。
这里要注意,不能把残差简单理解成目标值和的差值,它本质是由损失函数计算负梯度得到的。
训练过程
我们再把模型训练的整个过程给整理一下,把所有的细节串联起来。
首先我们先明确几个参数,M表示决策树的数量。表示第m轮训练之后的整体,即为最终输出的GBDT模型。
初始化
首先,我们创建第一棵回归树即,在回归问题当中,它是直接用回归树拟合目标值的结果,所以:
迭代
i. 对于第2到第m棵回归树,我们要计算出每一棵树的训练目标, 也就是前面结果的残差:
ii. 对于当前第m棵子树而言,我们需要遍历它的可行的切分点以及阈值,找到最优的预测值c对应的参数,使得尽可能逼近残差,我们来写出这段公式:
这里的指的是第m棵子树所有的划分方法中叶子节点预测值的集合,也就是第m棵回归树可能达到的预测值。其中j的范围是1,2,3...J。
接着,我们更新,这里的I是一个函数,如果样本落在了节点上,那么I=1,否则I=0。
最后我们得到回归树
上述的公式看起来有些复杂,其实就是我们借助和I把回归树的情况表示了出来而已。因为我们训练模型最终希望得到的其实是模型的参数,对于回归树而言,它的参数表示比较复杂,所以看起来可能会有些迷惑。
我们可以简单一点理解,GBDT就是利用的加法模型训练多棵回归树,预测的结果是这些回归树的和。而每一棵回归树的训练目标都是之前模型的残差。
Shrinkage
Shinkage是一种优化避免GBDT陷入过拟合的方法,这个方法的本质是减小每一次迭代对于残差的收敛程度,认为每一次逼近少一些多次收敛的效果好于一次逼近很多,逼近次数较少的结果。具体的表现措施就是给我们的每一棵回归树的结果乘上一个类似于学习率的参数,通过增大回归树的个数来弥补。
说白了就和梯度下降的时候我们乘上学习率是一样的,只不过在梯度下降的问题当中,我们明确知道不乘学习率的话会陷入震荡无法收敛的问题。而在GBDT当中,Shrinkage的机制并没有一个明确的证明或者是感性的认识,它的效果更多是基于经验的。
我们写一下加上Shrinkage之后的方程来做个对比:
这里的就是我们的Shrinkage的参数,一般取值在0.001到0.01之间。
总结
到这里,关于GBDT模型的基本原理就算是介绍完了。如果你对于之前关于决策树的相关文章都认真阅读的话,相信理解GBDT对于你来说应该不是一件困难的事。如果你没有读过或者是错过了之前的文章的话,可以看一下文末的相关阅读的部分,回顾一下之前的内容。
GBDT最大的创新就在于,将传统的调整参数来降低梯度的过程转化成了创建新的树模型来逼近,我第一次看到的时候深深为之惊艳。和传统的模型相比,由于GBDT是综合了多个分类器的结果,所以更加不容易陷入过拟合,并且对于一些复杂的场景的拟合效果会更好。今天我们介绍的只是最基本的回归问题当中的解法,在分类问题当中,公式会稍稍有些不同,这部分内容我们放在下篇文章当中。
今天的文章到这里就结束了,如果喜欢本文的话,请来一波素质三连,给我一点支持吧(关注、转发、点赞)。
相关阅读
机器学习——打开集成方法的大门,手把手带你实现AdaBoost模型
GBDT开源代码: https://github.com/RRdmlearning/Machine-Learning-From-Scratch/tree/master/gradient_boosting_decision_tree
机器学习 | 详解GBDT梯度提升树原理,看完再也不怕面试了的更多相关文章
- 机器学习 | 详解GBDT在分类场景中的应用原理与公式推导
本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第31篇文章,我们一起继续来聊聊GBDT模型. 在上一篇文章当中,我们学习了GBDT这个模型在回归问题当中的原理.GBD ...
- GBDT(梯度提升树)scikit-klearn中的参数说明及简汇
1.GBDT(梯度提升树)概述: GBDT是集成学习Boosting家族的成员,区别于Adaboosting.adaboosting是利用前一次迭代弱学习器的误差率来更新训练集的权重,在对更新权重后的 ...
- 机器学习 之梯度提升树GBDT
目录 1.基本知识点简介 2.梯度提升树GBDT算法 2.1 思路和原理 2.2 梯度代替残差建立CART回归树 1.基本知识点简介 在集成学习的Boosting提升算法中,有两大家族:第一是AdaB ...
- 【Spark机器学习速成宝典】模型篇07梯度提升树【Gradient-Boosted Trees】(Python版)
目录 梯度提升树原理 梯度提升树代码(Spark Python) 梯度提升树原理 待续... 返回目录 梯度提升树代码(Spark Python) 代码里数据:https://pan.baidu.co ...
- 梯度提升树 Gradient Boosting Decision Tree
Adaboost + CART 用 CART 决策树来作为 Adaboost 的基础学习器 但是问题在于,需要把决策树改成能接收带权样本输入的版本.(need: weighted DTree(D, u ...
- 梯度提升树(GBDT)原理小结(转载)
在集成学习值Adaboost算法原理和代码小结(转载)中,我们对Boosting家族的Adaboost算法做了总结,本文就对Boosting家族中另一个重要的算法梯度提升树(Gradient Boos ...
- 梯度提升树(GBDT)原理小结
在集成学习之Adaboost算法原理小结中,我们对Boosting家族的Adaboost算法做了总结,本文就对Boosting家族中另一个重要的算法梯度提升树(Gradient Boosting De ...
- GBDT(梯度提升树) 原理小结
在之前博客中,我们对Boosting家族的Adaboost算法做了总结,本文就对Boosting家族中另一个重要的算法梯度提升树(Gradient Boosting Decison Tree, 以下简 ...
- scikit-learn 梯度提升树(GBDT)调参小结
在梯度提升树(GBDT)原理小结中,我们对GBDT的原理做了总结,本文我们就从scikit-learn里GBDT的类库使用方法作一个总结,主要会关注调参中的一些要点. 1. scikit-learn ...
随机推荐
- Mysql基础(六):索引、数据库备份、锁和事务、慢查询优化、索引命中相关
目录 数据库05 /索引.数据库备份.锁和事务.慢查询优化.索引命中相关 1. 什么是索引 2. 索引的原理 3. 索引的数据结构(聚集索引.辅助索引) 4. 索引操作 5. 索引的两大类型hash与 ...
- bzoj3375[Usaco2004 Mar]Paranoid Cows 发疯的奶牛*
bzoj3375[Usaco2004 Mar]Paranoid Cows 发疯的奶牛 题意: 依次给出n只奶牛的产奶时间段,求最大的k使得前k只奶牛不存在一个时间段被另一个时间段完全覆盖的情况.n≤1 ...
- Json对象,Json数组,Json字符串的区别
Json对象: var str = {"姓名":"张三","性别":"男","年龄":"2 ...
- SSM框架前后端信息交互
一.从前端向后端传送数据 常见的3种方式 1.form表单的action:此方法可以提交form表单内的输入数据,也可同时提交某些隐藏但设置有默认值的<input>,如修改问题时,我们除了 ...
- MSSQL系列 (一):数据库的相关操作(增删改查)
1.创建数据库 --创建数据库 create database stuDb on primary ( --表示属于primary文件组 name='stuDb', --逻辑名称 fileName='D ...
- javascript : 找到一个树型数据的一个节点及其所有父节点
如题. (function () { let tree = { "id": 0, "label": "all", "childre ...
- 【JVM之内存与垃圾回收篇】方法区
方法区 前言 这次所讲述的是运行时数据区的最后一个部分 从线程共享与否的角度来看 ThreadLocal:如何保证多个线程在并发环境下的安全性?典型应用就是数据库连接管理,以及会话管理 栈.堆.方法区 ...
- Flutter日常笔记
factory修饰的构造方法 表示不是每次返回的都是新创建出来的对象, 可以取内存中已有的, 比如单例模式的书写 每次返回的都是一个实例, 这时要使用factory修饰构造方法 flutter不要求显 ...
- leetcode题库练习_数组中重复的数字
题目:数组中重复的数字 找出数组中重复的数字. 在一个长度为 n 的数组 nums 里的所有数字都在 0-n-1 的范围内.数组中某些数字是重复的,但不知道有几个数字重复了,也不知道每个数字重复了几次 ...
- Region Normalization for Image Inpainting, AAAI 2020
论文:Region Normalization for Image Inpainting, AAAI 2020 代码:https://github.com/geekyutao/RN 图像修复的目的是重 ...