---------------------------------------------------------------------------------------

本系列文章为《机器学习实战》学习笔记,内容整理自书本,网络以及自己的理解,如有错误欢迎指正。

源码在Python3.5上测试均通过,代码及数据 --> https://github.com/Wellat/MLaction

---------------------------------------------------------------------------------------

1、算法概述及实现

1.1 算法理论

决策树构建的伪代码如下:

显然,决策树的生成是一个递归过程,在决策树基本算法中,有3种情导致递归返回:(1)当前结点包含的样本全属于同一类别,无需划分;(2)属性集为空,或是所有样本在所有属性上取值相同,无法划分;(3)当前结点包含的样本集合为空,不能划分。


划分选择

决策树学习的关键是第8行,即如何选择最优划分属性,一般而言,随着划分过程不断进行,我们希望决策树的分支节点所包含的样本尽可能属于同一类别,即结点的“纯度”越来越高。

①ID3算法

ID3算法通过对比选择不同特征下数据集的信息增益香农熵来确定最优划分特征。

香农熵:

  

from collections import Counter
import operator
import math def calcEnt(dataSet):
classCount = Counter(sample[-1] for sample in dataSet)
prob = [float(v)/sum(classCount.values()) for v in classCount.values()]
return reduce(operator.add, map(lambda x: -x*math.log(x, 2), prob))

纯度差,也称为信息增益,表示为

  

上面公式实际上就是当前节点的不纯度减去子节点不纯度的加权平均数,权重由子节点记录数与当前节点记录数的比例决定。信息增益越大,则意味着使用属性a来进行划分所获得的“纯度提升”越大,效果越好。

信息增益准则对可取值数目较多的属性有所偏好。

②C4.5算法

C4.5决策树生成算法相对于ID3算法的重要改进是使用信息增益率来选择节点属性。它克服了ID3算法存在的不足:ID3算法只适用于离散的描述属性,对于连续数据需离散化,而C4.5则离散连续均能处理。

增益率定义为:

  

其中

  

需注意的是,增益率准则对可取值数目较少的属性有所偏好,因此,C4.5算法并不是直接选择增益率最大的候选划分属性,而是使用了一个启发式:先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。

③CART算法

CART决策树使用“基尼指数”来选择划分属性,数据集D的纯度可用基尼值来度量:

  

它反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率。因此Gini(D)越小,则数据集D的纯度越高

from  collections import Counter
import operator def calcGini(dataSet):
labelCounts = Counter(sample[-1] for sample in dataSet)
prob = [float(v)/sum(labelCounts.values()) for v in labelCounts.values()]
return 1 - reduce(operator.add, map(lambda x: x**2, prob))

剪枝处理

为避免过拟合,需要对生成树剪枝。决策树剪枝的基本策略有“预剪枝”和“后剪枝”。预剪枝是指在决策树生成过程中,对每个结点划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶结点(有欠拟合风险)。后剪枝则是先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶节点。

1.2 构造决策树

本书使用ID3算法划分数据集,即通过对比选择不同特征下数据集的信息增益和香农熵来确定最优划分特征。

1.2.1 计算香农熵:

 from math import log
import operator def createDataSet():
'''
产生测试数据
'''
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
return dataSet, labels def calcShannonEnt(dataSet):
'''
计算给定数据集的香农熵
'''
numEntries = len(dataSet)
labelCounts = {}
#统计每个类别出现的次数,保存在字典labelCounts中
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #如果当前键值不存在,则扩展字典并将当前键值加入字典
shannonEnt = 0.0
for key in labelCounts:
#使用所有类标签的发生频率计算类别出现的概率
prob = float(labelCounts[key])/numEntries
#用这个概率计算香农熵
shannonEnt -= prob * log(prob,2) #取2为底的对数
return shannonEnt if __name__== "__main__":
'''
计算给定数据集的香农熵
'''
dataSet,labels = createDataSet()
shannonEnt = calcShannonEnt(dataSet)

1.2.2 划分数据集

 def splitDataSet(dataSet, axis, value):
'''
按照给定特征划分数据集
dataSet:待划分的数据集
axis: 划分数据集的第axis个特征
value: 特征的返回值(比较值)
'''
retDataSet = []
#遍历数据集中的每个元素,一旦发现符合要求的值,则将其添加到新创建的列表中
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
#extend()和append()方法功能相似,但在处理列表时,处理结果完全不同
#a=[1,2,3] b=[4,5,6]
#a.append(b) = [1,2,3,[4,5,6]]
#a.extend(b) = [1,2,3,4,5,6]
return retDataSet

划分数据集的结果如下所示:

选择最好的数据集划分方式。接下来我们将遍历整个数据集,循环计算香农熵和 splitDataSet() 函数,找到最好的特征划分方式。

 def chooseBestFeatureToSplit(dataSet):
'''
选择最好的数据集划分方式
输入:数据集
输出:最优分类的特征的index
'''
#计算特征数量
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures):
#创建唯一的分类标签列表
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
#计算每种划分方式的信息熵
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
#计算最好的信息增益,即infoGain越大划分效果越好
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature

1.2.3 递归构建决策树

目前我们已经学习了从数据集构造决策树算法所需要的子功能模块,其工作原理如下:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。

由于特征数目并不是在每次划分数据分组时都减少,因此这些算法在实际使用时可能引起一定的问题。目前我们并不需要考虑这个问题,只需要在算法开始运行前计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们通常会采用多数表决的方法决定该叶子节点的分类。

在 trees.py 中增加如下投票表决代码:

 import operator
def majorityCnt(classList):
'''
投票表决函数
输入classList:标签集合,本例为:['yes', 'yes', 'no', 'no', 'no']
输出:得票数最多的分类名称
'''
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
#把分类结果进行排序,然后返回得票数最多的分类结果
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]

创建树的函数代码(主函数):

 def createTree(dataSet,labels):
'''
创建树
输入:数据集和标签列表
输出:树的所有信息
'''
# classList为数据集的所有类标签
classList = [example[-1] for example in dataSet]
# 停止条件1:所有类标签完全相同,直接返回该类标签
if classList.count(classList[0]) == len(classList):
return classList[0]
# 停止条件2:遍历完所有特征时仍不能将数据集划分成仅包含唯一类别的分组,则返回出现次数最多的类标签
#
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 选择最优分类特征
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
# myTree存储树的所有信息
myTree = {bestFeatLabel:{}}
# 以下得到列表包含的所有属性值
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
# 遍历当前选择特征包含的所有属性值
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree

本例返回 myTree 为字典类型,如下:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

2、测试分类和存储分类器

利用决策树的分类函数:

 def classify(inputTree,featLabels,testVec):
'''
决策树的分类函数
inputTree:训练好的树信息
featLabels:标签列表
testVec:测试向量
'''
# 在2.7中,找到key所对应的第一个元素为:firstStr = myTree.keys()[0],
# 这在3.4中运行会报错:‘dict_keys‘ object does not support indexing,这是因为python3改变了dict.keys,
# 返回的是dict_keys对象,支持iterable 但不支持indexable,
# 我们可以将其明确的转化成list,则此项功能在3中应这样实现:
firstSides = list(inputTree.keys())
firstStr = firstSides[0]
secondDict = inputTree[firstStr]
# 将标签字符串转换成索引
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
# 递归遍历整棵树,比较testVec变量中的值与树节点的值,如果到达叶子节点,则返回当前节点的分类标签
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel if __name__== "__main__":
'''
测试分类效果
'''
dataSet,labels = createDataSet()
myTree = createTree(dataSet,labels)
ans = classify(myTree,labels,[1,0])

决策树模型的存储

 def storeTree(inputTree,filename):
'''
使用pickle模块存储决策树
'''
import pickle
fw = open(filename,'wb+')
pickle.dump(inputTree,fw)
fw.close() def grabTree(filename):
'''
导入决策树模型
'''
import pickle
fr = open(filename,'rb')
return pickle.load(fr) if __name__== "__main__":
'''
存取操作
'''
storeTree(myTree,'mt.txt')
myTree2 = grabTree('mt.txt')

3、使用 Matplotlib 绘制树形图

上节我们已经学习如何从数据集中创建决策树,然而字典的表示形式非常不易于理解,决策树的主要优点就是直观易于理解,如果不能将其直观显示出来,就无法发挥其优势。本节使用 Matplotlib 库编写代码绘制决策树。

创建名为 treePlotter.py 的新文件:

3.1 绘制树节点

 import matplotlib.pyplot as plt

 # 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-") # 绘制带箭头的注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) def createPlot():
fig = plt.figure(1, facecolor='grey')
fig.clf()
# 定义绘图区
createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show() if __name__== "__main__":
'''
绘制树节点
'''
createPlot()

结果如下:

3.2 构造注解树

绘制一棵完整的树需要一些技巧。我们虽然有 x, y 坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确x轴的长度;我们还需要知道树有多少层,来确定y轴的高度。这里另一两个新函数 getNumLeafs() 和 getTreeDepth() ,来获取叶节点的数目和树的层数,createPlot() 为主函数,完整代码如下:

 import matplotlib.pyplot as plt

 # 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-") # 绘制带箭头的注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) def createPlot(inTree):
'''
绘树主函数
'''
fig = plt.figure(1, facecolor='white')
fig.clf()
# 设置坐标轴数据
axprops = dict(xticks=[], yticks=[])
# 无坐标轴
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
# 带坐标轴
# createPlot.ax1 = plt.subplot(111, frameon=False)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
# 两个全局变量plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置,
# 以及放置下一个节点的恰当位置
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show() def getNumLeafs(myTree):
'''
获取叶节点的数目
'''
numLeafs = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
# 判断节点是否为字典来以此判断是否为叶子节点
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs def getTreeDepth(myTree):
'''
获取树的层数
'''
maxDepth = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth def plotMidText(cntrPt, parentPt, txtString):
'''
计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息
'''
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
# 计算宽与高
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstSides = list(myTree.keys())
firstStr = firstSides[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# 标记子节点属性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
# 减少y偏移
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def retrieveTree(i):
'''
保存了树的测试数据
'''
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i] if __name__== "__main__":
'''
绘制树
'''
createPlot(retrieveTree(1))

