大名鼎鼎的 GBDT 算法就是用回归树组合而成的。本文就回归树的基本原理进行讲解,并手把手、肩并肩地带您实现这一算法。

1. 原理篇

1.1 最简单的模型

如果预测某个连续变量的大小,最简单的模型之一就是用平均值。比如同事的平均年龄是 28 岁,那么新来了一批同事,在不知道这些同事的任何信息的情况下,直觉上用平均值 28 来预测是比较准确的,至少比 0 岁或者 100 岁要靠谱一些。我们不妨证明一下我们的直觉:

1.2 加一点难度

仍然是预测同事年龄,这次我们预先知道了同事的职级,假设职级的范围是整数1-10,如何能让这个信息帮助我们更加准确的预测年龄呢?

一个思路是根据职级把同事分为两组,这两组分别应用我们之前提到的“平均值”模型。比如职级小于 5 的同事分到A组,大于或等于5的分到 B 组,A 组的平均年龄是 25 岁,B 组的平均年龄是 35 岁。如果新来了一个同事,职级是 3,应该被分到 A 组,我们就预测他的年龄是 25 岁。

1.3 最佳分割点

还有一个问题待解决,如何取一个最佳的分割点对不同职级的同事进行分组呢?

我们尝试所有 m 个可能的分割点 P_i,沿用之前的损失函数,对 A、B 两组分别计算 Loss 并相加得到 L_i。最小的 L_i 所对应的 P_i 就是我们要找的“最佳分割点”。

1.4 运用多个变量

再复杂一些,如果我们不仅仅知道了同事的职级,还知道了同事的工资(貌似不科学),该如何预测同事的年龄呢?

我们可以分别根据职级、工资计算出职级和工资的最佳分割点P_1, P_2,对应的Loss L_1, L_2。然后比较L_1和L2,取较小者。假设L_1 < L_2,那么按照P_1把不同职级的同事分为A、B两组。在A、B组内分别计算工资所对应的分割点,再分为C、D两组。这样我们就得到了AC, AD, BC, BD四组同事以及对应的平均年龄用于预测。

1.5 答案揭晓

如何实现这种1 to 2, 2 to 4, 4 to 8的算法呢?

熟悉数据结构的同学自然会想到二叉树,这种树被称为回归树,顾名思义利用树形结构求解回归问题。

2. 实现篇

2.1 创建Node类

初始化,存储预测值、左右结点、特征和分割点

class Node(object):

def __init__(self, score=None):

self.score = score

self.left = None

self.right = None

self.feature = None

self.split = None

2.2 创建回归树类

初始化,存储根节点和树的高度。

class RegressionTree(object):

def __init__(self):

self.root = Node()

self.height = 0

2.3 计算分割点、MSE

根据自变量X、因变量y、X元素中被取出的行号idx,列号feature以及分割点split,计算分割后的MSE。注意这里为了减少计算量,用到了方差公式:

2.4 计算最佳分割点

遍历特征某一列的所有的不重复的点,找出MSE最小的点作为最佳分割点。如果特征中没有不重复的元素则返回None。

def _choose_split_point(self, X, y, idx, feature):

unique = set([X[i][feature] for i in idx])

if len(unique) == 1:

return None

unique.remove(min(unique))

mse, split, split_avg = min(

(self._get_split_mse(X, y, idx, feature, split)

for split in unique), key=lambda x: x[0])

return mse, feature, split, split_avg

2.5 选择最佳特征

遍历所有特征,计算最佳分割点对应的MSE,找出MSE最小的特征、对应的分割点,左右子节点对应的均值和行号。如果所有的特征都没有不重复元素则返回None

def _choose_feature(self, X, y, idx):

m = len(X[0])

split_rets = [x for x in map(lambda x: self._choose_split_point(

X, y, idx, x), range(m)) if x is not None]

if split_rets == []:

return None

_, feature, split, split_avg = min(

split_rets, key=lambda x: x[0])

idx_split = [[], []]

while idx:

i = idx.pop()

xi = X[i][feature]

if xi < split:

idx_split[0].append(i)

else:

idx_split[1].append(i)

return feature, split, split_avg, idx_split

2.6 规则转文字

将规则用文字表达出来,方便我们查看规则。

def _expr2literal(self, expr):

feature, op, split = expr

op = ">=" if op == 1 else "<"

return "Feature%d %s %.4f" % (feature, op, split)

2.7 获取规则

将回归树的所有规则都用文字表达出来,方便我们了解树的全貌。这里用到了队列+广度优先搜索。有兴趣也可以试试递归或者深度优先搜索。

