一、概念

决策树及其集合是分类和回归的机器学习任务的流行方法。 决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征交互。 诸如随机森林和增强的树集合算法是分类和回归任务的最佳表现者。

决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

二、基本原理

决策树学习通常包含三个方面:特征选择、决策树生成和决策树剪枝。决策树学习思想主要来源于:Quinlan在1986年提出的ID算法、在1993年提出的C4.5算法和Breiman等人在1984年提出的CART算法。

2.1、特征选择

特征选择在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率。通常特征选择的准则是信息增益(或信息增益比、基尼指数等),每次计算每个特征的信息增益,并比较它们的大小,选择信息增益最大(信息增益比最大、基尼指数最小)的特征。

那么问题来了:怎么找到这样的最优划分特征呢?如何来衡量最优?

什么是最优特征,通俗的理解是对训练数据具有很强的分类能力的特征,比如要看相亲的男女是否合适,他们的年龄差这个特征就远比他们的出生地重要,因为年龄差能更好得对相亲是否成功这个分类问题具有更强的分类能力。

但是计算机并不知道哪些特征是最优的,因此,就要找一个衡量特征是不是最优的指标,使得决策树在每一个分支上的数据尽可能属于同一类别的数据,即样本纯度最高。

我们用熵来衡量样本集合的纯度。

这是概率统计与信息论中的一个概念,定义为:

 

其中p(x)=pi表示随机变量X发生概率。

我们可以从两个角度理解这个概念。

第一就是不确定度的一个度量,我们的目标是为了找到一颗树,使得每个分枝上都代表一个分类,也就是说我们希望这个分枝上的不确定性最小,即确定性最大,也就是这些数据都是同一个类别的。熵越小,代表这些数据是同一类别的越多。

第二个角度就是从纯度理解。因为熵是不确定度的度量,如果他们不确定度越小,意味着这个群体的差异很小,也就是它的纯度很高。比如,在明大的某富翁聚会上,来的人大多是某总,普通工薪白领就会很少,如果新来了一个刘总,他是富翁的确定性就很大,不确定性就很小,同时这个群体的纯度很大。总结来说就是熵越小,纯度越大,而我们希望的就是纯度越大越好。

信息增益

我们用信息熵来衡量一个分支的纯度,以及哪个特征是最优的特征
在决策树学习中应用信息增益准则来选择最优特征。信息增益定义如下:

 
信息增益

特征A对训练数据集D的信息增益g(D,A) 等于D的不确定度H(D) 减去给定条件A下D的不确定度H(D|A),可以理解为由于特征A使得对数据集D的分类的不确定性减少的程度,信息增益大的特征具有更强的分类能力。

信息增益率

信息增益选择特征倾向于选择取值较多的特征,假设某个属性存在大量的不同值,决策树在选择属性时,将偏向于选择该属性,但这肯定是不正确(导致过拟合)的。因此有必要使用一种更好的方法,那就是信息增益率(Info Gain Ratio)来矫正这一问题。
其公式为:

 
信息增益率

其中

 
训练数据集D关于特征A的值的熵

,n为特征A取值的个数

基尼指数

概率分布的基尼指数定义为

 
基尼指数

其中K表示分类问题中类别的个数

2.2、决策树的生成

从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点,再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增均很小或没有特征可以选择为止,最后得到一个决策树。

决策树需要有停止条件来终止其生长的过程。一般来说最低的条件是:当该节点下面的所有记录都属于同一类,或者当所有的记录属性都具有相同的值时。这两种条件是停止决策树的必要条件,也是最低的条件。在实际运用中一般希望决策树提前停止生长,限定叶节点包含的最低数据量,以防止由于过度生长造成的过拟合问题。

2.3、决策树的剪枝

决策树生成只考虑了通过信息增益或信息增益比来对训练数据更好的拟合,但没有考虑到如果模型过于复杂,会导致过拟合的产生。而剪枝就是缓解过拟合的一种手段,单纯的决策树生成学习局部的模型,而剪枝后的决策树会生成学习整体的模型,因为剪枝的过程中,通过最小化损失函数,可以平衡决策树的对训练数据的拟合程度和整个模型的复杂度。

决策树的损失函数定义如下:

 
损失函数

其中,

 
图1

三、代码实现

