XGBoost模型 0基础小白也能懂(附代码)

原文链接

啥是XGBoost模型

XGBoost 是 eXtreme Gradient Boosting 的缩写称呼,它是一个非常强大的 Boosting 算法工具包,优秀的性能(效果与速度)让其在很长一段时间内霸屏数据科学比赛解决方案榜首,现在很多大厂的机器学习方案依旧会首选这个模型。

XGBoost 在并行计算效率、缺失值处理、控制过拟合、预测泛化能力上都变现非常优秀。本文我们给大家详细展开介绍 XGBoost,包含「算法原理」和「工程实现」两个方面。

关于 XGBoost 的原理,其作者陈天奇本人有一个非常详尽的Slides做了系统性的介绍。

Boosted Tree

Boosted Tree(提升树)是一种常用的机器学习方法,属于集成学习的一种。它通过将多个弱学习器(通常是决策树)组合起来,以提升整个模型的预测性能。Boosted Tree的核心思想是通过逐步训练多个决策树,每个树都试图修正前一个树的错误,最终得到一个更强大的模型。

模型:假设我们有\(K\)棵树\(\hat{y_i}=\sum_{k=1}^Kf_k(x_i),f_k\in{F}\),\(F\)为包含所有回归树的函数空间。

目标函数:\(Obj=\sum_{i=1}^nl(y_i,\hat{y_i})+\sum_{k=1}^K\Omega(f_k)\)

\(\sum_{i=1}^nl(y_i,\hat{y_i})\)是成本函数

\(\sum_{k=1}^K\Omega(f_k)\)是正则化项,代表树的复杂程度,树越复杂正则化项的值越高(正则化项如何定义我们会在后面详细说)。

当我们讨论决策树或相关的树模型时,通常是启发式的。启发式(heuristic)在机器学习中指的是使用经验法则或近似方法来解决问题,而不保证找到最优解。

Gradient Boosting(如何学习)

在做 GBDT 的时候,我们没有办法使用 SGD(Stochastic Gradient Descent,随机梯度下降),因为它们是树,而非数值向量——也就是说从原来我们熟悉的参数空间变成了函数空间。Gradient Boosting Decision Trees(GBDT)与深度学习或线性模型不同,它的核心不是直接通过参数更新来优化,而是通过构建新的决策树来逐步降低误差。

解决方案:初始化一个预测值,每次迭代添加一个新函数\((f)\)

1)目标函数变换

根据解决方案可以对目标函数进行初步变形

其中constant是常数项,比如\(\Omega(f_1),\Omega(f_2)\)之类的,然后第三行就是考虑平方损失,\(l(y_i,\hat{y_i})=\frac{1}{2}(y_i-\hat{y_i})^2\),代进去就行

所以我们的目的就是找到\(f(t)\)使得目标函数最低。然而,经过上面初次变形的目标函数仍然很复杂,目标函数会产生二次项。引入泰勒公式

这图也多少有点问题,是在还没考虑平方损失的地方引入泰勒公式,然后泰勒公式也有问题,后面两项应该是\(f(x)\)的一阶导数和二阶导数,所以才是\(g_i,h_i\)。

再把里面的常数项提取出,和\(f_t\)无关

2)重新定义树

前面已经用\(f_t(x)\)表示一棵树,在本小节,我们重新定义一下树:我们通过叶子结点中的分数向量和将实例映射到叶子结点的索引映射函数来定义树:(有点儿抽象,具体请看下图)

图里有问题,第一个叶子结点权重是+2

3)定义树的复杂程度

其中\(T\)才是叶子节点的个数,\(\gamma\)是控制树的复杂度的参数,树的叶子节点越多,复杂度越高。通过调节

\(\gamma\)可以控制模型的复杂度。后面一堆是 L2 Norm正则化系数

4)重新审视目标函数

定义在叶子结点\(j\)中的实例的集合为:\(I_j=\{i|q(x_i)=j\}\),这么定义也是为了能够构建出第三个式子,都写成\(\sum_{j=1}^T\)