测试结果:

4、实例:使用决策树预测隐形眼镜类型

4.1 处理流程

数据格式如下所示,其中最后一列表示类标签:

4.2 Python实现代码

 import trees
import treePlotter fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree = trees.createTree(lenses,lensesLabels)
treePlotter.createPlot(lensesTree)

产生的决策树:

本节使用的算法成为ID3,它是一个号的算法但无法直接处理数值型数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数值,但如果存在太多的特征划分,ID3算法仍然会面临其他问题。

机器学习实战笔记(Python实现)-02-决策树的更多相关文章

  1. 机器学习实战笔记(Python实现)-09-树回归

    ---------------------------------------------------------------------------------------- 本系列文章为<机 ...

  2. 机器学习实战笔记(Python实现)-08-线性回归

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  3. 机器学习实战笔记(Python实现)-06-AdaBoost

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  4. 机器学习实战笔记(Python实现)-05-支持向量机(SVM)

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  5. 机器学习实战笔记(Python实现)-04-Logistic回归

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  6. 机器学习实战笔记(Python实现)-03-朴素贝叶斯

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  7. 机器学习实战笔记(Python实现)-01-K近邻算法(KNN)

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  8. 机器学习实战笔记(Python实现)-00-readme

    近期学习机器学习,找到一本不错的教材<机器学习实战>.特此做这份学习笔记,以供日后翻阅. 机器学习算法分为有监督学习和无监督学习.这本书前两部分介绍的是有监督学习,第三部分介绍的是无监督学 ...

  9. 机器学习实战笔记(Python实现)-07-模型评估与分类性能度量

    1.经验误差与过拟合 通常我们把分类错误的样本数占样本总数的比例称为“错误率”(error rate),即如果在m个样本中有a个样本分类错误,则错误率E=a/m:相应的,1-a/m称为“精度”(acc ...

随机推荐

  1. 你想不到的!CSS 实现的各种球体效果【附在线演示】

    CSS 可以实现很多你想不到的效果,今天我们来尝试使用 CSS 实现各种球体效果.有两种方法可以实现,第一种是使用大量的元素创建实际的 3D 球体,这种方法有潜在的性能问题:另外一种是使用 CSS3 ...

  2. CodeIgniter-Lottery - php ci 抽奖辅助函数

    CodeIgniter-Lottery - php ci 抽奖辅助函数 Github https://github.com/xjnotxj/CodeIgniter-Lottery 用法 1. 移入文件 ...

  3. 安卓SQLite常见错误

    利用闲时写了一个简单的Sql语句操作SQLite数据库,在用SimpleCursorAdapter时出了一个异常好久都没解决 Process: com.example.chunchuner.usesq ...

  4. LINQ(集成化查询)

    LINQ可以对数组.集合等数据结构进行查询.筛选.排序等操作:也可以用于与数据库交互:也支持对XML的操作,使用LINQ技术可以动态创建.筛选和修改XML数据和直接操作XML文件. 一). LINQ基 ...

  5. 7.2 数据注解属性--TimeStamp特性【Code-First 系列】

    TimeStamp特性可以应用到领域类中,只有一个字节数组的属性上面,这个特性,给列设定的是tiemStamp类型.在并发的检查中,Code-First会自动使用这个TimeStamp类型的字段. 下 ...

  6. 【WP8】WebBrowser相关

    2014年09月02日更新 今天用了一下WebBrowser,在使用过程中也遇到了一些问题,在这里做一下记录 虽然WebBrowser比较重,会比较影响性能(除非一定要用到它,否则尽量少用),但有时候 ...

  7. Windows Service--Write a Better Windows Service

    原文地址: http://visualstudiomagazine.com/Articles/2005/10/01/Write-a-Better-Windows-Service.aspx?Page=1 ...

  8. T- SQL性能优化详解

    摘自:http://www.cnblogs.com/Shaina/archive/2012/04/22/2464576.html 故事开篇:你和你的团队经过不懈努力,终于使网站成功上线,刚开始时,注册 ...

  9. 总结shell

    总结shell里面一些初学者不容易懂得点,因为我本身就是初学者,所以有一些知识点是不容易通过字面意思理解的,下面写在这里. (便于理解的一个方法就是举例子)举个例子就是哪些容易学,哪些不容易理解:丁是 ...

  10. java必备基础知识点

    Java基础 1. 简述Java的基本历史 java起源于SUN公司的一个GREEN的项目,其原先目的是:为家用消费电子产品发送一个信息的分布式代码系统,通过发送信息控制电视机.冰箱等 2. 简单写出 ...