梯度提升树是一种决策树的集成算法。它通过反复迭代训练决策树来最小化损失函数。决策树类似,梯度提升树具有可处理类别特征、易扩展到多分类问题、不需特征缩放等性质。Spark.ml通过使用现有decision tree工具来实现。


梯度提升树依次迭代训练一系列的决策树。在一次迭代中,算法使用现有的集成来对每个训练实例的类别进行预测,然后将预测结果与真实的标签值进行比较。通过重新标记,来赋予预测结果不好的实例更高的权重。所以,在下次迭代中,决策树会对先前的错误进行修正。


对实例标签进行重新标记的机制由损失函数来指定。每次迭代过程中,梯度迭代树在训练数据上进一步减少损失函数的值。spark.ml为分类问题提供一种损失函数(Log Loss),为回归问题提供两种损失函数(平方误差与绝对误差)。


Spark.ml支持二分类以及回归的随机森林算法,适用于连续特征以及类别特征。不支持多分类问题。

# -*- coding: utf-8 -*-
"""
Created on Wed May 9 09:53:30 2018 @author: admin
""" import numpy as np
import matplotlib.pyplot as plt from sklearn import ensemble
from sklearn import datasets
from sklearn.utils import shuffle
from sklearn.metrics import mean_squared_error # #############################################################################
# Load data
boston = datasets.load_boston()
X, y = shuffle(boston.data, boston.target, random_state=13)
X = X.astype(np.float32)
offset = int(X.shape[0] * 0.9)
X_train, y_train = X[:offset], y[:offset]
X_test, y_test = X[offset:], y[offset:] # #############################################################################
# Fit regression model
params = {'n_estimators': 500, 'max_depth': 4, 'min_samples_split': 2,
'learning_rate': 0.01, 'loss': 'ls'} #随便指定参数长度,也不用在传参的时候去特意定义一个数组传参
clf = ensemble.GradientBoostingRegressor(**params) clf.fit(X_train, y_train)
mse = mean_squared_error(y_test, clf.predict(X_test))
print("MSE: %.4f" % mse) # #############################################################################
# Plot training deviance # compute test set deviance
test_score = np.zeros((params['n_estimators'],), dtype=np.float64) for i, y_pred in enumerate(clf.staged_predict(X_test)):
test_score[i] = clf.loss_(y_test, y_pred) plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title('Deviance')
plt.plot(np.arange(params['n_estimators']) + 1, clf.train_score_, 'b-',
label='Training Set Deviance')
plt.plot(np.arange(params['n_estimators']) + 1, test_score, 'r-',
label='Test Set Deviance')
plt.legend(loc='upper right')
plt.xlabel('Boosting Iterations')
plt.ylabel('Deviance') # #############################################################################
# Plot feature importance
feature_importance = clf.feature_importances_
# make importances relative to max importance
feature_importance = 100.0 * (feature_importance / feature_importance.max())
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0]) + .5
plt.subplot(1, 2, 2)
plt.barh(pos, feature_importance[sorted_idx], align='center')
plt.yticks(pos, boston.feature_names[sorted_idx])
plt.xlabel('Relative Importance')
plt.title('Variable Importance')
plt.show()

房产数据介绍:

- CRIM     per capita crime rate by town 
- ZN       proportion of residential land zoned for lots over 25,000 sq.ft. 
- INDUS    proportion of non-retail business acres per town 
- CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise) 
- NOX      nitric oxides concentration (parts per 10 million) 
- RM       average number of rooms per dwelling 
- AGE      proportion of owner-occupied units built prior to 1940 
- DIS      weighted distances to five Boston employment centres 
- RAD      index of accessibility to radial highways 
- TAX      full-value property-tax rate per $10,000 
- PTRATIO  pupil-teacher ratio by town 
- B        1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town 
- LSTAT    % lower status of the population 
- MEDV     Median value of owner-occupied homes in $1000'

参考:http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html#sphx-glr-auto-examples-ensemble-plot-gradient-boosting-regression-py