同时也会发现上式是\(T\)个独立二次函数的和

5)计算叶子结点的值

搞了一大坨,其实也就是先把值换成\(G_j,H_j\),然后用一元二次方程求一个最优值就完了。

下图是前面公式讲解对应的一个实际例子。

这里再次总结一下,我们已经把目标函数变成了仅与\(G,H,\gamma,\lambda,T\)这五项已知参数有关的函数,把之前的变量\(f_t\)消灭掉了,也就不需要对每一个叶子进行打分了!

那么现在问题来,刚才我们提到,以上这些是假设树结构确定的情况下得到的结果。但是树的结构有好多种,我们应该如何确定呢?

6) 贪婪算法生成树

上一部分中我们假定树的结构是固定的。但是,树的结构其实是有无限种可能的,本小节我们使用贪婪算法生成树:

首先生成一个深度为0的树(只有一个根结点,也叫叶子结点)

对于每棵树的每个叶子结点,尝试去做分裂(生成两个新的叶子结点,原来的叶子结点不再是叶子结点)。在增加了分裂后的目标函数前后变化为(我们希望增加了树之后的目标函数小于之前的目标函数,所以用之前的目标函数减去之后的目标函数):

\(Gain=\frac{1}{2}(\frac{G_L^2}{H_L+\lambda}+\frac{G_R^2}{H_R+\lambda}-\frac{(G_L+G_R)^2}{H_L+H_R+\lambda})-\gamma\)

接下来要考虑的是如何寻找最佳分裂点。

例如,如果\(x_j\)是年龄,当分裂点是\(a\)的时候的增益\(Gain\)是多少?

其实这里对排序后的实例进行从左到右的线性扫描就足以决定特征的最佳分裂点。从左到右依次扫描:一旦数据按照特征值进行了排序,我们从第一个样本开始,依次计算每个可能的分裂点。对于每个分裂点,我们把样本分为“左侧”和“右侧”两个子集,分别计算划分前后目标函数的变化。下面还有别的一些办法

7)如何处理分类型变量

在很多情况下,我们不需要为分类变量设计特殊的处理方式,可以将其转换为one-hot 编码来处理。

\(z_j=
\begin{cases}
0& \text{if x is in category y}\\
1& \text{otherwise}
\end{cases}
\)

如果有太多的分类的话,矩阵会非常稀疏,算法会优先处理稀疏数据。

8) 修剪和正则化

回顾之前的增益,当训练损失减少的值小于正则化带来的复杂度时,增益有可能会是负数,此时就是模型的简单性和可预测性之间的权衡

XGBoost核心原理归纳解析

铺垫了那么多,总算到这里了。XGBoost 也是一个 Boosting 加法模型,每一步迭代只优化当前步中的子模型。

第\(m\)步我们有:\(F_m(x_i)=F_{m-1}(x_i)+f_m(x_i)\)

\(f_m(x_i)\)为当前步的子模型。

\(F_{m-1}(x_i)\)为前\(m-1\)个完成训练且固定了的子模型。

泰勒展开

然后去掉常数,带入复杂度(和之前一样)

1)近似算法

基于性能的考量,XGBoost 还对贪心准则做了一个近似版本,简单说,处理方式是「将特征分位数作为划分候选点」。这样将划分候选点集合由全样本间的遍历缩减到了几个分位数之间的遍历。

展开来看,特征分位数的选取还有 global 和 local 两种可选策略:

精确贪心准则:这是默认的精确算法,遍历所有可能分裂点,找到能最大化增益的点。计算量最大,但分裂效果最优。

Global 近似分裂:使用全体样本的特征分位数进行一次性划分,分裂点在所有节点中复用,计算量大幅减少,适合较大的数据集。

Local 近似分裂:在每个节点分裂前根据当前节点的样本重新计算特征分位数,能够更加灵活适应不同节点的特征分布,适合样本分布差异较大的情况。

