【机器学习】决策树(Decision Tree) 学习笔记
【机器学习】决策树(decision tree) 学习笔记
标签(空格分隔): 机器学习
决策树简介
决策树(decision tree)是一个树结构(可以是二叉树或非二叉树)。其每个非叶节点表示一个特征属性上的测试,每个分支代表这个特征属性在某个值域上的输出,而每个叶节点存放一个类别。使用决策树进行决策的过程就是从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。
本文采用的是ID3算法,ID3算法就是在每次需要分裂时,计算每个属性的增益率,然后选择增益率最大的属性进行分裂。
更为详细的介绍见这个博客:算法杂货铺——分类算法之决策树(Decision tree)
以及这个博客:机器学习——决策树算法原理及案例
这个博客的内容来自《机器学习实战》一书。
这个博客主要讲解决策树的python实现,把每行的代码都弄明白。
决策树代码实现
下面的代码分为两个问价:tree.py和treePlotter.py。tree.py包含了计算香农信息增益,分割数据集,选择最佳特征,表决叶节点的标签,创建树,对测试集数据做分类,存储树,读取树,以及一个对隐形眼镜进行分类的例子代码。treePlotter.py是把决策树画出来的代码。
tree.py
# coding=utf-8
from math import log
import operator
import treePlotter
def calcShannonEnt(dataSet):
"""
计算香农信息增益
:param dataSet:输入的数据集
:return: 熵
"""
numEntries = len(dataSet) # 数据集实例总数
labelCounts = {} # 数据字典,键值是最后一列的数值,记录当前类别出现的次数
for featVec in dataSet: # 对于每个数据进行循环
currentLabel = featVec[-1] # 最后一列
labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1 # 统计这个标签出现的次数
shannonEnt = 0.0 # 香农信息增益
for key in labelCounts: # 对于每个标签
prob = float(labelCounts[key]) / numEntries # 获取标签出现的概率
shannonEnt -= prob * log(prob, 2) # 信息增益-=xi出现的概率*log2(xi出现的概率)
return shannonEnt
def createDataSet():
"""
创造数据集
:return:数据集,标签
"""
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
# change to discrete values
return dataSet, labels
def splitDataSet(dataSet, axis, value):
"""
划分数据集
:param dataSet:带划分的数据集
:param axis: 划分数据集的特征
:param value: 需要返回的特征的值
:return:
"""
retDataSet = []
for featVec in dataSet: # 遍历数据集中的每一组数据
if featVec[axis] == value: # 该组数据符合特征
reducedFeatVec = featVec[:axis] # 截取该组数据的前半段
reducedFeatVec.extend(featVec[axis + 1:]) # 截取数据的后半段
# 这样两次操作删除了以axis为下标的元素
# 不能直接删除,否则影响原始dataSet
retDataSet.append(reducedFeatVec) # 返回的数据集添加上满足条件的数据组去除了特征的数据组
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 最后一列是标签,不是特征
baseEntropy = calcShannonEnt(dataSet) # 计算原始香农增益
bestInfoGain = 0.0 # 最佳信息增益
bestFeature = -1 # 最好的特征
for i in range(numFeatures): # iterate over all the features
featList = [example[i] for example in dataSet] # create a list of all the examples of this feature
uniqueVals = set(featList) # get a set of unique values
print "uniqueVals", uniqueVals
newEntropy = 0.0 # 对于此特征的熵
for value in uniqueVals: # 遍历此特征所有的唯一属性值
print "value", value
subDataSet = splitDataSet(dataSet, i, value) # 按照这个唯一属性值划分数据
print "subDataSet", subDataSet
prob = len(subDataSet) / float(len(dataSet)) # 这个唯一属性值出现的概率
print "prob", prob
newEntropy += prob * calcShannonEnt(subDataSet) # 对所有唯一属性值得到的熵求和
print "newEntropy", newEntropy
infoGain = baseEntropy - newEntropy # calculate the info gain; ie reduction in entropy
print "infoGain", infoGain
if (infoGain > bestInfoGain): # compare this to the best gain so far
bestInfoGain = infoGain # if better than current best, set to best
print "bestInfoGain", bestInfoGain
bestFeature = i
return bestFeature # returns an integer
def majorityCnt(classList):
"""
如果所有属性都参与了划分,但类标签依然不是唯一的,定义叶子节点的方法
:param classList: 叶子节点的所有标签
:return: 该叶子节点的标签定义
"""
classCount = {} # 叶子节点的统计
for vote in classList: # 投票表决
if vote not in classCount.keys(): classCount[vote] = 0 # 如果没有该类标签就初始化为0
classCount[vote] += 1 # 类标签个数加一
# 也可以用下面代码代替上面两行
# classCount[vote] = classCount.get(vote, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
print "sortedClassCount", sortedClassCount
# 按照类标签个数排序
return sortedClassCount[0][0] # 返回个数最多的标签名称
def createTree(dataSet, labels):
"""
创建树
:param dataSet: 数据集
:param labels: 标签列表,其实用不到
:return:
"""
classList = [example[-1] for example in dataSet] # 所有类别标签
print "classList", classList
if classList.count(classList[0]) == len(classList): # 判断类标签全部相同
return classList[0] # stop splitting when all of the classes are equal
if len(dataSet[0]) == 1: # stop splitting when there are no more features in dataSet
return majorityCnt(classList) # 已无法再使用特征分类,用标签的大多数代表这个节点
bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最佳分类标签的序号
print "bestFeat", bestFeat
bestFeatLabel = labels[bestFeat] # 最佳分类标签
print "bestFeatLabel", bestFeatLabel
myTree = {bestFeatLabel: {}} # 保存树的所有信息
del (labels[bestFeat]) # 删除标签列表中的最佳标签
featValues = [example[bestFeat] for example in dataSet] # 最佳标签对应的所有特征值
print "featValues", featValues
uniqueVals = set(featValues) # 把最佳标签对应的所有特征值去重
print "uniqueVals", uniqueVals
for value in uniqueVals: # 对于每个唯一的最佳标签对应的所有特征值
subLabels = labels[:] # copy all of labels, so trees don't mess up existing labels
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
def classify(inputTree, featLabels, testVec):
"""
使用决策树的分类函数
:param inputTree:输入的树
:param featLabels:特征标签
:param testVec:要进行分类的向量
:return:
"""
firstStr = inputTree.keys()[0] # 输入树的第一个分类标签字符串
print "firstStr", firstStr
secondDict = inputTree[firstStr] # 标签字符串指向的树
print "secondDict", secondDict
featIndex = featLabels.index(firstStr) # 将标签字符串转换为索引
print "featIndex", featIndex
key = testVec[featIndex] # 找出测试的向量此索引下的值
print "key", key
valueOfFeat = secondDict[key] # 根据索引下的值找出下一个子树
print "valueOfFeat", valueOfFeat
if isinstance(valueOfFeat, dict): # 循环判断是否已经到了叶节点
classLabel = classify(valueOfFeat, featLabels, testVec) # 不是叶子节点,分类标签继续循环
else:
classLabel = valueOfFeat # 已经到了叶节点
return classLabel # 返回最后预测的分类标签
def storeTree(inputTree, filename):
"""
存储决策树
:param inputTree:要保存的决策树
:param filename:保存的文件名
:return:
"""
import pickle
fw = open(filename, 'w') # 文件写
pickle.dump(inputTree, fw) # 把决策树对象序列化写
fw.close() # 关闭文件操作
def grabTree(filename):
"""
从磁盘上读取决策树
:param filename:文件名字
:return: 决策树
"""
import pickle
fr = open(filename)
return pickle.load(fr)
dataSet, labels = createDataSet()
print "dataSet", dataSet
myTree = treePlotter.retrieveTree(0)
print "myTree", myTree
treePlotter.createPlot(myTree)
print classify(myTree, labels, [1, 0])
storeTree(myTree, 'classifierStorage.txt')
print grabTree('classifierStorage.txt')
treePlotter.py主要是画图功能。
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) # this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0] # the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key], cntrPt, str(key)) # recursion
else: # it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# if you do get a dictonary you know it's a tree, and the first element will be another dict
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
# createPlot(thisTree)
为了让大家更明白整个过程的运行结果,可以看下面的输出数据。
dataSet [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
myTree {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
firstStr no surfacing
secondDict {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
featIndex 0
key 1
valueOfFeat {'flippers': {0: 'no', 1: 'yes'}}
firstStr flippers
secondDict {0: 'no', 1: 'yes'}
featIndex 1
key 0
valueOfFeat no
no
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
画出的决策树:
决策树实战 使用决策树预测隐形眼镜类型
数据集是这个lenses.txt:
young myope no reduced no lenses
young myope no normal soft
young myope yes reduced no lenses
young myope yes normal hard
young hyper no reduced no lenses
young hyper no normal soft
young hyper yes reduced no lenses
young hyper yes normal hard
pre myope no reduced no lenses
pre myope no normal soft
pre myope yes reduced no lenses
pre myope yes normal hard
pre hyper no reduced no lenses
pre hyper no normal soft
pre hyper yes reduced no lenses
pre hyper yes normal no lenses
presbyopic myope no reduced no lenses
presbyopic myope no normal no lenses
presbyopic myope yes reduced no lenses
presbyopic myope yes normal hard
presbyopic hyper no reduced no lenses
presbyopic hyper no normal soft
presbyopic hyper yes reduced no lenses
presbyopic hyper yes normal no lenses
下面的代码就是通过上文的决策树算法实现了预测,并且画出了具体的决策树的结构图。
def classifyLenses():
"""
分类隐形眼镜
:return:
"""
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
print "lenses", lenses
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print "lensesTree", lensesTree
treePlotter.createPlot(lensesTree)
classifyLenses()
画出来的决策树的结构图如下。
决策树算法在做分类时同样存在问题。比如过度匹配,ID3算法可以用于划分标称数据集,无法直接处理数值型数据。
这篇博客是对《机器学习实战》一书的学习笔记,如有不明白之处,请阅读该书。
【机器学习】决策树(Decision Tree) 学习笔记的更多相关文章
- 机器学习-决策树 Decision Tree
咱们正式进入了机器学习的模型的部分,虽然现在最火的的机器学习方面的库是Tensorflow, 但是这里还是先简单介绍一下另一个数据处理方面很火的库叫做sklearn.其实咱们在前面已经介绍了一点点sk ...
- 【转载】决策树Decision Tree学习
本文转自:http://www.cnblogs.com/v-July-v/archive/2012/05/17/2539023.html 最近在研究规则引擎,需要学习决策树.决策表等算法.发现篇好文对 ...
- 机器学习算法实践:决策树 (Decision Tree)(转载)
前言 最近打算系统学习下机器学习的基础算法,避免眼高手低,决定把常用的机器学习基础算法都实现一遍以便加深印象.本文为这系列博客的第一篇,关于决策树(Decision Tree)的算法实现,文中我将对决 ...
- 数据挖掘 决策树 Decision tree
数据挖掘-决策树 Decision tree 目录 数据挖掘-决策树 Decision tree 1. 决策树概述 1.1 决策树介绍 1.1.1 决策树定义 1.1.2 本质 1.1.3 决策树的组 ...
- 机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别
一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...
- 机器学习框架ML.NET学习笔记【3】文本特征分析
一.要解决的问题 问题:常常一些单位或组织召开会议时需要录入会议记录,我们需要通过机器学习对用户输入的文本内容进行自动评判,合格或不合格.(同样的问题还类似垃圾短信检测.工作日志质量分析等.) 处理思 ...
- 机器学习框架ML.NET学习笔记【2】入门之二元分类
一.准备样本 接上一篇文章提到的问题:根据一个人的身高.体重来判断一个人的身材是否很好.但我手上没有样本数据,只能伪造一批数据了,伪造的数据比较标准,用来学习还是蛮合适的. 下面是我用来伪造数据的代码 ...
- 机器学习框架ML.NET学习笔记【1】基本概念与系列文章目录
一.序言 微软的机器学习框架于2018年5月出了0.1版本,2019年5月发布1.0版本.期间各版本之间差异(包括命名空间.方法等)还是比较大的,随着1.0版发布,应该是趋于稳定了.之前在园子里也看到 ...
- 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)
一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...
- 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类
一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...
随机推荐
- 在linux下查看python已经安装的模块
一.命令行下使用pydoc命令 在命令行下运行$ pydoc modules即可查看 二.在python交互解释器中使用help()查看 python--->在交互式解释器中输入>> ...
- UE4之Slate:App默认窗口的创建流程
UE4版本:4.24.3源码编译 Windows10 + VS2019开发环境 在先前分享的基础上,现在来梳理下App启动时默认窗口的创建流程,以及相关的类.对象之间的抽象层级: 纯C++工程配置 S ...
- Docker环境中部署Prometheus及node-exporter监控主机资源
前提条件 已部署docker 已部署grafana 需要开放 3000 9100 和 9090 端口 启动node-exporter docker run --name node-exporter - ...
- 学习java的第二十一天
一.今日收获 1.java完全学习手册第三章算法的3.2排序,比较了跟c语言排序上的不同 2.观看哔哩哔哩上的教学视频 二.今日问题 1.快速排序法的运行调试多次 2.哔哩哔哩教学视频的一些术语不太理 ...
- Flume(二)【入门】
目录 一.安装部署 1.安装地址 2.安装步骤 二.入门案例 1.官方案例(nestat->logger) 2.实时监控单个追加文件(exec->hdfs) 3.实时监控目录下多个新文件( ...
- Oracle—网络配置文件
Oracle网络配置文件详解 三个配置文件 listener.ora.sqlnet.ora.tnsnames.ora ,都是放在$ORACLE_HOME/network/admin目录下. 1 ...
- Spring Cloud Feign原理详解
目录 1.什么是Feign? 2.Open Feign vs Spring Cloud Feign 2.1.OpenFeign 2.2.Spring Cloud Open Feign 3.Spring ...
- zabbix之邮件报警
创建媒介类型 如果用QQ邮箱的话,先设置一下授权码 为用户设置报警 创建一个用户 配置动作 测试
- 【Services】【Web】【Nginx】静态下载页面的安装与配置
1. 拓扑 F5有自动探活机制,如果一台机器宕机,请求会转发到另外一台,省去了IPVS漂移的麻烦 F5使用轮询算法,向两台服务器转发请求,实现了负载均衡 2. 版本: 2.1 服务器版本:RHEL7. ...
- feignclient发送get请求,传递参数为对象
feignclient发送get请求,传递参数为对象.此时不能使用在地址栏传递参数的方式,需要将参数放到请求体中. 第一步: 修改application.yml中配置feign发送请求使用apache ...