GBDT梯度提升树算法及官方案例的更多相关文章

  1. 【小白学AI】GBDT梯度提升详解

    文章来自微信公众号:[机器学习炼丹术] 文章目录: 目录 0 前言 1 基本概念 2 梯度 or 残差 ? 3 残差过于敏感 4 两个基模型的问题 0 前言 先缕一缕几个关系: GBDT是gradie ...

  2. GBDT(梯度提升树)scikit-klearn中的参数说明及简汇

    1.GBDT(梯度提升树)概述: GBDT是集成学习Boosting家族的成员,区别于Adaboosting.adaboosting是利用前一次迭代弱学习器的误差率来更新训练集的权重,在对更新权重后的 ...

  3. 一文读懂:GBDT梯度提升

    先缕一缕几个关系: GBDT是gradient-boost decision tree GBDT的核心就是gradient boost,我们搞清楚什么是gradient boost就可以了 GBDT是 ...

  4. 机器学习 | 详解GBDT梯度提升树原理,看完再也不怕面试了

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第30篇文章,我们今天来聊一个机器学习时代可以说是最厉害的模型--GBDT. 虽然文无第一武无第二,在机器学习领域并没有 ...

  5. GBDT 梯度提升决策树简述

    首先明确一点,gbdt 无论用于分类还是回归一直都是使用的CART 回归树.不会因为我们所选择的任务是分类任务就选用分类树,这里面的核心是因为gbdt 每轮的训练是在上一轮的训练的残差基础之上进行训练 ...

  6. 梯度提升决策树(GBDT)与XGBoost、LightGBM

    今天是周末,之前给自己定了一个小目标:每周都要写一篇博客,不管是关于什么内容的都行,关键在于总结和思考,今天我选的主题是梯度提升树的一些方法,主要从这些方法的原理以及实现过程入手讲解这个问题. 本文按 ...

  7. 机器学习 之梯度提升树GBDT

    目录 1.基本知识点简介 2.梯度提升树GBDT算法 2.1 思路和原理 2.2 梯度代替残差建立CART回归树 1.基本知识点简介 在集成学习的Boosting提升算法中,有两大家族:第一是AdaB ...

  8. 梯度提升树 Gradient Boosting Decision Tree

    Adaboost + CART 用 CART 决策树来作为 Adaboost 的基础学习器 但是问题在于,需要把决策树改成能接收带权样本输入的版本.(need: weighted DTree(D, u ...

  9. R︱Yandex的梯度提升CatBoost 算法(官方述:超越XGBoost/lightGBM/h2o)

    俄罗斯搜索巨头 Yandex 昨日宣布开源 CatBoost ,这是一种支持类别特征,基于梯度提升决策树的机器学习方法. CatBoost 是由 Yandex 的研究人员和工程师开发的,是 Matri ...

随机推荐

  1. MySQL root密码重置问题

    1:进入cmd,停止mysql服务:Net stop mysql (进入服务---->MySql----->停止) 到mysql的安装路径启动mysql,在bin目录下使用mysqld-n ...

  2. mac 使用命令行向 github 提交代码

    让 mac 本地和自己的 github 网站建立连接(ssh) 下载安装 git 网址: https://git-scm.com/downloads 查看安装是否成功: git -version $ ...

  3. Newman+Jenkins实现接口自动化测试

    目录 一.是什么Newman 二.如何安装 三.如何使用 1.运行本地文件 2.运行在线文件 3.以node.js库运行 4.导出报告 四.命令行测试真实接口 1.导出collection文件 2.导 ...

  4. Django中的session的使用

    一.Session 的概念 cookie 是在浏览器端保存键值对数据,而 session 是在服务器端保存键值对数据 session 的使用依赖 cookie:在使用 Session 后,会在 Coo ...

  5. js 面向对象中,定义一个函数的过程

    定义一个函数做的两件事:1: 实例化一个Function对象:2: 实例化一个Object对象,并给该函数扩展prototype属性指向这个构造函数 大致过程如图所示: 每一种引用类型(函数,对象,数 ...

  6. 前端Tips#6 - 在 async iterator 上使用 for-await-of 语法糖

    视频讲解 前往原文 前端Tips 专栏#6,点击观看 文字讲解 本期主要是讲解如何使用 for-await-of 语法糖进行异步操作迭代,让组织异步操作的代码更加简洁易读. 1.场景简述 以下代码中的 ...

  7. MapReduce 简单数据统计

    1. 准备数据源 摘录了一片散文,保存格式为utf-8 2. 准备环境 2.1 搭建伪分布式环境 https://www.cnblogs.com/cjq10029/p/12336446.html 上传 ...

  8. docker的安装使用

    目录 Docker 入门到精通 CentOS安装Docker 设置管理Docker的仓库 安装Docker Engine-Community Docker基础命令 开启关闭 镜像操作 容器操作 Doc ...

  9. 032.核心组件-kube-proxy

    一 kube-proxy原理 1.1 kube-proxy概述 Kubernetes为了支持集群的水平扩展.高可用性,抽象出了Service的概念.Service是对一组Pod的抽象,它会根据访问策略 ...

  10. 学习webpack基础笔记01

    学习webpack基础笔记 1.webpack搭建环境最重要的就是如何使用loader和plugins,使用yarn/npm安装插件.预处理器,正确的配置好去使用 2.从0配置webpack - 1. ...