近似算法的性能与精确贪心算法几乎相同,但大大降低了计算成本。

2)加权分位数

在 XGBoost 中,加权分位数(Weighted Quantile Sketch)用于加速分裂点的寻找过程。加权分位数算法并不是直接根据样本的特征值来划分分位点,而是考虑了样本的二阶导数(Hessian)作为权重,从而更好地平衡分裂点的选择,特别是在近似算法中。

令偏导为0易得\(f_m^*(x_i)=-\frac{g_i}{h_i}\)

3) 列采样与学习率

列采样指的是在构建每棵决策树时,XGBoost 不会使用全部特征,而是随机选择部分特征用于分裂。这种方法源自于随机森林的思想,目的是增加模型的多样性,从而防止过拟合。

学习率在梯度提升树(GBDT)中是一个非常重要的超参数,用于控制每棵树对模型的贡献。学习率可以防止模型更新过快,从而提升模型的稳定性和性能。也叫步长、shrinkage,具体的操作是在每个子模型前(即每个叶节点的回归值上)乘上该系数,不让单颗树太激进地拟合,留有一定空间,使迭代更稳定。XGBoost默认设定为 。

4) 特征缺失与稀疏性

简单说,它的做法是将缺失值和稀疏\(0\)值等同视作缺失值,将其「绑定」在一起,分裂节点的遍历会跳过缺失值的整体。这样大大提高了运算效率。

比如在下面的例子中有六种划分情况,XGBoost 会遍历以上6种情况(3个非缺失值的切分点×缺失值的两个方向),最大的分裂收益就是本特征上的分裂收益

XGBoost工程优化

1)并行列块设计

XGBoost 将每一列特征提前进行排序,以块(Block)的形式储存在缓存中,并以索引将特征值和梯度统计量对应起来,每次节点分裂时会重复调用排好序的块。而且不同特征会分布在独立的块中,因此可以进行分布式或多线程的计算。

2)缓存访问优化

特征值排序后通过索引来取梯度\(g_i,h_i\)会导致访问的内存空间不一致,进而降低缓存的命中率,影响算法效率。为解决这个问题,XGBoost为每个线程分配一个单独的连续缓存区,用来存放梯度信息。

3) 核外块计算

数据量非常大的情形下,无法同时全部载入内存。XGBoost 将数据分为多个 blocks 储存在硬盘中,使用一个独立的线程专门从磁盘中读取数据到内存中,实现计算和读取数据的同时进行。

为了进一步提高磁盘读取数据性能,XGBoost 还使用了两种方法:

① 压缩 block,用解压缩的开销换取磁盘读取的开销。

② 将 block 分散储存在多个磁盘中,提高磁盘吞吐量。

XGBoost vs GBDT

GBDT 是机器学习算法,XGBoost 在算法基础上还有一些工程实现方面的优化。

GBDT 使用的是损失函数一阶导数,相当于函数空间中的梯度下降;XGBoost 还使用了损失函数二阶导数,相当于函数空间中的牛顿法。

正则化:XGBoost 显式地加入了正则项来控制模型的复杂度,能有效防止过拟合。

列采样:XGBoost 采用了随机森林中的做法,每次节点分裂前进行列随机采样。

缺失值:XGBoost 运用稀疏感知策略处理缺失值,GBDT无缺失值处理策略。

并行高效:XGBoost 的列块设计能有效支持并行运算,效率更优。

代码实现

需要先下载xgboost

pip install xgboost

代码如下