def _get_rules(self):

que = [[self.root, []]]

self.rules = []

while que:

nd, exprs = que.pop(0)

if not(nd.left or nd.right):

literals = list(map(self._expr2literal, exprs))

self.rules.append([literals, nd.score])

if nd.left:

rule_left = copy(exprs)

rule_left.append([nd.feature, -1, nd.split])

que.append([nd.left, rule_left])

if nd.right:

rule_right = copy(exprs)

rule_right.append([nd.feature, 1, nd.split])

que.append([nd.right, rule_right])

2.8 训练模型

仍然使用队列+广度优先搜索,训练模型的过程中需要注意:

  1. 控制树的最大深度max_depth;

  2. 控制分裂时最少的样本量min_samples_split;

  3. 叶子结点至少有两个不重复的y值;

  4. 至少有一个特征是没有重复值的。

def fit(self, X, y, max_depth=5, min_samples_split=2):

self.root = Node()

que = [[0, self.root, list(range(len(y)))]]

while que:

depth, nd, idx = que.pop(0)

if depth == max_depth:

break

if len(idx) < min_samples_split or \

set(map(lambda i: y[i], idx)) == 1:

continue

feature_rets = self._choose_feature(X, y, idx)

if feature_rets is None:

continue

nd.feature, nd.split, split_avg, idx_split = feature_rets

nd.left = Node(split_avg[0])

nd.right = Node(split_avg[1])

que.append([depth+1, nd.left, idx_split[0]])

que.append([depth+1, nd.right, idx_split[1]])

self.height = depth

self._get_rules()

2.9 打印规则

模型训练完毕,查看一下模型生成的规则

def print_rules(self):

for i, rule in enumerate(self.rules):

literals, score = rule

print("Rule %d: " % i, ' | '.join(

literals) + ' => split_hat %.4f' % score)

2.10 预测一个样本

def _predict(self, row):

nd = self.root

while nd.left and nd.right:

if row[nd.feature] < nd.split:

nd = nd.left

else:

nd = nd.right

return nd.score

2.11 预测多个样本

def predict(self, X):

return [self._predict(Xi) for Xi in X]

3 效果评估

3.1 main函数

使用著名的波士顿房价数据集,按照7:3的比例拆分为训练集和测试集,训练模型,并统计准确度。

@run_time

def main():

print("Tesing the accuracy of RegressionTree...")

# Load data

X, y = load_boston_house_prices()

# Split data randomly, train set rate 70%

X_train, X_test, y_train, y_test = train_test_split(

X, y, random_state=10)

# Train model

reg = RegressionTree()

reg.fit(X=X_train, y=y_train, max_depth=4)

# Show rules

reg.print_rules()

# Model accuracy

get_r2(reg, X_test, y_test)

3.2 效果展示

最终生成了15条规则,拟合优度0.801,运行时间1.74秒,效果还算不错~

3.3 工具函数

本人自定义了一些工具函数,可以在github上查看 https://github.com/tushushu/Imylu/blob/master/utils.py 1. run_time – 测试函数运行时间 2. load_boston_house_prices – 加载波士顿房价数据 3. train_test_split – 拆分训练集、测试机 4. get_r2 – 计算拟合优度

总结

回归树的原理:

损失最小化,平均值大法。 最佳行与列,效果顶呱呱。

回归树的实现:

一顿操作猛如虎,加减乘除二叉树。

微信:https://mp.weixin.qq.com/s?__biz=MzAxMjUyNDQ5OA==&mid=2653557207&idx=1&sn=65d635e4b2b2d514c5cf472317001f2d&chksm=806e3f6ab719b67c01d1891eaf4524e8318983481cfbc6cefe555156d6858e02f8ffeb12b859&scene=21#wechat_redirect

