• 简介

决策树是一个预测模型,通过坐标数据进行多次分割,找出分界线,绘制决策树。

在机器学习中,决策树学习算法就是根据数据,使用计算机算法自动找出决策边界。

每一次分割代表一次决策,多次决策而形成决策树,决策树可以通过核技巧把简单的线性决策面转换为非线性决策面。

  • 基本思想

树是由节点和边两种元素组成的结构。有这几个关键词:根节点、父节点、子节点和叶子节点。

父节点和子节点是相对的,子节点由父节点根据某一规则分裂而来,然后子节点作为新的父亲节点继续分裂,直至不能分裂为止。而根节点是没有父节点的节点,即初始分裂节点,叶子节点是没有子节点的节点,如下图所示:

决策树利用如上图所示的树结构进行决策,每一个非叶子节点是一个判断条件,每一个叶子节点是结论。从跟节点开始,经过多次判断得出结论。

举个例子

如图,利用决策树将两类样本点分类。

先从X轴观察,在X = 3时,样本点有一次明显的“突变”,我们以X = 3作为一次决策,进行一次划分:

再从Y轴观察,两类样本点在Y = 4 和Y = 2处可以进行划分,进而进行两次划分:

通过这几次划分,样本点被划分为四个部分,其中两类样本点各划为两部分,而且无法再继续分割,这种分割的过程就是决策树:

  • 熵(entropy)

熵的作用:用于控制决策树在什么条件下做出决策,即在什么条件下分割数据

熵的定义:它是一系列样本中的不纯度的测量值(measure of impurity in a bunch of examples)

建立决策树的过程就是找到变量划分点从而产生尽可能的单一的子集,实际上决策树做决策的过程,就是对这个过程的递归重复。

熵描述了数据的混乱程度,熵越大,混乱程度越高,也就是纯度越低;反之,熵越小,混乱程度越低,纯度越高。 熵的计算公式如下所示:

                  

其中Pi表示类i的数量占比。以二分类问题为例,如果两类的数量相同,此时分类节点的纯度最低,熵等于1;如果节点的数据属于同一类时,此时节点的纯度最高,熵等于0。

熵的最大值为1,最小值为0

  • 信息增益

用信息增益表示分裂前后跟的数据复杂度和分裂节点数据复杂度的变化值,计算公式表示为:

                  

其中Gain表示节点的复杂度,Gain越高,说明复杂度越高。信息增益也可以说是分裂前的熵减去孩子节点的熵的和,信息增益越大,分裂后的熵减小得越多,分类的效果越明显。

  • 偏差(bias)与方差(variance)

高偏差机器学习算法实际上会忽略训练数据,它几乎没有能力学习任何数据,这被称为偏差。

另一个极端情况就是高方差,它只能复现曾经出现过的东西,对于没有出现过的情况,他的反应非常差。

通过调整参数让偏差与方差平衡,使算法具有一定泛化能力,但仍然对训练数据开放,能根据数据调整模型,是机器学习的要点。

  • 代码实现

环境:MacOS mojave  10.14.3

Python  3.7.0

使用库:scikit-learn    0.19.2

sklearn.tree官方库:https://scikit-learn.org/stable/modules/tree.html

  1. >>> from sklearn import tree
  2. >>> X = [[0, 0], [1, 1]] #两个样本点
  3. >>> Y = [0, 1] #分别属于两个标签
  4. >>> clf = tree.DecisionTreeClassifier() #进行分类
  5. >>> clf = clf.fit(X, Y)
  6. >>> clf.predict([[2., 2.]]) #预测新点
  7. array([1]) #新点通过分类属于标签1

Main.py  主程序

  1. import sys
  2. from class_vis import prettyPicture, output_image
  3. from prep_terrain_data import makeTerrainData
  4.  
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import pylab as pl
  8. from classifyDT import classify
  9.  
  10. features_train, labels_train, features_test, labels_test = makeTerrainData()
  11.  
  12. ### the classify() function in classifyDT is where the magic
  13. ### happens--fill in this function in the file 'classifyDT.py'!
  14. clf = classify(features_train, labels_train)
  15.  
  16. #### grader code, do not modify below this line
  17.  
  18. prettyPicture(clf, features_test, labels_test)
  19. accuracy = clf.score(features_test, labels_test)
  20.  
  21. # output_image("test.png", "png", open("test.png", "rb").read())
  22. print (accuracy)
  23. acc = accuracy ### you fill this in!