# 导入所需的库
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes # 替换为 load_diabetes
from sklearn.metrics import mean_squared_error # 1. 加载糖尿病数据集
# 这个数据集包含442个样本,10个特征,用于预测一个连续目标变量
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target # X是特征数据,y是标签(目标变量) # 2. 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 3. 将数据转换为 DMatrix 格式
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test) # 4. 设置 XGBoost 模型的超参数
params = {
'objective': 'reg:squarederror', # 回归任务使用的目标函数,平方误差
'max_depth': 3, # 决策树的最大深度,控制模型的复杂度
'eta': 0.05, # 学习率,控制每棵树对整体模型的贡献
'eval_metric': 'rmse' , # 评估指标,使用均方根误差(RMSE)
'lambda': 2, # L2 正则化项,防止过拟合
'alpha': 0.5 # L1 正则化项
} # 5. 设定训练轮数
num_round = 200 # 训练的轮数,即构建多少棵树 # 6. 定义评估数据集
evals = [(dtrain, 'train'), (dtest, 'eval')] # (数据集, 数据集名称) # 7. 训练 XGBoost 模型,加入 early_stopping_rounds早停机制,防止过拟合
bst = xgb.train(params, dtrain, num_round, evals, early_stopping_rounds=10) # 8. 使用训练好的模型对测试集进行预测
y_pred = bst.predict(dtest) # 9. 评估模型性能
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}") # 10. 保存训练好的模型
bst.save_model('xgboost_model.json') # 11. 加载已保存的模型
loaded_bst = xgb.Booster()
loaded_bst.load_model('xgboost_model.json') # 12. 使用加载的模型进行预测
y_pred_loaded = loaded_bst.predict(dtest)
mse_loaded = mean_squared_error(y_test, y_pred_loaded)
print(f"Mean Squared Error from loaded model: {mse_loaded}")

结果如下

[0]	train-rmse:76.08309	eval-rmse:71.75905
[1] train-rmse:74.34324 eval-rmse:70.47408
[2] train-rmse:72.66427 eval-rmse:69.24759
[3] train-rmse:71.10664 eval-rmse:68.09809
[4] train-rmse:69.63498 eval-rmse:67.14668
[5] train-rmse:68.24045 eval-rmse:66.09854
[6] train-rmse:66.93042 eval-rmse:64.91738
[7] train-rmse:65.73304 eval-rmse:64.08775
[8] train-rmse:64.58640 eval-rmse:63.26052
[9] train-rmse:63.51304 eval-rmse:62.49745
[10] train-rmse:62.44810 eval-rmse:61.64759
[11] train-rmse:61.51387 eval-rmse:60.96222
[12] train-rmse:60.61767 eval-rmse:60.32972
[13] train-rmse:59.77722 eval-rmse:59.74329
[14] train-rmse:59.01348 eval-rmse:59.13121
[15] train-rmse:58.24704 eval-rmse:58.55106
[16] train-rmse:57.57392 eval-rmse:58.15165
[17] train-rmse:56.92761 eval-rmse:57.68188
[18] train-rmse:56.33319 eval-rmse:57.37781
[19] train-rmse:55.72582 eval-rmse:56.97001
[20] train-rmse:55.14420 eval-rmse:56.45029
[21] train-rmse:54.61096 eval-rmse:55.97904
[22] train-rmse:54.12594 eval-rmse:55.57225
[23] train-rmse:53.68383 eval-rmse:55.39305
[24] train-rmse:53.24822 eval-rmse:55.01127
[25] train-rmse:52.85214 eval-rmse:54.85699
[26] train-rmse:52.43814 eval-rmse:54.49904
[27] train-rmse:52.07004 eval-rmse:54.42905
[28] train-rmse:51.68191 eval-rmse:54.25354
[29] train-rmse:51.28268 eval-rmse:54.09452
[30] train-rmse:50.94229 eval-rmse:54.06703
[31] train-rmse:50.58475 eval-rmse:53.88010
[32] train-rmse:50.24739 eval-rmse:53.74475
[33] train-rmse:49.97042 eval-rmse:53.49905
[34] train-rmse:49.65855 eval-rmse:53.41597
[35] train-rmse:49.38190 eval-rmse:53.34692
[36] train-rmse:49.07203 eval-rmse:53.32202
[37] train-rmse:48.81472 eval-rmse:53.22084
[38] train-rmse:48.57124 eval-rmse:53.24058
[39] train-rmse:48.33730 eval-rmse:53.13983
[40] train-rmse:47.97171 eval-rmse:53.05406
[41] train-rmse:47.75619 eval-rmse:52.87405
[42] train-rmse:47.43067 eval-rmse:52.80852
[43] train-rmse:47.18844 eval-rmse:52.70296
[44] train-rmse:46.96694 eval-rmse:52.61260
[45] train-rmse:46.79053 eval-rmse:52.58588
[46] train-rmse:46.58746 eval-rmse:52.51602
[47] train-rmse:46.38476 eval-rmse:52.50433
[48] train-rmse:46.15591 eval-rmse:52.44922
[49] train-rmse:46.00542 eval-rmse:52.36981
[50] train-rmse:45.84480 eval-rmse:52.27445
[51] train-rmse:45.63700 eval-rmse:52.23794
[52] train-rmse:45.49250 eval-rmse:52.25740
[53] train-rmse:45.31208 eval-rmse:52.16836
[54] train-rmse:45.15374 eval-rmse:52.22044
[55] train-rmse:45.00284 eval-rmse:52.15072
[56] train-rmse:44.87677 eval-rmse:52.04112
[57] train-rmse:44.71921 eval-rmse:52.08482
[58] train-rmse:44.55626 eval-rmse:52.02783
[59] train-rmse:44.41483 eval-rmse:52.09304
[60] train-rmse:44.27997 eval-rmse:52.03098
[61] train-rmse:44.15710 eval-rmse:52.08378
[62] train-rmse:44.00683 eval-rmse:52.02136
[63] train-rmse:43.84878 eval-rmse:52.06178
[64] train-rmse:43.74180 eval-rmse:52.06495
[65] train-rmse:43.59775 eval-rmse:52.08875
[66] train-rmse:43.44009 eval-rmse:52.20317
[67] train-rmse:43.29717 eval-rmse:52.14245
[68] train-rmse:43.10437 eval-rmse:52.15464
[69] train-rmse:43.00768 eval-rmse:52.17011
[70] train-rmse:42.87951 eval-rmse:52.11852
[71] train-rmse:42.79951 eval-rmse:52.21249
[72] train-rmse:42.66769 eval-rmse:52.22331
Mean Squared Error: 2727.2736118611274
Mean Squared Error from loaded model: 2727.2736118611274

