CART(Classification And Regression Tree),分类回归树,,决策树可以分为ID3算法,C4.5算法,和CART算法。ID3算法,C4.5算法可以生成二叉树或者多叉树,CART只支持二叉树,既可支持分类树,又可以作为回归树。

分类树: 基于数据判断某物或者某人的某种属性(个人理解)可以处理离散数据,就是有限的数据,输出样本的类别

回归树: 给定了数据,预测具体事物的某个值;可以对连续型的数据进行预测,也就是数据在某个区间内都有取值的可能,它输出的是一个数值

CART 分类树的工作流程

CART和C4.5算法类似,知识属性选择的指标采用的是基尼系数,基尼系数本身反应了样本的不确定度,当基尼系数越小的时候,说明样本之间的差异性小,不确定度低。分类的过程是一个不确定度降低的过程,即纯度提升的过程,所以构造分类树的时候会基于基尼系数最小的属性作为划分。

了解基尼系数:

假设t为节点,那么该节点的GINI系数的计算公式为:

p(Ck|t) 表示t属性类别Ck的概率,节点t的基尼系数为1减去各个分类Ck概率平方和  

例如集合1: 6个人去游泳,  那么p(Ck|t)=1,因此  GINI(t) = 1-1 =0

集合2      :  3个人去游泳,3个人不去,那么p(C1k|t) = 0.5 ,p(C2k|t) = 0.5

得出,集合1样本基尼系数最小,样本最稳定,2的样本不稳定性大

该公式表示节点D的基尼系数等于子节点D1,D2的归一化基尼系数之和

使用CART算法创建分类树

iris是sklearn 自带IRIS(鸢尾花)数据集sklearn中的来对特征处理功能进行说明包含4个特征(Sepal.Length(花萼长度)、Sepal.Width(花萼宽度)、Petal.Length(花瓣长度)、Petal.Width(花瓣宽度)),特征值都为正浮点数,单位为厘米

目标值为鸢尾花的分类(Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),Iris Virginica(维吉尼亚鸢尾))

  1. 1 # encoding=utf-8
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.metrics import accuracy_score
  4. from sklearn.tree import DecisionTreeClassifier
  5. from sklearn.datasets import load_iris
  6. # 准备数据集
  7. iris=load_iris()
  8. # 获取特征集和分类标识
  9. features = iris.data
  10. labels = iris.target
  11. # 随机抽取 33% 的数据作为测试集,其余为训练集 使用sklearn.model_selection train_test_split 训练
  12. train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=0)
  13. # 创建 CART 分类树
  14. clf = DecisionTreeClassifier(criterion='gini')
  15. # 拟合构造 CART 分类树
  16. clf = clf.fit(train_features, train_labels)
  17. # 用 CART 分类树做预测 得到预测结果
  18. test_predict = clf.predict(test_features)
  19. # 预测结果与测试集结果作比对
  20. score = accuracy_score(test_labels, test_predict)
  21. print("CART 分类树准确率 %.4lf" % score)
  1. CART 分类树准确率 0.9600

train_test_split 可以把数据集抽取一部分作为测试集,就可以德奥训练集和测试集

14 初始化一棵cart树,16 训练集的特征值和分类表示作为参数进行拟合得到cart分类树

cart回归树的工作流程

cart回归树划分数据集的过程和分类树的过程是一样的,回归树得到的预测结果是连续值,评判不纯度的指标不同,分类树采用的是基尼系数,回归树需要根据样本的离散程度来评价 不纯度

样本离散程度计算方式,每个样本值到均值的差值,可以去差值的绝对值,或者方差

         方差为每个样本值减去样本均值的平方和除以样本可数

最小绝对偏差(LAD) 最小二乘偏差

如何使用CART回归树做预测

这里使用sklearn字典的博士度房价数据集,该数据集给出了影响房价的一些指标,比如犯罪了房产税等,最后给出了房价

  1. # encoding=utf-8
  2. from sklearn.metrics import mean_squared_error
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.datasets import load_boston
  5. from sklearn.metrics import r2_score,mean_absolute_error,mean_squared_error
  6. from sklearn.tree import DecisionTreeRegressor
  7. # 准备数据集
  8. boston=load_boston()
  9. # 探索数据
  10. print(boston.feature_names)
  11. # 获取特征集和房价
  12. features = boston.data
  13. prices = boston.target
  14. # 随机抽取 33% 的数据作为测试集,其余为训练集
  15. train_features, test_features, train_price, test_price = train_test_split(features, prices, test_size=0.33)
  16. # 创建 CART 回归树
  17. dtr=DecisionTreeRegressor()
  18. # 拟合构造 CART 回归树
  19. dtr.fit(train_features, train_price)
  20. # 预测测试集中的房价
  21. predict_price = dtr.predict(test_features)
  22. # 测试集的结果评价
  23. print('回归树二乘偏差均值:', mean_squared_error(test_price, predict_price))
  24. print('回归树绝对值偏差均值:', mean_absolute_error(test_price, predict_price))
  1. ['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'
  2. 'B' 'LSTAT']
  3. 回归树二乘偏差均值: 32.065568862275455
  4. 回归树绝对值偏差均值 3.2892215568862277

cart决策树的剪枝

cart决策树剪枝采用的是CCP方法,一种后剪枝的方法,cost-complexity prune 中文:代价复杂度,这种剪枝用到一个指标 叫做  节点的表面误差率增益值,以此作为剪枝前后误差的定义

 Tt 代表以t为根节点的子树,C(Tt)表示节点t的子树没被裁剪时子树Tt的误差,C(t)表示节点t的子树被剪枝后节点t的误差,|Tt|代子树Tt的叶子树,剪枝后,T的叶子树减一