classifyDT.py  决策树分类

  1. def classify(features_train, labels_train):
  2.  
  3. ### your code goes here--should return a trained decision tree classifer
  4. from sklearn.tree import DecisionTreeClassifier
  5. clf = DecisionTreeClassifier(random_state=0)
  6. clf.fit(features_train,labels_train)
  7.  
  8. return clf

perp_terrain_data.py  生成训练点

  1. import random
  2.  
  3. def makeTerrainData(n_points=1000):
  4. ###############################################################################
  5. ### make the toy dataset
  6. random.seed(42)
  7. grade = [random.random() for ii in range(0,n_points)]
  8. bumpy = [random.random() for ii in range(0,n_points)]
  9. error = [random.random() for ii in range(0,n_points)]
  10. y = [round(grade[ii]*bumpy[ii]+0.3+0.1*error[ii]) for ii in range(0,n_points)]
  11. for ii in range(0, len(y)):
  12. if grade[ii]>0.8 or bumpy[ii]>0.8:
  13. y[ii] = 1.0
  14.  
  15. ### split into train/test sets
  16. X = [[gg, ss] for gg, ss in zip(grade, bumpy)]
  17. split = int(0.75*n_points)
  18. X_train = X[0:split]
  19. X_test = X[split:]
  20. y_train = y[0:split]
  21. y_test = y[split:]
  22.  
  23. grade_sig = [X_train[ii][0] for ii in range(0, len(X_train)) if y_train[ii]==0]
  24. bumpy_sig = [X_train[ii][1] for ii in range(0, len(X_train)) if y_train[ii]==0]
  25. grade_bkg = [X_train[ii][0] for ii in range(0, len(X_train)) if y_train[ii]==1]
  26. bumpy_bkg = [X_train[ii][1] for ii in range(0, len(X_train)) if y_train[ii]==1]
  27.  
  28. # training_data = {"fast":{"grade":grade_sig, "bumpiness":bumpy_sig}
  29. # , "slow":{"grade":grade_bkg, "bumpiness":bumpy_bkg}}
  30.  
  31. grade_sig = [X_test[ii][0] for ii in range(0, len(X_test)) if y_test[ii]==0]
  32. bumpy_sig = [X_test[ii][1] for ii in range(0, len(X_test)) if y_test[ii]==0]
  33. grade_bkg = [X_test[ii][0] for ii in range(0, len(X_test)) if y_test[ii]==1]
  34. bumpy_bkg = [X_test[ii][1] for ii in range(0, len(X_test)) if y_test[ii]==1]
  35.  
  36. test_data = {"fast":{"grade":grade_sig, "bumpiness":bumpy_sig}
  37. , "slow":{"grade":grade_bkg, "bumpiness":bumpy_bkg}}
  38.  
  39. return X_train, y_train, X_test, y_test
  40. # return training_data, test_data

