《机器学习实战》学习笔记第九章 —— 决策树之CART算法
相关博文:
主要内容:
一.CART算法简介
二.分类树
三.回归树
四.构建回归树
五.回归树的剪枝
六.模型树
七.树回归与标准回归的比较
一.CART算法简介
1.对于上一篇博客所介绍的决策树,其使用的算法是ID3算法或者是C4.5算法,他们都是根据特征的所有取值情况来进行分割结点的。也正因如此,这两种算法都只能用于离散型的特征,而不能处理连续型的特征。为了解决这个问题,我们使用二元切分法来对连续型的特征进行处理。所谓二元切分法,其实就是一个对特征进行 True or False的判断(最简单如:是或不是、小于或大于等于),这个判断就将数据分割成两半,而不管其特征是连续型的还是离散型的。显而易见,以这种方法构建出来的决策树是一棵二叉树。这就是CART算法最基本的思路。
二.分类树
1.CART算法使用在离散型特征的数据上,则称为分类树(分类还是回归不是以Y来界定的吗?为什么这里以特征X来界定?)。 CART算法使用在离散型标签的数据上,称为分类树。在这里不再像ID3算法那样使用熵来衡量数据(指的是Y)的不确定性,而是使用“基尼指数”。基尼指数的详情如下:
2.CART算法之分类树:
(为什么选择基尼系数最小的?基尼系数与熵类似,其值越大,不确定性越大,那么选择分割后基尼系数最小的,表明不确定性越小了,类别就越能确定了)
三.回归树
1.同样地,在CART算法之回归树中,数据的不确定性不再是用熵来衡量,但也不是用基尼指数,而是用总方差。
问1)为什么分类树中不用熵,而是用基尼指数呢?
这个问题倒真的想不到解释,感觉两种都可以用来衡量离散型变量的不确定性。待解决……
问2)为什么回归树中不用熵或者基尼指数呢?
这个很容易解释,离散型变量可以数出现的个数来计算概率,但是连续型变量,对于单单一个点的值而言,是没有概率的,所以熵或者基尼指数不能用来衡量连续型变量的不确定性。
问3)为什么使用总方差而不是方差(总方差/m)呢?
《统计学习方法》里面提到一句话“基尼指数值越大,样本集合的不确定性也越大”,其中个人觉得“样本集合”这个词是关键。所谓样本集合,一个表现特征就是规模。我想,规模也是影响“数据混乱程度”的一个因素。所以使用总方差。当然这只是个人感性的理解,并没有任何的理论参考或推敲。
2.CART算法之回归树:
四.构建回归树
1.算法流程:
2.代码及注释:
- def binSplitDataSet(dataSet, feature, value): #根据分割特征及其值将数据分成两半
- mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
- mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
- return mat0,mat1
- '''选择最好的分割特征及其值'''
- def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
- tolS = ops[0]; tolN = ops[1] #tolS是分割误差减少的下限,tolN是分割后每个子树的结点个数下限
- if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #如果所有值都相等则退出
- return None, leafType(dataSet)
- m,n = shape(dataSet)
- S = errType(dataSet) #分割前的总方差
- bestS = inf; bestIndex = 0; bestValue = 0
- for featIndex in range(n-1): #枚举特征
- for splitVal in set(dataSet[:,featIndex]): #枚举该特征下在训练数据中所有出现的值,所谓分割线
- mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) #将数据集切割成两半
- if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue #如果切割后某一子树的结点数少于下限,则此次分割无效
- newS = errType(mat0) + errType(mat1)
- if newS < bestS: #更新最小总方差下的分割特征及其值
- bestIndex = featIndex
- bestValue = splitVal
- bestS = newS
- if (S - bestS) < tolS: #如果在最好的情况下(即总方差减少得最多),总方差的减少量仍然少于下限,则此次分割无效,直接返回当前数据集作为叶子结点
- return None, leafType(dataSet) #exit cond 2
- mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
- if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #话说这一步不是已经在双重循环里面了吗?
- return None, leafType(dataSet)
- return bestIndex,bestValue
- def regLeaf(dataSet): #生成叶子结点,均值作为返回值(即预测值)
- return mean(dataSet[:,-1])
- def regErr(dataSet): #计算总方差
- return var(dataSet[:,-1]) * shape(dataSet)[0]
- '''构建回归树:leafType是建立叶子结点的函数,errType是计算误差的函数,ops是参数元组(用于预剪枝)'''
- def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
- feat, val = chooseBestSplit(dataSet, leafType, errType, ops) #选择最好的分割特征及其值
- if feat == None: return val #if the splitting hit a stop condition return val
- retTree = {}
- retTree['spInd'] = feat #记录特征
- retTree['spVal'] = val #记录值
- lSet, rSet = binSplitDataSet(dataSet, feat, val) #将数据集分割成两部分,然后递归左右子树继续生产回归树
- retTree['left'] = createTree(lSet, leafType, errType, ops)
- retTree['right'] = createTree(rSet, leafType, errType, ops)
- return retTree
五.回归树的剪枝
1.剪枝有预剪枝和后剪枝两种。预剪枝对设定的参数非常敏感,如上面代码中tolS和tolN两个参数,分别是分割误差减少的下限、分割后每个子树的结点个数下限。基于预剪枝的性能不太好控制,我们就应着手于后剪枝的研究,其伪代码如下:
2.代码及注释如下:
- def isTree(obj): #判断是否是一棵树,即非叶子结点
- return (type(obj).__name__=='dict')
- def getMean(tree): #递归地求树(子树)的(平均?)方差
- if isTree(tree['right']): tree['right'] = getMean(tree['right'])
- if isTree(tree['left']): tree['left'] = getMean(tree['left'])
- return (tree['left']+tree['right'])/2.0
- '''利用测试数据进行后剪枝'''
- def prune(tree, testData):
- if shape(testData)[0] == 0: #没有测试数据(特殊情况),则塌陷这棵子树,即缩成一个叶子结点。但为什么要这样做?
- return getMean(tree)
- if (isTree(tree['right']) or isTree(tree['left'])): #如果某个儿子是一棵树,则可以对该儿子进行剪枝,因此需要分割数据,注意分割的是测试数据
- lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
- if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) #如果左儿子是树,则对其进行剪枝
- if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) #如果右儿子是树,则对其进行剪枝
- if not isTree(tree['left']) and not isTree(tree['right']): #注意:所谓剪枝其实就是动态地合并两个叶子结点,所以当当前的两个儿子都是叶子结点时,可以尝试合并
- lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) #首先分割测试数据集
- errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) + sum(power(rSet[:,-1] - tree['right'],2)) #计算不合并的误差
- treeMean = (tree['left']+tree['right'])/2.0
- errorMerge = sum(power(testData[:,-1] - treeMean,2)) #计算合并的误差
- if errorMerge < errorNoMerge: #如果合并后的误差小于合并前的误差,则对其进行合并
- print "merging"
- return treeMean
- else: return tree #否则返回当前的树
- else: return tree
六.模型树
1.上面所介绍的决策树中,所有叶子结点,也就是预测值都是直接设定为在该叶子节点上的数据的Y的均值。简而言之,就是叶子结点放的是均值,是一个确定的值。但除此之外,我们还可以在叶子节点上放一个函数,以此进行预测。
2.例如,我们可以在叶子结点上放一个线性回归模型,也正因如此,前面代码中生成叶子结点的方式以及计算总方差的方式方式了改变。详情如下:
- '''线性回归模型'''
- def linearSolve(dataSet): #利用最小二乘法计算线性回归模型的参数ws
- m,n = shape(dataSet)
- X = mat(ones((m,n))); Y = mat(ones((m,1))) #create a copy of data with 1 in 0th postion
- X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
- xTx = X.T*X
- if linalg.det(xTx) == 0.0:
- raise NameError('This matrix is singular, cannot do inverse,\n\
- try increasing the second value of ops')
- ws = xTx.I * (X.T * Y)
- return ws,X,Y
- '''叶子结点为线性回归函数'''
- def modelLeaf(dataSet):#返回线性回归模型的参数ws
- ws,X,Y = linearSolve(dataSet)
- return ws
- def modelErr(dataSet): #计算总方差,因而Y值是X的函数,而不再是之前的均值,所以不能利用以前的误差计算方法
- ws,X,Y = linearSolve(dataSet)
- yHat = X * ws
- return sum(power(Y - yHat,2))
七.回归树、模型树、线性回归的比较
既然介绍了回归树与模型树,且模型树又用到了线性回归模型,且三者都能对同样的数据进行预测,那就理所当然地对它们作出一些比较,分出优劣。具体实现实现如下:
- '''普通回归树叶子结点返回均值。为什么要这个没有用的inData?其实只是为了方便统一输入参数,因为模型树需要输入参数'''
- def regTreeEval(model,inData ):
- return float(model)
- def modelTreeEval(model, inDat): #模型树叶子结点对测试数据的预测值, 与上面的regTreeEval()函数是同类型
- n = shape(inDat)[1]
- X = mat(ones((1,n+1)))
- X[:,1:n+1]=inDat
- return float(X*model)
- '''modelEval是计算叶子结点的值的函数,可以是regLeaf()对应普通回归树,可以是modelTreeEval()对应模型树'''
- def treeForeCast(tree, inData, modelEval=regTreeEval): #搜索回归树,找到合适的预测值.
- if not isTree(tree): return modelEval(tree, inData)
- if inData[tree['spInd']] > tree['spVal']:
- return treeForeCast(tree['left'], inData, modelEval)
- else:
- return treeForeCast(tree['right'], inData, modelEval)
- def createForeCast(tree, testData, modelEval=regTreeEval): #对测试数据集进行预测
- m=len(testData)
- yHat = mat(zeros((m,1)))
- for i in range(m): #枚举每一个数据,并对其进行预测
- yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
- return yHat
然后测试一下三者对测试数据的预测效果,这里用相关系数R2来衡量,R2的值越接近于1.0,预测的效果越好。
首先是回归树:
然后是模型树:
最后是线性回归:
从上面可以看得出:模型树 > 回归树 > 线性回归 。
所以多做了点功夫,效果就较之好一点,是说得过去的。
《机器学习实战》学习笔记第九章 —— 决策树之CART算法的更多相关文章
- 【机器学习实战学习笔记(2-2)】决策树python3.6实现及简单应用
文章目录 1.ID3及C4.5算法基础 1.1 计算香农熵 1.2 按照给定特征划分数据集 1.3 选择最优特征 1.4 多数表决实现 2.基于ID3.C4.5生成算法创建决策树 3.使用决策树进行分 ...
- 【机器学习实战学习笔记(1-1)】k-近邻算法原理及python实现
笔者本人是个初入机器学习的小白,主要是想把学习过程中的大概知识和自己的一些经验写下来跟大家分享,也可以加强自己的记忆,有不足的地方还望小伙伴们批评指正,点赞评论走起来~ 文章目录 1.k-近邻算法概述 ...
- 【机器学习实战】第3章 决策树(Decision Tree)
第3章 决策树 <script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/ ...
- o'Reill的SVG精髓(第二版)学习笔记——第九章
第九章:文本 9.1 字符:在XML文档中,字符是指带有一个数字值的一个或多个字节,数字只与Unicode标准对应. 符号:符号(glyph)是指字符的视觉呈现.每个字符都可以用很多不同的符号来呈现. ...
- 《Python基础教程(第二版)》学习笔记 -> 第九章 魔法方法、属性和迭代器
准备工作 >>> class NewStyle(object): more_code_here >>> class OldStyle: more_code_here ...
- 学习笔记 第九章 使用CSS美化表格
第9章 使用CSS美化表格 学习重点 正确使用表格标签: 设置表格和单元格属性: 设计表格的CSS样式. 9.1 表格的基本结构 表格由行.列.单元格3部分组成,单元格时行与列交叉的部分. 在HTM ...
- 【机器学习实战学习笔记(1-2)】k-近邻算法应用实例python代码
文章目录 1.改进约会网站匹配效果 1.1 准备数据:从文本文件中解析数据 1.2 分析数据:使用Matplotlib创建散点图 1.3 准备数据:归一化特征 1.4 测试算法:作为完整程序验证分类器 ...
- 《DOM Scripting》学习笔记-——第九章 CSS-DOM
本章内容: 一.style属性 二.如何检索样式信息 三.如何改变样式 属性: 包含位置信息:parentNode , nextSibling , previousSibling , childNod ...
- Head First Servlets & JSP 学习笔记 第九章 —— 使用JSTL
JSTL1.1 不是JSP2.0规范的一部分!你能访问Servlet和JSP API 不意味着你能访问JSTL! 使用JSTL之前,需要将两个文件("jstl.jar" 和 &qu ...
随机推荐
- WebScarab安装
1.下载webscarab 下载地址:http://sourceforge.net/projects/owasp/files/WebScarab/20070504-1631/ 2.安装webscara ...
- [ Laravel 5.6 文档 ]laravel数据库操作分页(自定义分页实现和自定义分页样式)
简介 在其他框架中,分页可能是件非常痛苦的事,Laravel 让这件事变得简单.易于上手.Laravel 的分页器与查询构建器和 Eloquent ORM 集成在一起,并开箱提供方便的.易于使用的.基 ...
- 简述JS中 appy 和 call 的详细用法
Apply 和 Call 两个老生常言的方法,使用过程的一些细节还是有很大的异同,具体使用情况可以参照下面例子详细回顾一下. 区别和详解:js中call()和apply()的用法 1.关于call() ...
- secureCrt 开启Linux上的oracle服务
IP : 192.168.0.21 user: root pwd: 123456 --------------------------------------------------- ...
- Java线程—-Runnable和Callable的区别和联系
Java 提供了三种创建线程的方法 1.继承Thread接口 public class Thread2Thread { public static void main(String[] args) { ...
- .Net中多线程类的使用和总结
lock, Monitor, Thread, Join, BackGroundWorker. 消费者和生产者.Async 委托Invoke TypeHandle中BlockIndex. http: ...
- HDU 3397 Sequence operation(区间合并 + 区间更新)
题目链接:pid=3397">http://acm.hdu.edu.cn/showproblem.php?pid=3397 题意:给定n个数,由0,1构成.共同拥有5种操作. 每一个操 ...
- 函数柯里化常见应用---add(1,2) add(1)(2) add(1)(2)(3) add(1,2,3)(4)
这是一道经典的题目,先上代码: 解法1: function add () { var args = Array.prototype.slice.call(arguments); var fn = fu ...
- mxnet编译问题手记
MXNet在64位Win7下的编译安装:https://www.cnblogs.com/noahzn/p/5506086.html http://blog.csdn.net/Jarvis_wxy/ar ...
- PHP fsockopen模拟POST/GET方法
原文链接:http://www.nowamagic.net/academy/detail/12220214 fsockopen 除了前面小节的模拟生成 HTTP 连接之外,还能实现很多功能,比如模拟 ...