回归树的原理及Python实现的更多相关文章

  1. cart中回归树的原理和实现

    前面说了那么多,一直围绕着分类问题讨论,下面我们开始学习回归树吧, cart生成有两个关键点 如何评价最优二分结果 什么时候停止和如何确定叶子节点的值 cart分类树采用gini系数来对二分结果进行评 ...

  2. 连续值的CART(分类回归树)原理和实现

    上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现 思考连续值和离散值的不同之处: 二分子树的时候不同:离散值需要求出最优的两个 ...

  3. CART(分类回归树)原理和实现

    前面我们了解了决策树和adaboost的决策树墩的原理和实现,在adaboost我们看到,用简单的决策树墩的效果也很不错,但是对于更多特征的样本来说,可能需要很多数量的决策树墩 或许我们可以考虑使用更 ...

  4. GBDT回归的原理及Python实现

    一.原理篇 1.1 温故知新回归树是GBDT的基础,之前的一篇文章曾经讲过回归树的原理和实现.链接如下: 回归树的原理及Python实现 1.2 预测年龄仍然以预测同事年龄来举例,从<回归树&g ...

  5. 集成方法:渐进梯度回归树GBRT(迭代决策树)

    http://blog.csdn.net/pipisorry/article/details/60776803 单决策树C4.5由于功能太简单.而且非常easy出现过拟合的现象.于是引申出了很多变种决 ...

  6. 《机器学习Python实现_10_10_集成学习_xgboost_原理介绍及回归树的简单实现》

    一.简介 xgboost在集成学习中占有重要的一席之位,通常在各大竞赛中作为杀器使用,同时它在工业落地上也很方便,目前针对大数据领域也有各种分布式实现版本,比如xgboost4j-spark,xgbo ...

  7. 机器学习之分类回归树(python实现CART)

    之前有文章介绍过决策树(ID3).简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的.按照某种特征切分数据后,该特征在以后切分数据集时就不再使用,因此存在切分 ...

  8. 机器学习之路: python 回归树 DecisionTreeRegressor 预测波士顿房价

    python3 学习api的使用 git: https://github.com/linyi0604/MachineLearning 代码: from sklearn.datasets import ...

  9. 机器学习——手把手教你用Python实现回归树模型

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天这篇是机器学习专题的第24篇文章,我们来聊聊回归树模型. 所谓的回归树模型其实就是用树形模型来解决回归问题,树模型当中最经典的自然还是决 ...

随机推荐

  1. iOS常用设计模式——工厂方法(简单工厂模式,工厂方法模式, 抽象工厂模式)

    1. 简单工厂模式 如何理解简单工厂,工厂方法, 抽象工厂三种设计模式? 简单工厂方法包含:父类拥有共同基础接口,具体子类实现子类特殊功能,工厂类根据参数区分创建不同子类实例.该场景对应的UML图如下 ...

  2. Python中list作为默认参数的陷阱

    在Python中,作为默认参数的一定要是不可变对象,如果是可变对象,就会出现问题,稍不注意,就会调入陷阱,尤其是初学者,比如我(┬_┬). 我们来看一个例子. def add(L=[]): L.app ...

  3. MacOS下,Python2和Python3完美兼容使用(转)

    问题阐述: MacOS默认Python版本是2.7.10,随着Python3的进一步占有市场,Python2.7也将在2020年结束维护,所以在同一台电脑上安装多个Python版本势在必行. 一.py ...

  4. 使用open live writer客户端写博客(亲测有效)

    博客都开了这么久了,才开始将资料上传,但是每次都要登录网页确实很麻烦,所以就用open live writer,使用起来真的是挺方便的,所以将我在安装配置时,发现的问题汇总起来以便日后再次碰到忘记怎么 ...

  5. 解决 This application requires Java Runtime Environment XX

    已经安装了 jdk ,并且设置好了 java 环境变量,CMD 运行 java 或 javac 都正常,其他依赖 jdk 的应用程序都能正常运行.但是在运行 jd-gui 1.1.0 的时候,出现错误 ...

  6. adb 调系统时间

    1.修改前提 获取系统root权限,然后adb shell进入shell界面 adb shell su 2.时区设置 cat /data/property/persist.sys.timezone / ...

  7. 操作手册_MyEclipse

    前言 假 如 你 的 人 生 有 理 想,那 么 就 一 定 要 去 追,不 管 你 现 在 的 理 想 在 别 人 看 来是 多 么 的 可 笑 , 你 也 不 用 在 乎 , 人 生 蹉 跎 几  ...

  8. 说说C#中的enum吧

    enum,就是枚举类型,它是struct,int,single,double一样,都属于值类型,从ValueType类型中派生,存储在栈中.它在被创建时,不需要分配内在空间,所以对程序的性能是有好处的 ...

  9. bootstrap中使用modal加载kindeditor时弹出层文本框不能输入的问题

    答案来自老外http://stackoverflow.com/questions/14795035/twitter-bootstrap-modal-blocks-text-input-field $( ...

  10. 【前端】Chrome DevTools 笔记

    1. 查看网络耗时 timeline 生命周期按照以下类别显示花费的时间: Queuing Stalled 如果适用:DNS lookup.initial connection.SSL handsha ...