决策树ID3原理及R语言python代码实现(西瓜书)

摘要:

决策树是机器学习中一种非常常见的分类与回归方法,可以认为是if-else结构的规则。分类决策树是由节点和有向边组成的树形结构,节点表示特征或者属性,

而边表示的是属性值,边指向的叶节点为对应的分类。在对样本的分类过程中,由顶向下,根据特征或属性值选择分支,递归遍历直到叶节点,将实例分到叶节点对应的类别中。

决策树的学习过程就是构造出一个能正取分类(或者误差最小)训练数据集的且有较好泛化能力的树,核心是如何选择特征或属性作为节点,

通常的算法是利用启发式的算法如ID3,C4.5,CART等递归的选择最优特征。选择一个最优特征,然后按照此特征将数据集分割成多个子集,子集再选择最优特征,

直到所有训练数据都被正取分类,这就构造出了决策树。决策树有如下特点:

  1. 原理简单, 计算高效;使用基于信息熵相关的理论划分最优特征,原理清晰,计算效率高。
  2. 解释性强;决策树的属性结构以及if-else的判断逻辑,非常符合人的决策思维,使用训练数据集构造出一个决策树后,可视化决策树,

    可以非常直观的理解决策树的判断逻辑,可读性强。
  3. 效果好,应用广泛;其拟合效果一般很好,分类速度快,但也容易过拟合,决策树拥有非常广泛的应用。

本文主要介绍基于ID3的算法构造决策树。

决策树原理

训练数据集有多个特征,如何递归选择最优特征呢?信息熵增益提供了一个非常好的也非常符合人们日常逻辑的判断准则,即信息熵增益最大的特征为最优特征。在信息论中,熵是用来度量随机变量不确定性的量纲,熵越大,不确定性越大。熵定义如下:

此处log一般是以2为底,假设一个产品成品率为100%次品率为0%那么熵就为0,如果是成品率次品率各为50%,那么熵就为1,熵越大,说明不确定性越高,非常符合我们人类的思维逻辑。假设分类标记为随机变量Y,那么H(Y)表示随机变量Y的不确定性,我们依次选择可选特征,如果选择一个特征后,随机变量Y的熵减少的最多,表示得知特征X后,使得类Y不确定性减少最多,那么就把此特征选为最优特征。信息熵增益的公式如下:

ID3算法

决策树基于信息熵增益的ID3算法步骤如下:

  1. 如果数据集类别只有一类,选择这个类别作为,标记为叶节点。
  2. 从数据集的所有特征中,选择信息熵增益最大的作为节点,特征的属性分别作为节点的边。
  3. 选择最优特征后,按照对应的属性,将数据集分成多个,依次将子数据集从第1步递归进行构造子树。

python实现

#encoding:utf-8

import pandas as pd
import numpy as np class DecisionTree:
def __init__(self):
self.model = None
def calEntropy(self, y): # 计算熵
valRate = y.value_counts().apply(lambda x : x / y.size) # 频次汇总 得到各个特征对应的概率
valEntropy = np.inner(valRate, np.log2(valRate)) * -1
return valEntropy def fit(self, xTrain, yTrain = pd.Series()):
if yTrain.size == 0:#如果不传,自动选择最后一列作为分类标签
yTrain = xTrain.iloc[:,-1]
xTrain = xTrain.iloc[:,:len(xTrain.columns)-1]
self.model = self.buildDecisionTree(xTrain, yTrain)
return self.model
def buildDecisionTree(self, xTrain, yTrain):
propNamesAll = xTrain.columns
#print(propNamesAll)
yTrainCounts = yTrain.value_counts()
if yTrainCounts.size == 1:
#print('only one class', yTrainCounts.index[0])
return yTrainCounts.index[0]
entropyD = self.calEntropy(yTrain) maxGain = None
maxEntropyPropName = None
for propName in propNamesAll:
propDatas = xTrain[propName]
propClassSummary = propDatas.value_counts().apply(lambda x : x / propDatas.size)# 频次汇总 得到各个特征对应的概率 sumEntropyByProp = 0
for propClass, dvRate in propClassSummary.items():
yDataByPropClass = yTrain[xTrain[propName] == propClass]
entropyDv = self.calEntropy(yDataByPropClass)
sumEntropyByProp += entropyDv * dvRate
gainEach = entropyD - sumEntropyByProp
if maxGain == None or gainEach > maxGain:
maxGain = gainEach
maxEntropyPropName = propName
#print('select prop:', maxEntropyPropName, maxGain)
propDatas = xTrain[maxEntropyPropName]
propClassSummary = propDatas.value_counts().apply(lambda x : x / propDatas.size)# 频次汇总 得到各个特征对应的概率 retClassByProp = {}
for propClass, dvRate in propClassSummary.items():
whichIndex = xTrain[maxEntropyPropName] == propClass
if whichIndex.size == 0:
continue
xDataByPropClass = xTrain[whichIndex]
yDataByPropClass = yTrain[whichIndex]
del xDataByPropClass[maxEntropyPropName]#删除已经选择的属性列 #print(propClass)
#print(pd.concat([xDataByPropClass, yDataByPropClass], axis=1)) retClassByProp[propClass] = self.buildDecisionTree(xDataByPropClass, yDataByPropClass) return {'Node':maxEntropyPropName, 'Edge':retClassByProp}
def predictBySeries(self, modelNode, data):
if not isinstance(modelNode, dict):
return modelNode
nodePropName = modelNode['Node']
prpVal = data.get(nodePropName)
for edge, nextNode in modelNode['Edge'].items():
if prpVal == edge:
return self.predictBySeries(nextNode, data)
return None
def predict(self, data):
if isinstance(data, pd.Series):
return self.predictBySeries(self.model, data)
return data.apply(lambda d: self.predictBySeries(self.model, d), axis=1) dataTrain = pd.read_csv("xiguadata.csv", encoding = "gbk") decisionTree = DecisionTree()
treeData = decisionTree.fit(dataTrain)
print(pd.DataFrame({'预测值':decisionTree.predict(dataTrain), '正取值':dataTrain.iloc[:,-1]})) import json
print(json.dumps(treeData, ensure_ascii=False))

