1、简单概念描述

决策树的类型有很多,有CART、ID3和C4.5等,其中CART是基于基尼不纯度(Gini)的,这里不做详解,而ID3和C4.5都是基于信息熵的,它们两个得到的结果都是一样的,本次定义主要针对ID3算法。下面我们介绍信息熵的定义。

p(ai):事件ai发生的概率

  I(ai)=-log2(p(ai)):表示为事件ai的不确定程度,称为ai的自信息量

  H=sum(p(ai)*I(ai)):称为信源S的平均信息量—信息熵

  Gain = BaseEntropy – newEntropy:信息增益

    决策树学习采用的是自顶向下的递归方法,其基本思想是以信息熵为度量构造一棵熵值下降最快的树,到叶子节点处的熵值为零,此时每个叶节点中的实例都属于同一类。ID3的原理是基于信息熵增益Gain达到最大,设原始问题的标签有正例和负例,p和n表示其相应的个数。则原始问题的信息熵为

其中N为该特征所取值的个数,比如{rain,sunny},则N即为2

  ID3易出现的问题:如果是取值更多的属性,更容易使得数据更“纯”(尤其是连续型数值),其信息增益更大,决策树会首先挑选这个属性作为树的顶点。结果训练出来的形状是一棵庞大且深度很浅的树,这样的划分是极为不合理的。 此时可以采用C4.5来解决,C4.5的思想是最大化Gain除以下面这个公式即得到信息增益率:

  其中底为2

2、决策树的优缺点

优点:计算复杂度不高,输出结果易于理解,对中间值缺失不敏感,可以处理不相关特征数据

缺点:可能产生过度匹配问题

适用数据类型:数值型和标称型

3、python代码的实现

以下的代码根据这些数据理解

数据1中包含5个海洋动物,特征包括:不浮出水面是否可以生存,以及是否有脚蹼。我们可以将这些动物分成两类:鱼类和非鱼类。

  不浮出水面是否可以生存 是否有脚蹼 属于鱼类
1
2
3
4
5
  特征[0](no surfacing) 特征[1](flippers) 特征[-1]fish
dataSet[0] 1 1 yes
dataSet[1] 1 1 yes
dataSet[2] 0 1 no
dataSet[3] 0 1 no
dataSet[4] 0 1 no

创建名为trees.py的文件,下面代码内容都在此文件中。

(1)计算信息熵

  1. # -*- coding: utf-8 -*-
    #计算给定数据集的香农熵
  2. def calcShannonEnt(dataSet):
  3. numEntries=len(dataSet) #数据实例总数
  4. labelCounts={} #对类别数量创建了一个数据字典,键值是最后一列的数值
  5. for featVec in dataSet: #featVec表示特征集
  6. currentLabel=featVec[-1] # currentLabel表示当前键值,featVec[-1]表示数据集中的最后一列
  7. #如果当前键值不存在,扩展字典将当前键值加入字典,设置当前键值表示的类别数量为0
  8. if currentLabel not in labelCounts.keys():
  9. labelCounts[currentLabel]=0
  10. #如果当前键值存在,则类别数量累加
  11. labelCounts[currentLabel]+=1
  12. shannonEnt=0.0
  13. for key in labelCounts:
  14. prob=float(labelCounts[key])/numEntries #每个键值都记录了当前类别出现的次数
  15. shannonEnt -=prob*log(prob,2)
  16. return shannonEnt 

(2)创建数据集

  1. #创建数据集
  2. def createDataSet():
  3. dataSet=[[1,1,'yes'],[1,1,'yes'],[0,1,'no'],[0,1,'no'],[0,1,'no']]
  4. labels=['no surfacing','flippers']
  5. return dataSet,labels