class_vis.py  绘图与保存图像

  1. import warnings
  2. warnings.filterwarnings("ignore")
  3.  
  4. import matplotlib
  5. matplotlib.use('agg')
  6.  
  7. import matplotlib.pyplot as plt
  8. import pylab as pl
  9. import numpy as np
  10.  
  11. #import numpy as np
  12. #import matplotlib.pyplot as plt
  13. #plt.ioff()
  14.  
  15. def prettyPicture(clf, X_test, y_test):
  16. x_min = 0.0; x_max = 1.0
  17. y_min = 0.0; y_max = 1.0
  18.  
  19. # Plot the decision boundary. For that, we will assign a color to each
  20. # point in the mesh [x_min, m_max]x[y_min, y_max].
  21. h = .01 # step size in the mesh
  22. xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  23. Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
  24.  
  25. # Put the result into a color plot
  26. Z = Z.reshape(xx.shape)
  27. plt.xlim(xx.min(), xx.max())
  28. plt.ylim(yy.min(), yy.max())
  29.  
  30. plt.pcolormesh(xx, yy, Z, cmap=pl.cm.seismic)
  31.  
  32. # Plot also the test points
  33. grade_sig = [X_test[ii][0] for ii in range(0, len(X_test)) if y_test[ii]==0]
  34. bumpy_sig = [X_test[ii][1] for ii in range(0, len(X_test)) if y_test[ii]==0]
  35. grade_bkg = [X_test[ii][0] for ii in range(0, len(X_test)) if y_test[ii]==1]
  36. bumpy_bkg = [X_test[ii][1] for ii in range(0, len(X_test)) if y_test[ii]==1]
  37.  
  38. plt.scatter(grade_sig, bumpy_sig, color = "b", label="fast")
  39. plt.scatter(grade_bkg, bumpy_bkg, color = "r", label="slow")
  40. plt.legend()
  41. plt.xlabel("bumpiness")
  42. plt.ylabel("grade")
  43.  
  44. plt.savefig("test.png")

得到结果,正确率90.8%

其中,狭长区域为过拟合

  • 决策树的参数

min_samples_split可分割的样本数量下限,默认值为2

对于决策树最下层的每一个节点,是否还要继续分割,min_samples_split决定了能够继续进行分割的最少分割样本

acc_min_samples.py  acc_min_samples对比

  1. import sys
  2. from class_vis import prettyPicture
  3. from prep_terrain_data import makeTerrainData
  4.  
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import pylab as pl
  8.  
  9. features_train, labels_train, features_test, labels_test = makeTerrainData()
  10.  
  11. ########################## DECISION TREE #################################
  12.  
  13. ### your code goes here--now create 2 decision tree classifiers,
  14. ### one with min_samples_split=2 and one with min_samples_split=50
  15. ### compute the accuracies on the testing data and store
  16. ### the accuracy numbers to acc_min_samples_split_2 and
  17. ### acc_min_samples_split_50, respectively
  18.  
  19. from sklearn.tree import DecisionTreeClassifier
  20. clf1 = DecisionTreeClassifier(min_samples_split=2)
  21. clf2 = DecisionTreeClassifier(min_samples_split=50)
  22.  
  23. clf1.fit(features_train,labels_train)
  24. clf2.fit(features_train,labels_train)
  25.  
  26. acc_min_samples_split_2 = clf1.score(features_test, labels_test)
  27. acc_min_samples_split_50 = clf2.score(features_test, labels_test)
  28.  
  29. print (acc_min_samples_split_2)
  30. print (acc_min_samples_split_50)
  31.  
  32. #choose one of two
  33. prettyPicture(clf1, features_test, labels_test)
  34. # prettyPicture(clf2, features_test, labels_test)

上图,min_samples_split分别为2 和50

得到正确率分别为90.8%和91.2%

  • 决策树的优点与缺点

易于使用,易于理解

容易过拟合,尤其对于具有包含大量特征的数据时,复杂的决策树可能会过拟合数据,通过仔细调整参数,避免过拟合(对于节点上只有单个数据点的决策树,几乎肯定是过拟合)

