一、概念

决策树及其集合是分类和回归的机器学习任务的流行方法。 决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征交互。 诸如随机森林和增强的树集合算法是分类和回归任务的最佳表现者。

决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

二、基本原理

决策树学习通常包含三个方面:特征选择、决策树生成和决策树剪枝。决策树学习思想主要来源于:Quinlan在1986年提出的ID算法、在1993年提出的C4.5算法和Breiman等人在1984年提出的CART算法。

2.1、特征选择

特征选择在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率。通常特征选择的准则是信息增益(或信息增益比、基尼指数等),每次计算每个特征的信息增益,并比较它们的大小,选择信息增益最大(信息增益比最大、基尼指数最小)的特征。

那么问题来了:怎么找到这样的最优划分特征呢?如何来衡量最优?

什么是最优特征,通俗的理解是对训练数据具有很强的分类能力的特征,比如要看相亲的男女是否合适,他们的年龄差这个特征就远比他们的出生地重要,因为年龄差能更好得对相亲是否成功这个分类问题具有更强的分类能力。

但是计算机并不知道哪些特征是最优的,因此,就要找一个衡量特征是不是最优的指标,使得决策树在每一个分支上的数据尽可能属于同一类别的数据,即样本纯度最高。

我们用熵来衡量样本集合的纯度。

这是概率统计与信息论中的一个概念,定义为:

 

其中p(x)=pi表示随机变量X发生概率。

我们可以从两个角度理解这个概念。

第一就是不确定度的一个度量,我们的目标是为了找到一颗树,使得每个分枝上都代表一个分类,也就是说我们希望这个分枝上的不确定性最小,即确定性最大,也就是这些数据都是同一个类别的。熵越小,代表这些数据是同一类别的越多。

第二个角度就是从纯度理解。因为熵是不确定度的度量,如果他们不确定度越小,意味着这个群体的差异很小,也就是它的纯度很高。比如,在明大的某富翁聚会上,来的人大多是某总,普通工薪白领就会很少,如果新来了一个刘总,他是富翁的确定性就很大,不确定性就很小,同时这个群体的纯度很大。总结来说就是熵越小,纯度越大,而我们希望的就是纯度越大越好。

信息增益

我们用信息熵来衡量一个分支的纯度,以及哪个特征是最优的特征
在决策树学习中应用信息增益准则来选择最优特征。信息增益定义如下:

 
信息增益

特征A对训练数据集D的信息增益g(D,A) 等于D的不确定度H(D) 减去给定条件A下D的不确定度H(D|A),可以理解为由于特征A使得对数据集D的分类的不确定性减少的程度,信息增益大的特征具有更强的分类能力。

信息增益率

信息增益选择特征倾向于选择取值较多的特征,假设某个属性存在大量的不同值,决策树在选择属性时,将偏向于选择该属性,但这肯定是不正确(导致过拟合)的。因此有必要使用一种更好的方法,那就是信息增益率(Info Gain Ratio)来矫正这一问题。
其公式为:

 
信息增益率

其中

 
训练数据集D关于特征A的值的熵

,n为特征A取值的个数

基尼指数

概率分布的基尼指数定义为

 
基尼指数

其中K表示分类问题中类别的个数

2.2、决策树的生成

从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点,再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增均很小或没有特征可以选择为止,最后得到一个决策树。

决策树需要有停止条件来终止其生长的过程。一般来说最低的条件是:当该节点下面的所有记录都属于同一类,或者当所有的记录属性都具有相同的值时。这两种条件是停止决策树的必要条件,也是最低的条件。在实际运用中一般希望决策树提前停止生长,限定叶节点包含的最低数据量,以防止由于过度生长造成的过拟合问题。

2.3、决策树的剪枝

决策树生成只考虑了通过信息增益或信息增益比来对训练数据更好的拟合,但没有考虑到如果模型过于复杂,会导致过拟合的产生。而剪枝就是缓解过拟合的一种手段,单纯的决策树生成学习局部的模型,而剪枝后的决策树会生成学习整体的模型,因为剪枝的过程中,通过最小化损失函数,可以平衡决策树的对训练数据的拟合程度和整个模型的复杂度。

决策树的损失函数定义如下:

 
损失函数

其中,

 
图1

三、代码实现

