GBDT源码剖析
如今,GBDT被广泛运用于互联网行业,他的原理与优点这里就不细说了,网上google一大把。但是,我自认为自己不是一个理论牛人,对GBDT的理论理解之后也做不到从理论举一反三得到更深入的结果。但是学习一个算法,务必要深入细致才能领会到这个算法的精髓。因此,在了解了足够的GBDT理论之后,就需要通过去阅读其源码来深入学习GBDT了。但是,网上有关这类资料甚少,因此,我不得不自己亲自抄刀,索性自己从头学习了一下GBDT源码。幸好,这个算法在机器学习领域中的其它算法还是非常简单的。这里将心得简单分享,欢迎指正。源码可以去GBDT源码下载。
首先,这里需要介绍一下程序中用到的结构体,具体的每一个结构体的内容这里就不再赘述了,源码里面都有。这里只再细说一下每个结构体的作用,当然一些重要的结构体会详细解释。
struct gbdt_model_t:GBDT模型的结构体,也就是最终我们训练得到的由很多棵决策树组成的模型。
typedef struct {
int* nodestatus; //!<
int* depth; //
int* splitid; //!<
double* splitvalue; //!<
int* ndstart; //!< 节点对应于 Index 的开始位置
int* ndcount; //!< 节点内元素的个数
double* ndavg; //!< 节点内元素的均值
//double* vpredict;
int* lson; //!< 左子树
int* rson; //!< 右子树
int nodesize; //!< 树的节点个数
}gbdt_tree_t;
struct gbdt_tree_t:当然就代表模型中的一棵树的各种信息了。为了后面能理解,这里需要详细解释一下这个结构体。splitid[k]保存该棵树的第k个结点分裂的feature下标,splitvalue[k]保存该棵树第k个结点的分裂值,nodestatus[k]代表该棵树的第k个结点的状态,如果为GBDT_INTERIOR,代表该结点已分裂,如果为GBDT_TOSPLIT,代表该结点需分裂,如果为GBDT_TERMINAL表示该结点不需再分裂,一般是由于该结点的样本数ndcount[k]少于等于一阈值gbdt_min_node_size;depth[ncur+1]代表左子树的深度,depth[ncur+2]表示右子树的深度,其中ncur的增长步长为2,表示每次+2都相关于跳过当前结点的左子树和右子树,到达下一个结点。ndstart[ncur+1]代表划分到左子树开始样本的下标,ndstart[ncur+2]代表划分到右子树开始样本的下标,其中到底这个下标是代表第几个样本是由index的一个结构保存。ndcount[ncur + 1]代表划分到左子树的样本数量,ndcount[ncur + 2]代表划分到右子树的样本数量。ndavg[ncur+1]代表左子树样本的均值,同理是右子树样本的均值。nodestatus[ncur+1] = GBDT_TOSPLIT表示左子树可分裂。lson[k]=ncur+1表示第k个结点的左子树,同理表示第k个结点的右子树。
gbdt_info_t保存模型配置参数。
typedef struct
{
int* fea_pool; //!< 随机 feature 候选池
double* fvalue_list; //!< 以feature i 为拉链的特征值 x_i
double* fv; //!< 特征值排序用的buffer版本
double* y_list; //!< 回归的y值集合
int* order_i; //!< 排序的标号
} bufset; //!< 训练数据池
bufset代表训练数据池,它保存了训练当前一棵树所用到的一些数据。fea_pool保存了训练数据的特征的下标,循环rand_fea_num(feature随机采样量)次,随机地从fea_pool中选取特征来计算分裂的损失函数(先过的feature不会再选)。fvalue_list保存在当前选择特征fid时,所有采样的样本特征fid对应的值。fv与favlue_list一样。y_list表示采样样本的y值。order_i保存左子树与右子树结点下标。
nodeinfo代表节点的信息。
typedef struct
{
int bestid; //!< 分裂使用的Feature ID
double bestsplit; //!< 分裂边界的x值
int pivot; //!< 分裂边界的数据标号
} splitinfo; //!< 分裂的信息
splitinfo代表分裂的信息。pivot代表分裂点在order_i中的下标。bestsplit表示分裂值。bestid表示分裂的feature。
好了,解释完关键的一些结构体,下面要看懂整个gbdt的流程就非常简单了。这里我就简单的从头至尾叙述一下整个训练的流程。
首先申请分配模型空间gbdt_model,并且计算所有样本在每一维特征上的平均值。假如我们需要训练infbox.tree_num棵树,每一棵的训练流程为:从x_fea_value中采样gbdt_inf.sample_num个样本,index[i]记录了第i个结点所对应的样本集合x_fea_value中的下标,其始终保存了训练本棵树的所有采样样本对应样本空间的下标值,同时,结点的顺序是按该棵树所有结点按广度优先遍历算法遍历的结果的。即当前树gbdt_single_tree只有一个根结点0,其中gbdt_single_tree->nodestatus为GBDT_TOSPLIT,ndstart[0]=0,ndcount[0]=sample_num,ndavg为所有采样样本的y的梯度值均值。下面就是对这个结点进行分裂的过程:首先nodeinfo ninf这个结构体保存了当前分裂结点的一些信息,比如结点中样本开始的下标(指相对于index的下标值,index指向的值才是样本空间中该样本的下标),样本结束下标(同上),样本结点数,样本结点的y的梯度之和等。循环rand_fea_num次,随机采样feature,来计算在该feature分裂的信息增益,计算方式为(左子树样子目标值和的平方均值+右子树目标值和的平方均值-父结点所有样本和的平方均值)。选过的feature就不会再选中来计算信息增益了。利用data_set来保存当前分裂过程所用到的一些信息,包括候选feature池,选中feature对应的采样样本的特征值及其y值。data_set->order_i保存了左右子树对应结点在样本集合中的下标。计算每个feature的信息增益,并取最大的,保存分点信息到spinf中,包括最优分裂值,最优分裂feature。然后,将该结点小于分裂值的结点样本下标与大于分裂值的结点样本下标都保存在data_set->order_i中,nl记录了order_i中右子树开始的位置。更新index数组,将order_i中copy到index中。将nl更新到spinf中。注意index数组从左至右保存了最终分裂的左子树与右子树样本对应在样本空间的下标。
至此,我们找到了这个结点的最优分裂点。gbdt_single_tree->ndstart[1]保存了左孩子的开始下标(指相对于index的下标值,index指向的值才是样本下标),gbdt_single_tree->ndstart[2]保存了右孩子的开始下标,即nl的值。同理,ndcount,depth等也是对就保存了左右孩子信息。gbdt_single_tree->lson[0]=1,gbdt_single_tree->lson[0]=2即表示当前结点0的左子树是1,右子树是2。当前结点分裂完了之后,下一次就同理广度优先算法,对该结点的孩子继续上述步骤。
该棵树分裂完成之后,对每一个样本,都用目前模型(加上分裂完成的这棵树)计算预测值,并且更新每一个样本的残差y_gradient。计算过程:选取当前结点的分裂feature以及分裂值,小于则走左子树,大于则走右子树,直到叶子结点。预测值为shrink*该叶子结点的样本目标值的均值。
训练第二棵树同理,只是训练的样本的目标值变成了前面模型预测结果的残差了。这点就体现在梯度下降的寻优过程。
好了,这里只是简单的对gbdt代码做了说明,当然如果没有看过本文引用的源码,是不怎么能看懂的,如果结合源码来看,就很容易看懂了。总之,个人感觉,只有结合原码来学习gbdt,才真正能体会到事个模型的学习以及树的生成过程。
GBDT源码剖析的更多相关文章
- jQuery之Deferred源码剖析
一.前言 大约在夏季,我们谈过ES6的Promise(详见here),其实在ES6前jQuery早就有了Promise,也就是我们所知道的Deferred对象,宗旨当然也和ES6的Promise一样, ...
- Nodejs事件引擎libuv源码剖析之:高效线程池(threadpool)的实现
声明:本文为原创博文,转载请注明出处. Nodejs编程是全异步的,这就意味着我们不必每次都阻塞等待该次操作的结果,而事件完成(就绪)时会主动回调通知我们.在网络编程中,一般都是基于Reactor线程 ...
- Apache Spark源码剖析
Apache Spark源码剖析(全面系统介绍Spark源码,提供分析源码的实用技巧和合理的阅读顺序,充分了解Spark的设计思想和运行机理) 许鹏 著 ISBN 978-7-121-25420- ...
- 基于mybatis-generator-core 1.3.5项目的修订版以及源码剖析
项目简单说明 mybatis-generator,是根据数据库表.字段反向生成实体类等代码文件.我在国庆时候,没事剖析了mybatis-generator-core源码,写了相当详细的中文注释,可以去 ...
- STL"源码"剖析-重点知识总结
STL是C++重要的组件之一,大学时看过<STL源码剖析>这本书,这几天复习了一下,总结出以下LZ认为比较重要的知识点,内容有点略多 :) 1.STL概述 STL提供六大组件,彼此可以组合 ...
- SpringMVC源码剖析(四)- DispatcherServlet请求转发的实现
SpringMVC完成初始化流程之后,就进入Servlet标准生命周期的第二个阶段,即“service”阶段.在“service”阶段中,每一次Http请求到来,容器都会启动一个请求线程,通过serv ...
- 自己实现多线程的socket,socketserver源码剖析
1,IO多路复用 三种多路复用的机制:select.poll.epoll 用的多的两个:select和epoll 简单的说就是:1,select和poll所有平台都支持,epoll只有linux支持2 ...
- Java多线程9:ThreadLocal源码剖析
ThreadLocal源码剖析 ThreadLocal其实比较简单,因为类里就三个public方法:set(T value).get().remove().先剖析源码清楚地知道ThreadLocal是 ...
- JS魔法堂:mmDeferred源码剖析
一.前言 avalon.js的影响力愈发强劲,而作为子模块之一的mmDeferred必然成为异步调用模式学习之旅的又一站呢!本文将记录我对mmDeferred的认识,若有纰漏请各位指正,谢谢.项目请见 ...
随机推荐
- unity 获取水平FOV
unity中Camera的Field of View是指的垂直FOV,水平FOV可以经过计算得到. 创建脚本如下,把脚本挂载到摄像机上即可得到水平FOV: public class GetHorizo ...
- 记一款bug管理系统(bugdone.cn)的开发过程(3) - 永久免费化
BugDone永久免费了! BugDone(bug管理工具)已经发布有一阵子了,自发布以来注册用户量.项目创建量稳步提升,并且得到了很多用户的好评. 在开发BugDone工具之前,我们团队也曾为找不到 ...
- 脚本设置IP bat 命令行设置自动获取IP和固定IP
由于办公室网络需要固定IP和DNS才能上网, 在连接公共网络或者家里又需要自动获取IP和DNS才能上网. 频繁手动切换很麻烦,就搞了两个脚本一键设置. 1.新建文本文件, 命名为固定IP.bat 复制 ...
- jmeter函数简介
1._char:把一组数字转化成Unicode字符. 2._counter:记录线程的迭代次数. 3._CSVRead:可以从文件中指定列的值. 4.${_CSVRead(D:\test.txt,0, ...
- SQL Server中怎么查看每个数据库的日志大小,以及怎么确定数据库的日志文件,怎么用语句收缩日志文件
一,找到每个数据库的日志文件大小 SQL Server:查看SQL日志文件大小命令:dbcc sqlperf(logspace) DBA 日常管理工作中,很重要一项工作就是监视数据库文件大小,及日志文 ...
- MySQl新特性 GTID
GTID简介 概念 全局事务标识符(GTID)是创建的唯一标识符,并与在源(主)服务器上提交的每个事务相关联.此标识符不但是唯一的,而且在给定复制设置中的所有服务器上都是唯一的.所有交易和所有GTID ...
- python基础学习11----函数
一.函数的定义 def 函数名(参数列表): 函数体 return语句 return语句不写或后边不加任何对象即为return None 二.函数的参数 无参数 def func1(): print( ...
- 【转】Java学习---Java中volatile关键字实现原理
[原文]https://www.toutiao.com/i6592879392400081412/ 前言 我们知道volatile关键字的作用是保证变量在多线程之间的可见性,它是java.util.c ...
- 【转】Nginx学习---Nginx&&Redis&&hcache三层缓存架构总结
[原文]https://www.toutiao.com/i6594307974817120782/ 摘要: 对于高并发架构,毫无疑问缓存是最重要的一环,对于大量的高并发,可以采用三层缓存架构来实现,n ...
- 数据库迁移之从oracle 到 MySQL最简单的方法
数据库迁移之从oracle 到 MySQL最简单的方法 因工作需要将oracle数据库换到MySQL数据库,数据量比较大,百万级别的数据,表也比较多,有没有一种既快捷又安全的方法呢?答案是肯定的,下面 ...