训练结束后,使用一个递归的字典保存决策树模型,使用格式json工具格式化输出后,可以简洁的看到树的结构。

R语言实现



dataTrain <- read.csv("xiguadata.csv", header = TRUE)

trainDecisionTree <- function(dataTrain){
calEntropy <- function(y){ # 计算熵 values <- table(unlist(y)); # 频次汇总 得到各个特征对应的概率 valuesRate <- values / sum(values); logVal = log2(valuesRate);# log2(0) == infinite
logVal[is.infinite(logVal)]=0; valuesEntropy <- -1 * t(valuesRate) %*% logVal;
if (is.nan(valuesEntropy)){
valuesEntropy = 0;
}
return(valuesEntropy);
} propNamesAll <- names(dataTrain)
propNamesAll <- propNamesAll[length(propNamesAll) * - 1]
print(propNamesAll)
buildDecisionTree <- function(propNames, dataSet){ classColumn = dataSet[, length(dataSet)]#最后一列是类别标签 classSummary <- table(unlist(classColumn))# 频次汇总 defaultRet = c(propNames[1], names(classSummary)[which.max(classSummary)]);
if (length(classSummary) == 1){#如果所有的都是同一类别,那么标记为叶节点
return(defaultRet);
}
if (length(propNames) == 1){#如果只剩一种属性了,那么返回样本数量最多的类别作为节点
return(defaultRet);
}
entropyD <- calEntropy(classColumn)
propGains = sapply(propNames, function(propName){ # propName 对应的是"色泽" "根蒂" "敲声" "纹理" "脐部" "触感"
propDatas <- dataSet[c(propName)] propClassSummary <- table(unlist(propDatas))# 频次汇总 retGain <- sapply(names(propClassSummary), function(propClass){# propClass 对应色泽的种类 如 浅白 青绿 乌黑
dataByPropClass <- subset(dataSet, dataSet[c(propName)] == propClass); #筛选出色泽等于 种类 propClass 的数据集
entropyDv <- calEntropy(dataByPropClass[, length(dataByPropClass)]) #最后一列是标记是否为好瓜
Dv = propClassSummary[c(propClass)][1]
return(entropyDv * Dv);# 这里没有直接除|D|,最后累加后再除,等价的
}); return(entropyD - sum(retGain)/sum(propClassSummary));
});
#print(propGains);
maxEntropyProp = propGains[which.max(propGains)];#选择信息熵增益最大的属性
propName = names(maxEntropyProp)[1]
#print(propName)
propDatas <- dataSet[c(propName)] propClassSummary <- table(unlist(propDatas))# 频次汇总 propClassSummary <- propClassSummary[which(propClassSummary > 0)]
propClassNames <- names(propClassSummary) #propClassNames = c(propClassNames[1])
retGain <- sapply(propClassNames, function(propClass){# propClass 对应色泽的种类 如 浅白 青绿 乌黑 dataByPropClass <- subset(dataSet, dataSet[c(propName)] == propClass); #筛选出色泽等于 种类 propClass 的数据集
leftClassNames = propNames[which(propNames==propName) * -1] #去掉这个属性,递归构造决策树
ret = buildDecisionTree(leftClassNames, dataByPropClass);
return(ret);
});
#names(retGain) = propClassNames
retList = retGain
#retList = list()
#for (propClass in propClassNames){
# retList[propClass] = retGain[propClass]
#}
#print(retList) #索引1表示选择的属性名称 索引2对应的类别,如果有子树那么就是frame,否则就是类别
ret = list(propName, retList)
#ret = data.frame(c(retList))
#names(ret) = c(propName)
return(ret);
}
retProp = buildDecisionTree(propNamesAll, dataTrain);
return(retProp);
}
decisionTree = trainDecisionTree(dataTrain)
#print(decisionTree) library("rpart")
library("rpart.plot")
dataTrain <- read.csv("xiguadata.csv", header = TRUE)
print(dataTrain)
fit <- rpart(HaoGua~.,data=dataTrain,control = rpart.control(minsplit = 1, minbucket = 1),method="class")
printcp(fit) rpart.plot(fit, branch = 1, branch.type = 1, type = 2, extra = 102,shadow.col='gray', box.col='green',border.col='blue', split.col='red',main="DecisionTree") #library(jsonlite)
#dataJson = toJSON(decisionTree)
#c <- file( "result.txt", "w" )
#writeLines(dataJson, c )
#close( c ) #这里需要主动关闭文件 #for (k in propNames) {
# eachData <- dataSet[c(k)]
# values <- table(unlist(eachData))# 频次汇总
# #print(values)
# print(k)
# total <- 0
# for (m in names(values)) {
# #print(m)
# #print(values[m][1])
# data3 <- subset(dataSet, dataSet[c(k)] == m)
# entropyDv <- calEntropy(data3[, length(data3)])
# #print(entropyDv)
# total = total + entropyDv*values[c(m)][1]
# }
# GainDv <- entropyD - total / sum(values);+
# print(GainDv)
#}