决策树(Decision Trees)的更多相关文章

  1. 海量数据挖掘MMDS week6: 决策树Decision Trees

    http://blog.csdn.net/pipisorry/article/details/49445465 海量数据挖掘Mining Massive Datasets(MMDs) -Jure Le ...

  2. Decision Trees 决策树

    Decision Trees (DT)是用于分类和回归的非参数监督学习方法. 目标是创建一个模型,通过学习从数据特征推断出的简单决策规则来预测目标变量的值. 例如,在下面的例子中,决策树从数据中学习用 ...

  3. Facebook Gradient boosting 梯度提升 separate the positive and negative labeled points using a single line 梯度提升决策树 Gradient Boosted Decision Trees (GBDT)

    https://www.quora.com/Why-do-people-use-gradient-boosted-decision-trees-to-do-feature-transform Why ...

  4. CatBoost使用GPU实现决策树的快速梯度提升CatBoost Enables Fast Gradient Boosting on Decision Trees Using GPUs

    python机器学习-乳腺癌细胞挖掘(博主亲自录制视频)https://study.163.com/course/introduction.htm?courseId=1005269003&ut ...

  5. Logistic Regression vs Decision Trees vs SVM: Part II

    This is the 2nd part of the series. Read the first part here: Logistic Regression Vs Decision Trees ...

  6. Logistic Regression Vs Decision Trees Vs SVM: Part I

    Classification is one of the major problems that we solve while working on standard business problem ...

  7. 机器学习算法 --- Pruning (decision trees) & Random Forest Algorithm

    一.Table for Content 在之前的文章中我们介绍了Decision Trees Agorithms,然而这个学习算法有一个很大的弊端,就是很容易出现Overfitting,为了解决此问题 ...

  8. 机器学习算法 --- Decision Trees Algorithms

    一.Decision Trees Agorithms的简介 决策树算法(Decision Trees Agorithms),是如今最流行的机器学习算法之一,它即能做分类又做回归(不像之前介绍的其他学习 ...

  9. 机器学习算法实践:决策树 (Decision Tree)(转载)

    前言 最近打算系统学习下机器学习的基础算法,避免眼高手低,决定把常用的机器学习基础算法都实现一遍以便加深印象.本文为这系列博客的第一篇,关于决策树(Decision Tree)的算法实现,文中我将对决 ...

随机推荐

  1. [noip2011 luogu1312] Mayan游戏(模拟)

    原题:传送门 大模拟- 两个剪枝: 1.如果左边不为空就不往左边走(因为一定不如左边的移到右边优) 2.如果相邻两颜色相同不需移动 当然也有别的小剪枝(我没写)比如如果当前某一颜色剩余块数满足1< ...

  2. Maven 从安装到环境配置到项目搭建

    maven是基于项目对象模型(pom),可以通过一小段的描述信息来管理项目的构建,报告和文档的软件项目管理工具. Maven是构建项目的管理工具,白话就是说:“Maven的核心功能便是合理叙述项目间的 ...

  3. 提高生产力:发送邮件API和Web服务(包含源码)

    在Web开发中,发邮件是一种非常常见的功能或任务. 发送邮件的6种方式 一文提到了6种方法,文章发表后,有网友指出了还有另外一种方法,Ant中也可以发送邮件. 打开Foxmail之类的邮件客户端或者在 ...

  4. nodejs-路由(待补充)

    path Router 1 2 3 4 5 var express = require('express'); var Router = express.Router(); Router.get('/ ...

  5. nodejs-函数

    使用表达式定义的函数要提到使用之前,要不然无法解析,自然的function xx(xx)不用,ECMAscript自动提前 with关键字 引入空间命令空间,然后可以直接使用里面的对象了 label标 ...

  6. Android获取图片实际大小兼容平板电脑

    项目中有个图片在平板电脑中显示特别小的原因.一直苦于没找到原因,也没有平板电脑測试,今天找了个改动分辨率的,编写相关方法最终处理了,记录下比較: 好让以后不造轮子. 主要是获取文章相关图片显示问题.直 ...

  7. NEFU 2

    其实就是筛选素数. 如,若能被2是质数,则2的倍数全是合数.如此循环. #include <iostream> #include <math.h> #include <c ...

  8. JAVAEE之--------过滤器设置是否缓存(Filter)

    在网页中.每次的client訪问server.有部分不用反复请求.如有些图片,视频等就没有必要每次都请求,这样会让server增大工作量.为了防止这样.我们採用过滤器来设置client是都缓存. 參考 ...

  9. 关于vue 自定义组件的写法与用法

    最近在网上看到很多大神都有写博客的习惯,坚持写博客不但可以为自己的平时的学习做好记录积累 无意之中也学还能帮助到一些其他的朋友所以今天我也注册一个账号记录一下学习的点滴!当然本人能力实在有限写出的文章 ...

  10. ubuntu16.04+caffe训练mnist数据集

    1.   caffe-master文件夹权限修改 下载的caffe源码编译的caffe-master文件夹貌似没有写入权限,输入以下命令修改: sudo chmod -R 777 ~/caffe-ma ...