train-rmse是训练集上的预测值与真实值之间的误差。eval-rmse是模型在测试集上的 RMSE

分析下早停机制下最后的数据,42.66769 表示在训练集上,模型的预测误差为 42.67。RMSE 越低,表示模型在训练集上拟合得越好。52.22 说明模型在测试集上的预测误差明显高于训练集,表明模型可能存在一定的过拟合问题,模型在训练集上表现良好,但在新数据(测试集)上的泛化能力不如在训练集上的表现。

XGBoost模型 0基础小白也能懂(附代码)的更多相关文章

  1. Docker_入门?只要这篇就够了!(纯干货适合0基础小白)

    与sgy一起开启你的Docker之路 关键词: Docker; mac; Docker中使用gdb无法进入断点,无法调试; 更新1: 看起来之前那一版博文中参考资料部分引用的外站链接太多,被系统自动屏 ...

  2. 0基础小白怎么学好Java?

    自身零基础,我们应该先学好Java,小编给大家介绍一下Java的特性: Java语言是简单的 Java语言的语法与C语言和C++语言很接近,使得大多数程序员很容易学习和使用Java.Java丢弃了C+ ...

  3. 大一0基础小白用最基础C写哥德巴赫猜想

    #include <stdio.h>int main (){ int a,b,c,k,count1,count2; for(a=4;a<=1200;a=a+2){ for(b=2;b ...

  4. MySQL下载,安装,配置环境变量【0基础小白用】

    一,下载 选择社区版的,下载地址:https://dev.mysql.com/downloads/installer/  ,选择离线安装包 二,安装 1,双击安装包文件,这里选择服务模式,会安装在默认 ...

  5. (五)SpringBoot2.0基础篇- Mybatis与插件生成代码

    SpringBoot与Mybatis合并 一.创建SpringBoot项目,引入相关依赖包: <?xml version="1.0" encoding="UTF-8 ...

  6. 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)

    文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...

  7. (六)SpringBoot2.0基础篇- Redis整合(JedisCluster集群连接)

    一.环境 Redis:4.0.9 SpringBoot:2.0.1 Redis安装:Linux(Redhat)安装Redis 二.SpringBoot整合Redis 1.项目基本搭建: 我们基于(五) ...

  8. 0基础的小白怎么学习Java?

    自身零基础,那么我们应该先学好Java,首先我们来了解下Java的特性: Java语言是简单的 Java语言的语法与C语言和C++语言很接近,使得大多数程序员很容易学习和使用Java.另一方面,Jav ...

  9. 手把手0基础Centos下安装与部署paddleOcr 教程

    !!!以下内容为作者原创,首发于个人博客园&掘金平台.未经原作者同意与许可,任何人.任何组织不得以任何形式转载.原创不易,如果对您的问题提供了些许帮助,希望得到您的点赞支持. 0.paddle ...

  10. Android程序开发0基础教程(一)

    程序猿学英语就上视觉英语网 Android程序开发0基础教程(一)   平台简单介绍   令人激动的Google手机操作系统平台-Android在2007年11月13日正式公布了,这是一个开放源码的操 ...