R语言代码包含本人自己编写的R语言ID3算法,最后使用R的rpart包训练了一个决策树。

总结:

  • ID3算法简洁清晰,符合人类思路方式。
  • 决策树的解释性强,可视化后也方便理解模型和验证正确性。
  • ID3算法时候标签类特征的样本,对应具有连续型数值的特征,无法运行此算法。
  • 有过拟合的风险,要通过剪枝来避免过拟合。
  • 信息增益有时候偏爱属性很多的特征,C4.5和CART算法可以对此有优化。
  • 这是我的github主页https://github.com/fanchy,有些有意思的分享。
  • python相比R语言写起来还是溜多了,主要是遍历和嵌套,python比R要容易很多,R的数据筛选和选择方便一点,这个python版本的id3算法写的还是很清晰简洁的 正是Talk is cheap. Show me the code。这是在网上可以看到原生实现版本中,最精简的版本之一。

对应的西瓜书数据集为

色泽	根蒂	敲声	纹理	脐部	触感	HaoGua
青绿 蜷缩 浊响 清晰 凹陷 硬滑 是
乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 是
乌黑 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 蜷缩 沉闷 清晰 凹陷 硬滑 是
浅白 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 稍蜷 浊响 清晰 稍凹 软粘 是
乌黑 稍蜷 浊响 稍糊 稍凹 软粘 是
乌黑 稍蜷 浊响 清晰 稍凹 硬滑 是
乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 否
青绿 硬挺 清脆 清晰 平坦 软粘 否
浅白 硬挺 清脆 模糊 平坦 硬滑 否
浅白 蜷缩 浊响 模糊 平坦 软粘 否
青绿 稍蜷 浊响 稍糊 凹陷 硬滑 否
浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 否
乌黑 稍蜷 浊响 清晰 稍凹 软粘 否
浅白 蜷缩 浊响 模糊 平坦 硬滑 否
青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 否