我们以iris数据集(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

3.1、读取数据

首先,读取文本文件;然后,通过map将每行的数据用“,”隔开,在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类。把这里我们用LabeledPoint来存储标签列和特征列。LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。所以,我们把莺尾花的分类进行了一下改变,”Iris-setosa”对应分类0,”Iris-versicolor”对应分类1,其余对应分类2;然后获取莺尾花的4个特征,存储在Vector中。

import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import scala.Tuple2;
SparkConf conf = new SparkConf().setAppName("decisionTree").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf); /**
* 读取数据
* 转化成 LabeledPoint类型
*/
JavaRDD<String> source = sc.textFile("data/mllib/iris.data");
JavaRDD<LabeledPoint> data = source.map(line->{
String[] parts = line.split(",");
double label = 0.0;
if(parts[4].equals("Iris-setosa")) {
label = 0.0;
}else if(parts[4].equals("Iris-versicolor")) {
label = 1.0;
}else {
label = 2.0;
}
return new LabeledPoint(label,Vectors.dense(Double.parseDouble(parts[0]),
Double.parseDouble(parts[1]),
Double.parseDouble(parts[2]),
Double.parseDouble(parts[3])));
});

3.2、划分数据集

接下来,首先进行数据集的划分,这里划分70%的训练集和30%的测试集:

JavaRDD<LabeledPoint>[] splits =  data.randomSplit(new double[] {0.7,0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

3.3、构建模型

调用决策树的trainClassifier方法构建决策树模型,设置参数,比如分类数、信息增益的选择、树的最大深度等:

int numClasses = 3;//分类数
int maxDepth = 5; //树的最大深度
int maxBins = 30;//离散连续特征时使用的bin数。增加maxBins允许算法考虑更多的分割候选者并进行细粒度的分割决策。
String impurity = "gini";
Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<Integer,Integer>();//空的categoricalFeaturesInfo表示所有功能都是连续的。
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);

3.4、模型预测

接下来我们调用决策树模型的predict方法对测试数据集进行预测,并把模型结构打印出来:

JavaPairRDD<Double, Double> predictionAndLabel =  testData.mapToPair(point->{
return new Tuple2<>(model.predict(point.features()),point.label());
});
//打印预测和实际结果
predictionAndLabel.foreach(x->{
System.out.println("predictionAndLabel:"+x);
});
System.out.println("Learned classification tree model:"+model.toDebugString());
/**
*控制台输出结果:
-----------------------
Learned classification tree model:DecisionTreeModel classifier of depth 5 with 15 nodes
If (feature 2 <= 2.45)
Predict: 0.0
Else (feature 2 > 2.45)
If (feature 2 <= 4.75)
Predict: 1.0
Else (feature 2 > 4.75)
If (feature 2 <= 4.95)
If (feature 0 <= 6.25)
If (feature 1 <= 3.05)
Predict: 2.0
Else (feature 1 > 3.05)
Predict: 1.0
Else (feature 0 > 6.25)
Predict: 1.0
Else (feature 2 > 4.95)
If (feature 3 <= 1.7000000000000002)
If (feature 0 <= 6.05)
Predict: 1.0
Else (feature 0 > 6.05)
Predict: 2.0
Else (feature 3 > 1.7000000000000002)
Predict: 2.0
------------------------
**/

3.5、准确性评估

最后,我们把模型预测的准确性打印出来:

double testErr = predictionAndLabel.filter(pl ->  !pl._1().equals(pl._2())).count() / (double)  testData.count();
System.out.println("Test Error:"+testErr);
/**
*控制台输出结果:
------------------------------
Test Error:0.06976744186046512
------------------------------
**/

spark机器学习从0到1决策树(六)的更多相关文章

  1. spark机器学习从0到1介绍入门之(一)

      一.什么是机器学习 机器学习(Machine Learning, ML)是一门多领域交叉学科,涉及概率论.统计学.逼近论.凸分析.算法复杂度理论等多门学科.专门研究计算机怎样模拟或实现人类的学习行 ...

  2. spark机器学习从0到1特征提取 TF-IDF(十二)

        一.概念 “词频-逆向文件频率”(TF-IDF)是一种在文本挖掘中广泛使用的特征向量化方法,它可以体现一个文档中词语在语料库中的重要程度. 词语由t表示,文档由d表示,语料库由D表示.词频TF ...

  3. spark机器学习从0到1特征变换-标签和索引的转化(十六)

      一.原理 在机器学习处理过程中,为了方便相关算法的实现,经常需要把标签数据(一般是字符串)转化成整数索引,或是在计算结束后将整数索引还原为相应的标签. Spark ML 包中提供了几个相关的转换器 ...

  4. spark机器学习从0到1机器学习工作流 (十一)

        一.概念 一个典型的机器学习过程从数据收集开始,要经历多个步骤,才能得到需要的输出.这非常类似于流水线式工作,即通常会包含源数据ETL(抽取.转化.加载),数据预处理,指标提取,模型训练与交叉 ...

  5. spark机器学习从0到1特征选择-卡方选择器(十五)

      一.公式 卡方检验的基本公式,也就是χ2的计算公式,即观察值和理论值之间的偏差   卡方检验公式 其中:A 为观察值,E为理论值,k为观察值的个数,最后一个式子实际上就是具体计算的方法了 n 为总 ...

  6. spark机器学习从0到1奇异值分解-SVD (七)

      降维(Dimensionality Reduction) 是机器学习中的一种重要的特征处理手段,它可以减少计算过程中考虑到的随机变量(即特征)的个数,其被广泛应用于各种机器学习问题中,用于消除噪声 ...

  7. spark机器学习从0到1基本的统计工具之(三)

      给定一个数据集,数据分析师一般会先观察一下数据集的基本情况,称之为汇总统计或者概要性统计.一般的概要性统计用于概括一系列观测值,包括位置或集中趋势(比如算术平均值.中位数.众数和四分位均值),展型 ...

  8. spark机器学习从0到1基本数据类型之(二)

        MLlib支持存储在单个机器上的局部向量和矩阵,以及由一个或多个RDD支持的分布式矩阵. 局部向量和局部矩阵是用作公共接口的简单数据模型. 底层线性代数操作由Breeze提供. 在监督学习中使 ...

  9. spark机器学习从0到1特征抽取–Word2Vec(十四)

      一.概念 Word2vec是一个Estimator,它采用一系列代表文档的词语来训练word2vecmodel.该模型将每个词语映射到一个固定大小的向量.word2vecmodel使用文档中每个词 ...

随机推荐

  1. Dockerfle创建镜像

    简介 Dockerfile 由一行行命令语句组成,并且支持以 # 开头的注释行. 一般的,Dockerfile 分为四部分:基础镜像信息.维护者信息.镜像操作指令和容器启动时执行指令. # This ...

  2. Linux网络服务第二章DHCP原理与配置

    1.笔记 服务端端口:67 客户端端口:68 dhcliemt -r:释放IP地址 dhcliemt -d:重新获取IP地址 :.,$ s/190.168.200 / 192.168.100 /g 从 ...

  3. web 之 session

    Session? 在WEB开发中,服务器可以为每个用户浏览器创建一个会话对象(session对象),注意:一个浏览器独占一个session对象(默认情况下).因此,在需要保存用户数据时,服务器程序可以 ...

  4. mysql不同端口的连接

    连接mysql3306端口命令 mysql -h58.64.217.120 -ushop -p123456 连接非3306端口(指定其他端口) 的命令 mysql -h58.64.217.120 -P ...

  5. Linux从入门到精通系列之NFS

    网络文件系统(NFS,Network File System)是一种将远程主机上的分区(目录)经网络挂载到本地系统的一种机制,通过对网络文件系统的支持,用户可以在本地系统上像操作本地分区一样来对远程主 ...

  6. 【学习笔记:Python-网络编程】Socket 之初见

    Socket 是任何一种计算机网络通讯中最基础的内容.当你在浏览器地址栏中输入一个地址时,你会打开一个套接字,可以说任何网络通讯都是通过 Socket 来完成的. Socket 的 python 官方 ...

  7. qt creator源码全方面分析(4-3)

    内外命名空间 QtCreator源码中,每一个子项目都有内外两层命名空间,一个是外部的,一个是内部的. 示例如下 namespace ExtensionSystem { namespace Inter ...

  8. 带"反悔"的贪心-超市

    题面:https://www.acwing.com/problem/content/description/147/ 超市里有N件商品,每个商品都有利润pi和过期时间di,每天只能卖一件商品,过期商品 ...

  9. C. Jury Marks 思维

    C. Jury Marks 这个题目虽然是只有1600,但是还是挺思维的. 有点难想. 应该可以比较快的推出的是这个肯定和前缀和有关, x x+a1 x+a1+a2 x+a1+a2+a3... x+s ...

  10. IDEA 2020.1 安装教程

    目录 IDEA 2020.1 安装教程 准备工作 破解教程 IDEA 2020.1 安装教程 IDEA 2020.1 安装教程 Win 10 版 64位操作系统 准备工作 IDEA旗舰版下载地址 je ...