我们以iris数据集(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

3.1、读取数据

首先,读取文本文件;然后,通过map将每行的数据用“,”隔开,在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类。把这里我们用LabeledPoint来存储标签列和特征列。LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。所以,我们把莺尾花的分类进行了一下改变,”Iris-setosa”对应分类0,”Iris-versicolor”对应分类1,其余对应分类2;然后获取莺尾花的4个特征,存储在Vector中。

import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import scala.Tuple2;
SparkConf conf = new SparkConf().setAppName("decisionTree").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf); /**
* 读取数据
* 转化成 LabeledPoint类型
*/
JavaRDD<String> source = sc.textFile("data/mllib/iris.data");
JavaRDD<LabeledPoint> data = source.map(line->{
String[] parts = line.split(",");
double label = 0.0;
if(parts[4].equals("Iris-setosa")) {
label = 0.0;
}else if(parts[4].equals("Iris-versicolor")) {
label = 1.0;
}else {
label = 2.0;
}
return new LabeledPoint(label,Vectors.dense(Double.parseDouble(parts[0]),
Double.parseDouble(parts[1]),
Double.parseDouble(parts[2]),
Double.parseDouble(parts[3])));
});

3.2、划分数据集

接下来,首先进行数据集的划分,这里划分70%的训练集和30%的测试集:

JavaRDD<LabeledPoint>[] splits =  data.randomSplit(new double[] {0.7,0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

3.3、构建模型

调用决策树的trainClassifier方法构建决策树模型,设置参数,比如分类数、信息增益的选择、树的最大深度等:

int numClasses = 3;//分类数
int maxDepth = 5; //树的最大深度
int maxBins = 30;//离散连续特征时使用的bin数。增加maxBins允许算法考虑更多的分割候选者并进行细粒度的分割决策。
String impurity = "gini";
Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<Integer,Integer>();//空的categoricalFeaturesInfo表示所有功能都是连续的。
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);

3.4、模型预测

接下来我们调用决策树模型的predict方法对测试数据集进行预测,并把模型结构打印出来:

JavaPairRDD<Double, Double> predictionAndLabel =  testData.mapToPair(point->{
return new Tuple2<>(model.predict(point.features()),point.label());
});
//打印预测和实际结果
predictionAndLabel.foreach(x->{
System.out.println("predictionAndLabel:"+x);
});
System.out.println("Learned classification tree model:"+model.toDebugString());
/**
*控制台输出结果:
-----------------------
Learned classification tree model:DecisionTreeModel classifier of depth 5 with 15 nodes
If (feature 2 <= 2.45)
Predict: 0.0
Else (feature 2 > 2.45)
If (feature 2 <= 4.75)
Predict: 1.0
Else (feature 2 > 4.75)
If (feature 2 <= 4.95)
If (feature 0 <= 6.25)
If (feature 1 <= 3.05)
Predict: 2.0
Else (feature 1 > 3.05)
Predict: 1.0
Else (feature 0 > 6.25)
Predict: 1.0
Else (feature 2 > 4.95)
If (feature 3 <= 1.7000000000000002)
If (feature 0 <= 6.05)
Predict: 1.0
Else (feature 0 > 6.05)
Predict: 2.0
Else (feature 3 > 1.7000000000000002)
Predict: 2.0
------------------------
**/

3.5、准确性评估

最后,我们把模型预测的准确性打印出来:

double testErr = predictionAndLabel.filter(pl ->  !pl._1().equals(pl._2())).count() / (double)  testData.count();
System.out.println("Test Error:"+testErr);
/**
*控制台输出结果:
------------------------------
Test Error:0.06976744186046512
------------------------------
**/

spark机器学习从0到1决策树(六)的更多相关文章

  1. spark机器学习从0到1介绍入门之(一)

      一.什么是机器学习 机器学习(Machine Learning, ML)是一门多领域交叉学科,涉及概率论.统计学.逼近论.凸分析.算法复杂度理论等多门学科.专门研究计算机怎样模拟或实现人类的学习行 ...

  2. spark机器学习从0到1特征提取 TF-IDF(十二)

        一.概念 “词频-逆向文件频率”(TF-IDF)是一种在文本挖掘中广泛使用的特征向量化方法,它可以体现一个文档中词语在语料库中的重要程度. 词语由t表示,文档由d表示,语料库由D表示.词频TF ...

  3. spark机器学习从0到1特征变换-标签和索引的转化(十六)

      一.原理 在机器学习处理过程中,为了方便相关算法的实现,经常需要把标签数据(一般是字符串)转化成整数索引,或是在计算结束后将整数索引还原为相应的标签. Spark ML 包中提供了几个相关的转换器 ...

  4. spark机器学习从0到1机器学习工作流 (十一)

        一.概念 一个典型的机器学习过程从数据收集开始,要经历多个步骤,才能得到需要的输出.这非常类似于流水线式工作,即通常会包含源数据ETL(抽取.转化.加载),数据预处理,指标提取,模型训练与交叉 ...

  5. spark机器学习从0到1特征选择-卡方选择器(十五)

      一.公式 卡方检验的基本公式,也就是χ2的计算公式,即观察值和理论值之间的偏差   卡方检验公式 其中:A 为观察值,E为理论值,k为观察值的个数,最后一个式子实际上就是具体计算的方法了 n 为总 ...

  6. spark机器学习从0到1奇异值分解-SVD (七)

      降维(Dimensionality Reduction) 是机器学习中的一种重要的特征处理手段,它可以减少计算过程中考虑到的随机变量(即特征)的个数,其被广泛应用于各种机器学习问题中,用于消除噪声 ...

  7. spark机器学习从0到1基本的统计工具之(三)

      给定一个数据集,数据分析师一般会先观察一下数据集的基本情况,称之为汇总统计或者概要性统计.一般的概要性统计用于概括一系列观测值,包括位置或集中趋势(比如算术平均值.中位数.众数和四分位均值),展型 ...

  8. spark机器学习从0到1基本数据类型之(二)

        MLlib支持存储在单个机器上的局部向量和矩阵,以及由一个或多个RDD支持的分布式矩阵. 局部向量和局部矩阵是用作公共接口的简单数据模型. 底层线性代数操作由Breeze提供. 在监督学习中使 ...

  9. spark机器学习从0到1特征抽取–Word2Vec(十四)

      一.概念 Word2vec是一个Estimator,它采用一系列代表文档的词语来训练word2vecmodel.该模型将每个词语映射到一个固定大小的向量.word2vecmodel使用文档中每个词 ...

随机推荐

  1. 2019-2020-1 20199328《Linux内核原理与分析》第二周作业

    冯诺依曼体系结构的核心是: 冯诺依曼体系结构五大部分:控制器,运算器,存储器,输入输出设备. 常用的寄存器 AX.BX.CX.DX一般存放一些一般的数据,被称为通用寄存器,分别拥有高8位和低8位. 段 ...

  2. vue项目中使用bpmn-基础篇

    内容概述 本系列“vue项目中使用bpmn-xxxx”分为五篇,均为自己使用过程中用到的实例,手工原创,目前属于陆续更新中.主要包括vue项目中bpmn使用实例.应用技巧.基本知识点总结和需要注意事项 ...

  3. 怎么在Chrome和Firefox浏览器中清除HSTS设置?

    HSTS代表的是HTTPS严格传输安全协议,它是一个网络安全政策机制,能够强迫浏览器只通过安全的HTTPS连接(永远不能通过HTTP)与网站交互.这能够帮助防止协议降级攻击和cookie劫持. HST ...

  4. Java的循环语句

    一.while 循环 while(循环条件){ 循环操作语句 } * 循环3要素: 变量的初值.变量的判断.变量的更新 * 缺少循环变量的更新,循环将一直进行下去 public class Whlie ...

  5. Clickhouse 条形图📊函数展示

    Clickhouse 条形图

  6. 什么是动态规划?动态规划的意义是什么?https://www.zhihu.com/question/23995189

    阮行止 上海洛谷网络科技有限公司 讲师 intro 很有意思的问题.以往见过许多教材,对动态规划(DP)的引入属于"奉天承运,皇帝诏曰"式:不给出一点引入,见面即拿出一大堆公式吓人 ...

  7. C# 多线程(18):一篇文章就理解async和await

    目录 前言 async await 从以往知识推导 创建异步任务 创建异步任务并返回Task 异步改同步 说说 await Task 说说 async Task 同步异步? Task封装异步任务 关于 ...

  8. 一只简单的网络爬虫(基于linux C/C++)————socket相关及HTTP

    socket相关 建立连接 网络通信中少不了socket,该爬虫没有使用现成的一些库,而是自己封装了socket的相关操作,因为爬虫属于客户端,建立套接字和发起连接都封装在build_connect中 ...

  9. 在linux上搭建nacos集群(步骤详细,linux小白也搞得定)

    (1)nacos官网:https://github.com/alibaba/nacos/releases/tag/1.2.1下载nacos安装包到window本地(后缀为tar.zip) (2)在li ...

  10. golang之method

    method Go does not have classes. However, you can define methods on types. package main import ( &qu ...