决策树ID3原理及R语言python代码实现(西瓜书)的更多相关文章

  1. 随机森林入门攻略(内含R、Python代码)

    随机森林入门攻略(内含R.Python代码) 简介 近年来,随机森林模型在界内的关注度与受欢迎程度有着显著的提升,这多半归功于它可以快速地被应用到几乎任何的数据科学问题中去,从而使人们能够高效快捷地获 ...

  2. 主成分分析(PCA)原理及R语言实现

    原理: 主成分分析 - stanford 主成分分析法 - 智库 主成分分析(Principal Component Analysis)原理 主成分分析及R语言案例 - 文库 主成分分析法的原理应用及 ...

  3. 决策树(ID3 )原理及实现

    1.决策树原理 1.1.定义 分类决策树模型是一种描述对实例进行分类的树形结构.决策树由结点和有向边组成.结点有两种类型:内部节点和叶节点,内部节点表示一个特征或属性,叶节点表示一个类. 举一个通俗的 ...

  4. 主成分分析(PCA)原理及R语言实现 | dimension reduction降维

    如果你的职业定位是数据分析师/计算生物学家,那么不懂PCA.t-SNE的原理就说不过去了吧.跑通软件没什么了不起的,网上那么多教程,copy一下就会.关键是要懂其数学原理,理解算法的假设,适合解决什么 ...

  5. 决策树---ID3算法(介绍及Python实现)

    决策树---ID3算法   决策树: 以天气数据库的训练数据为例. Outlook Temperature Humidity Windy PlayGolf? sunny 85 85 FALSE no ...

  6. 基于卡方的独立性检验原理及R语言实现

    在读到<R语言实战>(第二版)P143页有关卡方独立性检验所记 假设检验 假设检验(Test of Hypothesis)又称为显著性检验(Test of Ststistical Sign ...

  7. python实现简单决策树(信息增益)——基于周志华的西瓜书数据

    数据集如下: 色泽 根蒂 敲声 纹理 脐部 触感 好瓜 青绿 蜷缩 浊响 清晰 凹陷 硬滑 是 乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 是 乌黑 蜷缩 浊响 清晰 凹陷 硬滑 是 青绿 蜷缩 沉闷 清晰 ...

  8. 命令行运行R语言脚本(代码)

    1 Windows: 键入 cd C:\Program Files\R\R-3.2.0\bin   工作目录切换到R的核心程序目录 键入 R BATCH F:\Test.R 或 Rscript F:\ ...

  9. R语言学习笔记—决策树分类

    一.简介 决策树分类算法(decision tree)通过树状结构对具有某特征属性的样本进行分类.其典型算法包括ID3算法.C4.5算法.C5.0算法.CART算法等.每一个决策树包括根节点(root ...

随机推荐

  1. Java编程思想:XML

    /* 本次实验需要在www.xom.nu上下载XOM类库 */ import nu.xom.*; import java.io.BufferedOutputStream; import java.io ...

  2. C#跟Lua如何超高性能传递数据

    前言 在UWA学堂上线那天,我买了招文勇这篇Lua交互的课程,19块还算值,但是前段时间太忙,一直没空研究,他的demo是基于xlua的,今天终于花了大半天时间在tolua下跑起来了,记录一下我的理解 ...

  3. 个人永久性免费-Excel催化剂功能第26波-正确的Excel密码管理之道

    Excel等文档肩负着我们日常大量的信息存储和传递工作,难免出现数据安全的问题,OFFICE自带的密码设置,在什么样的场景下才有必要使用?网上所宣称的OFFICE文档密码保护不安全,随时可被破解,究竟 ...

  4. fjnu2016-2017 低程 PROBLEM C 汪老司机

    动态规划 方程 #include <iostream>#include <iomanip>#include <cmath>#include <algorith ...

  5. PHP与ECMAScript_4_常用数学相关函数

    PHP ECMAScript 向上取整 ceil($number) Math.ceil( number ) 向下取整 floor($number) Math.floor( number ) 绝对值 a ...

  6. springcloud-eureka

    作者:纯洁的微笑出处:http://www.ityouknow.com/ 版权归作者所有,转载请注明出处 Eureka是Netflix开源的一款提供服务注册和发现的产品,它提供了完整的Service ...

  7. .net RabbitMQ 介绍、安装、运行

    RabbitMQ介绍 什么是MQ Message Queue(简称:MQ),消息队列 顾名思义将内容存入到队列中,存入取出的原则是先进先出.后进后出. 其主要用途:不同进程Process/线程Thre ...

  8. 林大妈的JavaScript基础知识(三):JavaScript编程(2)函数

    JavaScript是一门函数式的面向对象编程语言.了解函数将会是了解对象创建和操作.原型及原型方法.模块化编程等的重要基础.函数包含一组语句,它的主要功能是代码复用.隐藏信息和组合调用.我们编程就是 ...

  9. 深入理解HashMap(jdk7)

    HashMap的结构图示 ​ jdk1.7的HashMap采用数组+单链表实现,尽管定义了hash函数来避免冲突,但因为数组长度有限,还是会出现两个不同的Key经过计算后在数组中的位置一样,1.7版本 ...

  10. 【iOS】手动抛出异常

    之前没遇到过需要手动抛出异常的时候,这次见到了,记录一下.示例代码如下: /** 如果调用 [[BNRItemStore alloc] init],就提示应该使用 [BNRItemStore shar ...