在python命令提示符下输入下列命令:

  1. >>> import trees
  2. >>> reload(trees)
  3. <module 'trees' from 'E:\python excise\trees.pyc'>
  4. >>> myDat,labels=trees.createDataSet()
  5. >>> myDat
  6. [[1, 1, 'yes'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  7. >>> trees.calcShannonEnt(myDat)
  8. 0.9709505944546686
  9. >>>

熵越高,则混合的数据越多,在数据集中添加更多的分类,观察熵是如何变化的,这里增加第三个名为maybe的分类,测试熵的变化:

  1. >>> myDat[0][-1]='maybe'
  2. >>> myDat
  3. [[1, 1, 'maybe'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  4. >>> trees.calcShannonEnt(myDat)
  5. 1.3709505944546687

得到熵后,我们可以按照获取最大信息增益的方法划分数据集

(3)划分数据集

我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式

  1. #按照给定特征划分数据集
  2. #dataSet:待划分的数据集,axis:划分数据集的特征,value:需要返回的特征的值
  3. def splitDataSet(dataSet,axis,value):
  4. retDataSet=[] #为了不修改原始数据dataSet,创建一个新的列表对象
  5. for featVec in dataSet:
  6. if featVec[axis]==value:
  7. reducedFeatVec=featVec[:axis] #获取从第0列到特征列的数据
  8. reducedFeatVec.extend(featVec[axis+1:]) #获取从特征列之后的数据
  9. retDataSet.append(reducedFeatVec) #目前reducedFeatVec表示除了特征列的数据
  10. return retDataSet
  1. >>> reload(trees)
  2. <module 'trees' from 'E:\python excise\trees.pyc'>
  3. >>> myDat,labels=trees.createDataSet()
  4. >>> myDat
  5. [[1, 1, 'yes'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  6. >>> trees.splitDataSet(myDat,0,1)
  7. [[1, 'yes'], [1, 'yes']]
  8. >>> trees.splitDataSet(myDat,0,0)
  9. [[1, 'no'], [1, 'no'], [1, 'no']]

(4)选择最好的特征进行划分

  1. #选择最好的数据集划分方式
  2. def chooseBestFeatureToSplit(dataSet):
  3. numFeatures=len(dataSet[0])-1 #减去类别那一列
  4. baseEntropy=calcShannonEnt(dataSet) #计算整个数据集的原始香农熵
  5. bestInfoGain=0.0;bestFeature=-1 #现在最好的特征是数据集中的最后一列
    #i=0,新熵,增益
    #i=1,新熵,增益
  6. for i in range(numFeatures): #循环遍历数据集中的所有特征
  7. featList=[example[i] for example in dataSet] #获取第i个特征所有可能的取值,特征0一个列表,特征1一个列表...
  8. uniqueVals=set(featList) #集合数据类型(set)与列表类型相似,不同之处仅在于集合类型中每个值互不相同
  9. newEntropy=0.0
  10. for value in uniqueVals:
  11. subDataSet=splitDataSet(dataSet,i,value) #划分后的数据集
  12. prob=len(subDataSet)/float(len(dataSet))
  13. newEntropy+=prob*calcShannonEnt(subDataSet) #求划分完的数据集的熵
  14. infoGain=baseEntropy-newEntropy
  15. if(infoGain>bestInfoGain):
  16. bestInfoGain=infoGain
  17. bestFeature=i
  18. return bestFeature

注意:这里数据集需要满足以下两个办法:

<1>所有的列元素都必须具有相同的数据长度

<2>数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签。

  1. >>> reload(trees)
  2. <module 'trees' from 'E:\python excise\trees.pyc'>
  3. >>> myDat,labels=trees.createDataSet()
  4. >>> trees.chooseBestFeatureToSplit(myDat)
  5. 0

(5)创建树的代码

Python用字典类型来存储树的结构,返回的结果是myTree-字典

  1. #创建树的函数代码
  2. def createTree(dataSet,labels):
  3. classList=[example[-1] for example in dataSet]
  4. if classList.count(classList[0])==len(classList): #类别完全相同规则停止继续划分
  5. return classList[0]
  6. if len(dataSet[0])==1: #确认至少有数据集
    return majorityCnt(classList)
  7. bestFeat=chooseBestFeatureToSplit(dataSet)
  8. bestFeatLabel=labels[bestFeat]
  9. myTree={bestFeatLabel:{}}
  10. del(labels[bestFeat]) #得到列表包含的所有属性
  11. featValues=[example[bestFeat] for example in dataSet]
  12. uniqueVals=set(featValues)
  13. for value in uniqueVals:
  14. subLabels=labels[:]
  15. myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
  16. return myTree 

其中递归结束当且仅当该类别中标签完全相同或者遍历所有的特征此时返回次数最多的

  1. >>> reload(trees)
  2. <module 'trees' from 'E:\python excise\trees.pyc'>
  3. >>> myDat,labels=trees.createDataSet()
  4. >>> myTree=trees.createTree(myDat,labels)
  5. >>> myTree
  6. {'no surfacing': {0: 'no', 1: 'yes'}}

其中当所有的特征都用完时,采用多数表决的方法来决定该叶子节点的分类,即该叶节点中属于某一类最多的样本数,那么我们就说该叶节点属于那一类。即为如果数据集已经处理了所有的属性,但是类标签依然不是唯一的,此时我们要决定如何定义该叶子节点,在这种情况下,我们通常采用多数表决的方法来决定该叶子节点的分类。代码如下:

  1. def majorityCnt(classList):
  2. classCount={}
  3. for vote in classList:
  4. if vote not in classCount.keys():classCount[vote]=0
  5. classCount[vote]+=1
  6. sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
  7. return sortedClassCount[0][0]

(6)使用决策树执行分类

  1. #测试算法:使用决策树执行分类
  2. def classify(inputTree,featLabels,testVec):
  3. firstStr=inputTree.keys()[0]
  4. secondDict=inputTree[firstStr]
  5. featIndex=featLabels.index(firstStr)
  6. for key in secondDict.keys():
  7. if testVec[featIndex]==key:
  8. if type(secondDict[key]).__name__=='dict':
  9. classLabel=classify(secondDict[key],featLabels,testVec)
  10. else:classLabel=secondDict[key]
  11. return classLabel
  1. >>> import trees
  2. >>> myDat,labels=trees.createDataSet()
  3. >>> labels
  4. ['no surfacing', 'flippers']
  5. >>> trees.classify(myTree,labels,[1,0])
  6. 'no'
  7. >>> trees.classify(myTree,labels,[1,1])
  8. 'yes'

注意递归的思想很重要。

(7)决策树的存储

构造决策树是一个很耗时的任务。为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用python模块pickle序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。

  1. #使用算法:决策树的存储
  2. def storeTree(inputTree,filename):
  3. import pickle
  4. fw=open(filename,'w')
  5. pickle.dump(inputTree,fw)
  6. fw.close()
  7. def grabTree(filename):
  8. import pickle
  9. fr=open(filename)
  10. return pickle.load(fr)
  1. >>> reload(trees)
  2. >>><module 'tree' from 'trees.py'>
  3. >>> trees.storeTree(myTree,'classifierStorage.txt')
  4. >>> trees.grabTree('classifierStorage.txt')
  5. {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

classifierStorage.txt如下:

补充:

用matplotlib注解上述形成的决策树

Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注释。注解通常用于解释数据的内容。

创建名为treePlotter.py文件,下面代码都在此文件中

其中index方法为查找当前列表中第一个匹配firstStr的元素 返回的为索引。

【Machine Learning in Action --3】决策树ID3算法的更多相关文章

  1. 机器学习实战(Machine Learning in Action)学习笔记————03.决策树原理、源码解析及测试

    机器学习实战(Machine Learning in Action)学习笔记————03.决策树原理.源码解析及测试 关键字:决策树.python.源码解析.测试作者:米仓山下时间:2018-10-2 ...

  2. 《Machine Learning in Action》—— Taoye给你讲讲决策树到底是支什么“鬼”

    <Machine Learning in Action>-- Taoye给你讲讲决策树到底是支什么"鬼" 前面我们已经详细讲解了线性SVM以及SMO的初步优化过程,具体 ...

  3. 《Machine Learning in Action》—— 小朋友,快来玩啊,决策树呦

    <Machine Learning in Action>-- 小朋友,快来玩啊,决策树呦 在上篇文章中,<Machine Learning in Action>-- Taoye ...

  4. 机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集

    机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集 关键字:FPgrowth.频繁项集.条件FP树.非监督学习作者:米 ...

  5. 机器学习实战(Machine Learning in Action)学习笔记————07.使用Apriori算法进行关联分析

    机器学习实战(Machine Learning in Action)学习笔记————07.使用Apriori算法进行关联分析 关键字:Apriori.关联规则挖掘.频繁项集作者:米仓山下时间:2018 ...

  6. 机器学习实战(Machine Learning in Action)学习笔记————06.k-均值聚类算法(kMeans)学习笔记

    机器学习实战(Machine Learning in Action)学习笔记————06.k-均值聚类算法(kMeans)学习笔记 关键字:k-均值.kMeans.聚类.非监督学习作者:米仓山下时间: ...

  7. 机器学习实战(Machine Learning in Action)学习笔记————02.k-邻近算法(KNN)

    机器学习实战(Machine Learning in Action)学习笔记————02.k-邻近算法(KNN) 关键字:邻近算法(kNN: k Nearest Neighbors).python.源 ...

  8. Machine Learning in Action(5) SVM算法

    做机器学习的一定对支持向量机(support vector machine-SVM)颇为熟悉,因为在深度学习出现之前,SVM一直霸占着机器学习老大哥的位子.他的理论很优美,各种变种改进版本也很多,比如 ...

  9. Machine Learning In Action 第二章学习笔记: kNN算法

    本文主要记录<Machine Learning In Action>中第二章的内容.书中以两个具体实例来介绍kNN(k nearest neighbors),分别是: 约会对象预测 手写数 ...

  10. 【机器学习实战】Machine Learning in Action 代码 视频 项目案例

    MachineLearning 欢迎任何人参与和完善:一个人可以走的很快,但是一群人却可以走的更远 Machine Learning in Action (机器学习实战) | ApacheCN(apa ...

随机推荐

  1. css2和CSS3的background属性简写

    1.css2:background:background-color || url("") || no-repeat || scroll || 0 0;  css3:  backg ...

  2. android相关内容

    一: 前台进程: 前台的进程的优先级最高, 可见进程: android系统一般存在少量的可见进程. 服务进程: 没有用户界面, 后台进程: 一般存在较多的后台进程. 空进程: 不包括任何活跃组件的进程 ...

  3. C#项目间循环引用的解决办法,有图有真相

    C#项目间循环引用的解决办法,有图有真相 程序间的互相调用接口,c#禁止互相引用,海宏软件,20160315 /// c#禁止互相引用,如果项目[订单]中有一个orderEdit单元,要在项目[进销存 ...

  4. typeof做类型判断时容易犯下的错

    学过js同学都知道js的数据类型有 字符串.数字.布尔.Null.Undefined和object(数组.function......) 作为一个初学者我一直认为每个数据类型返回的结果是这样的 typ ...

  5. 面试题-Java基础-线程部分

    1.进程和线程的区别是什么? 进程是执行着的应用程序,而线程是进程内部的一个执行序列.一个进程可以有多个线程.线程又叫做轻量级进程. 2.创建线程有几种不同的方式?你喜欢哪一种?为什么? 有三种方式可 ...

  6. iOS TableView的分割线

    if ([self.tableView respondsToSelector:@selector(setSeparatorInset:)]) { [self.tableView setSeparato ...

  7. C#学习心得,记录学习

  8. H5获取的经纬度,该怎么在百度地图中查看?

    之前理所当然的的到百度的坐标拾取系统, 输入H5获取的经纬度坐标,然后查询,然后发现老是有误差,而且误差都是一样的规律:偏实际位置西南方约1000~1500米左右. 以为是H5获取经纬度必然会产生这么 ...

  9. java练习 - 字符串反转

    思路: 1. 首先将字符串转换成数组,一个数组元素放一个字符. 2. 循环遍历字符串,将所有字符串前后字符调换位置,比如:第一个和最后一个调换,第二个和倒数第三调换,第三个和倒数第三调换,直到所有字符 ...

  10. matlab imshow()函数显示白色图像问题

    在MATLAB中,我们常使用imshow()函数来显示图像,而此时的图像矩阵可能经过了某种运算.在MATLAB中,为了保证精度,经过了运算的图像矩阵I其数据类型会从uint8型变成double型.如果 ...