随机推荐

  1. JavaSe 统计字符串中字符出现的次数

    public static void main(String[] args) { // 1.字符串 String str = "*Constructs a new <tt>Has ...

  2. 新版SpringBoot-Spring-Mybatis 数据库相关配置

    application.properties server.port=8081 # ========================数据库相关配置===================== sprin ...

  3. Vscode控制台乱码的最终解决方案

    Vscode控制台乱码的最终解决方案 vscode运行项目时控制台打印日志乱码.网上也有许多解决办法. 方法一[管用]推荐,避免过多设置 Java项目时,像Springboot微服务项目默认使用的是l ...

  4. 魔百和s905l3a蓝牙系列 在armbian驱动并使用蓝牙!

    文章已废弃,因为现在x大的dtb不需要驱动直接可以使用 之后我会重新写文章,感谢大家

  5. CF466E Information Graph 题解

    题目链接 Luogu Codeforces 题意简述 某公司中有 \(n\) 名员工.为方便起见,将这些员工从 1 至 \(n\) 编号.起初,员工之间相互独立.接下来,会有以下 \(m\) 次操作: ...

  6. 使用 useLazyFetch 进行异步数据获取

    title: 使用 useLazyFetch 进行异步数据获取 date: 2024/7/20 updated: 2024/7/20 author: cmdragon excerpt: 摘要:&quo ...

  7. GUN/Linux 基础教程

    GUN/Linux 基础教程 控制台 shell 超级用户 root 辅助管理 CLI 软件 文件基础 目录 链接 设备文件 控制台 shell 在启动 Linux 系统后,如果没有安装 GUI 的话 ...

  8. canvas实现截图功能

    开篇 最近在做一个图片截图的功能. 因为工作时间很紧张, 当时是使用的是一个截图插件. 周末两天无所事事,来写一个简单版本的截图功能. 因为写的比较简单,如果写的不好,求大佬轻一点喷 读取图片并获取图 ...

  9. 【Mybatis-Plus】06 代码生成器 CodeGenerator

    导入生成器需要的依赖坐标: <dependency> <groupId>com.baomidou</groupId> <artifactId>mybat ...

  10. python高性能计算:cython使用openmp并行(示例)

    y.pyx import cython from cython import parallel from cython.parallel import prange cdef int i cdef i ...