所以节点的表面误差率增益值 等于 节点t的子树被剪枝后的误差变化除以 减掉的叶子数量

因此希望剪枝前后误差最小,所以我们要寻找就是最小α值对应的节点,把它减掉。生成第一个子树,重复上面过程继续剪枝,知直到最后为根节点,即为最后一个子树

得到剪枝后的子树集合后,我们需要采用验证集对所有子树的误差计算一遍,可以计算每个子树的基尼指数或平房误差,去最小的那棵树

python数据分析算法(决策树2)CART算法的更多相关文章

  1. 决策树2 -- CART算法

    声明: 1,本篇为个人对<2012.李航.统计学习方法.pdf>的学习总结.不得用作商用,欢迎转载,但请注明出处(即:本帖地址). 2,因为本人在学习初始时有非常多数学知识都已忘记.所以为 ...

  2. 决策树之CART算法

    顾名思义,CART算法(classification and regression tree)分类和回归算法,是一种应用广泛的决策树学习方法,既然是一种决策树学习方法,必然也满足决策树的几大步骤,即: ...

  3. 《机器学习实战》学习笔记第九章 —— 决策树之CART算法

    相关博文: <机器学习实战>学习笔记第三章 —— 决策树 主要内容: 一.CART算法简介 二.分类树 三.回归树 四.构建回归树 五.回归树的剪枝 六.模型树 七.树回归与标准回归的比较 ...

  4. 简单易学的机器学习算法——决策树之ID3算法

    一.决策树分类算法概述     决策树算法是从数据的属性(或者特征)出发,以属性作为基础,划分不同的类.例如对于如下数据集 (数据集) 其中,第一列和第二列为属性(特征),最后一列为类别标签,1表示是 ...

  5. 02-23 决策树CART算法

    目录 决策树CART算法 一.决策树CART算法学习目标 二.决策树CART算法详解 2.1 基尼指数和熵 2.2 CART算法对连续值特征的处理 2.3 CART算法对离散值特征的处理 2.4 CA ...

  6. 机器学习——十大数据挖掘之一的决策树CART算法

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第23篇文章,我们今天分享的内容是十大数据挖掘算法之一的CART算法. CART算法全称是Classification ...

  7. 决策树-预测隐形眼镜类型 (ID3算法,C4.5算法,CART算法,GINI指数,剪枝,随机森林)

    1. 1.问题的引入 2.一个实例 3.基本概念 4.ID3 5.C4.5 6.CART 7.随机森林 2. 我们应该设计什么的算法,使得计算机对贷款申请人员的申请信息自动进行分类,以决定能否贷款? ...

  8. 决策树模型 ID3/C4.5/CART算法比较

    决策树模型在监督学习中非常常见,可用于分类(二分类.多分类)和回归.虽然将多棵弱决策树的Bagging.Random Forest.Boosting等tree ensembel 模型更为常见,但是“完 ...

  9. 机器学习技法-决策树和CART分类回归树构建算法

    课程地址:https://class.coursera.org/ntumltwo-002/lecture 重要!重要!重要~ 一.决策树(Decision Tree).口袋(Bagging),自适应增 ...

随机推荐

  1. spring集成Junit做单元测试及常见异常解决办法

    spring-test依赖包 <!--Spring-test --> <!-- https://mvnrepository.com/artifact/org.springframew ...

  2. 用WKWebView 截取整个Html页面

    以前使用UIWebview时,想截取整个页面,可以调整内部scrollView的frame,之后调用 scrollView的layer的 render 方法,很方便. 但是在WKWebView上,行不 ...

  3. LVS原理详解(3种工作方式8种调度算法)--老男孩

    一.LVS原理详解(4种工作方式8种调度算法) 集群简介 集群就是一组独立的计算机,协同工作,对外提供服务.对客户端来说像是一台服务器提供服务. LVS在企业架构中的位置: 以上的架构只是众多企业里面 ...

  4. bzoj 2780

    后缀自动机的应用 首先我们观察到:如果一个询问串的答案不为0,那么这个串一定是至少一个模式串的子串 如果只有一个模式串,那么这个问题可以简单地用什么东西解决掉(比如普通后缀自动机) 而这里有很多模式串 ...

  5. SSL通信-忽略证书认证错误

    .NET的SSL通信过程中,使用的证书可能存在各种问题,某种情况下可以忽略证书的错误继续访问.可以用下面的方式跳过服务器证书验证,完成正常通信. 1.设置回调属性ServicePointManager ...

  6. Postman 安装及使用入门教程(我主要使用接口测试)

    1.Postman 安装及使用入门教程(我主要使用接口测试)Postman的English官网:https://www.getpostman.com/chrome插件整理的Postman中文使用教程( ...

  7. RabbitMQ 学习日记

    RabbitMQ三种Exchange模式(fanout,direct,topic)的性能比较 http://www.rabbitmq.com/tutorials/tutorial-one-dotnet ...

  8. Datatables插件1.10.15版本服务器处理模式ajax获取分页数据实例解析

    一.问题描述 前端需要使用表格来展示数据,找了一些插件,最后确定使用dataTables组件来做. 后端的分页接口已经写好了,不能修改.接口需要传入页码(pageNumber)和页面显示数据条数(pa ...

  9. 【BZOJ4589】Hard Nim(FWT)

    题解: 由博弈论可以知道题目等价于求这$n$个数$\^$为0 快速幂$+fwt$ 这样是$nlog^2$的 并不能过 而且得注意$m$的数组$\^$一下会生成$2m$ #include <bit ...

  10. [数据结构] 快速排序C语言程序

    //由大到小//快速排序(待排序数组,左侧起点,右侧起点) void quickSort(int *array, int l, int r) { if ( l >= r) return; int ...