大白话5分钟带你走进人工智能-第二十六节决策树系列之Cart回归树及其参数(5)
第二十六节决策树系列之Cart回归树及其参数(5)
上一节我们讲了不同的决策树对应的计算纯度的计算方法,其实都是针对分类来说,本节的话我们讲解回归树的部分。
目录
1-Cart回归树的概念
对于回归树来说,之前咱们讲的三个决策树(ID3,C4.5和Cart树)里只有CART树具有回归上的意义,其实它无非就是把分裂条件给变了变,把叶子节点的表达给变了变。剩下的全部过程都是和分类树没有区别的。它的分裂条件变成什么了呢?分裂条件仍然是通过遍历维度搜索。当你搜索完了,尝试分裂,你要评估这次分裂是好还是不好的时候?不能再使用Gini系数和信息熵了。因为每一个样本跟每一个样本之间的结果都不一样。你想你原来是怎么算信息熵和Gini系数的?先看看我这叶子连接节点有几类数据,把它们分别统计一下,算出一个数。而回归问题,它的y lable有一样的吗?应该说没有一样的。这种情况下肯定不能用刚才那个Gini系数和熵来做了。那用什么呢,用mse来统计。
举例比如下图:
根节点里有100个数据我尝试分裂。分裂出两支来, 一分支是有60个数据。另一支有40个数据。此时怎么评估这次的分裂效果呢?先计算这60条数据的y的均值。然后用这60条数据的每一个真实的y减去y的均值加平方求和除以60。就得出了这个叶子节点里边的平均mse,能够理解吧?那么右边一样,先是计算出40条y平均。用每一条y减去这个y的平均加平方求和,最后乘以各自的权重,还是要乘以一个1/60和1/40的。那么你多次尝试分裂是不是就得到 或者你去想它会把y比较相近的一些节点分到同一个节点里边去,对不对? 所以这就是回归树的计算流程。评估每次分裂效果的指标我们叫它mse,它实际上是方差。就是一个集合里边的每一个数减去均值平方通通加起来再除以数目本身,假如有十个数,求这10个数的方差,首先要求出它的均值μ,用每一个数减去μ的差的平方,再相加,除一个1/10。这就是这个集合的方差。方差是一个统计学的指标,它描述的是什么?是这一组数据的离散程度。你方差越大代表这个数据里边天差地别,对吗?天南海北。方差越小,代表这一组数据非常紧密。彼此之间都差不了多少。那我们既然要做回归问题,我最终希望落到这个叶子节点里边的lable越近越好还是越远越好?那肯定是越近越好对吧。我分着分着,越分越近,越分越近,最后得到的叶子结点都是最近的那些落到同一个叶子节点。那未来预测的时候怎么办?它落到某一个叶子节点了。这个叶子节点是不是不知道应该给它输出多少值啊?它会输出多少呢?平均值。能够理解吗?也就是说这个东西回归分析做出来之后,它是锯齿状的。能够理解吗?锯齿状的一个回归分析。例如下图:
就是因为x都落在同一个叶子节点里边输出一个均值。而不象参数型模型了,按理来说,你只要变一点儿,那么y的结果多多少少都会变一点儿。而这个你的x只要变了一点儿就会影响到你最终落到那一个叶子节点。这样你给的输出是不是就都是一样的了。所以对于回归树来说还是那四个问题。
一、它分几支,我们刚才看了这个数分两支,对不对?
二、它怎么判断分裂条件。从原来的Gini系数变成了方差。或者说变成了mse
三、它什么时候停止?还是那些预剪枝的过程。我们后面会讲。
四、叶子节点怎么表达。从原来的投票算概率变成了算平均值。就是这么简单。
1-代码详解
我们来看下决策树的应用代码:
import pandas as pd import numpy as np from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_graphviz from sklearn.tree import DecisionTreeRegressor from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt import matplotlib as mpl #读取iris数据集 iris = load_iris() #iris['data'] ['target'] # 读取数据集 data = pd.DataFrame(iris.data) data.columns = iris.feature_names print(data.columns ) data['Species'] = load_iris().target print(data) print(data.shape) # #取数据帧的前四列(所有行)也就是X x = data.iloc[:, :4] # 花萼长度和宽度 # x = data.iloc[:, :4] # 花萼长度和宽度 #取数据帧的最后一列(所有行)也就y y = data.iloc[:, -1] # print(type(x),1) # y = pd.Categorical(data[4]).codes # print(x) # print(y) # #训练集和测试集的划分 x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.75, random_state=42) tree_clf = DecisionTreeClassifier(max_depth=6, criterion='entropy') tree_clf.fit(x_train, y_train) y_test_hat = tree_clf.predict(x_test) print("acc score:", accuracy_score(y_test, y_test_hat)) # # export_graphviz( # tree_clf, # out_file="./iris_tree.dot", # feature_names=iris.feature_names[:2], # class_names=iris.target_names, # rounded=True, # filled=True # ) # # # ./dot -Tpng ~/PycharmProjects/mlstudy/bjsxt/iris_tree.dot -o ~/PycharmProjects/mlstudy/bjsxt/iris_tree.png print(tree_clf.predict_proba([[5, 1.5]])) print(tree_clf.predict([[5, 1.5]])) # RandomForestClassifier #生成一个数组 depth = np.arange(1, 15)#不同的深度对决策树的影响 err_list = [] for d in depth: clf = DecisionTreeClassifier(criterion='entropy', max_depth=d)#预剪枝 clf.fit(x_train, y_train) y_test_hat = clf.predict(x_test) result = (y_test_hat == y_test) # 生成一个长度为验证集数量的数组,每一个元素是yhat和y是否相等的结果, print(list(result)) if d == 1: print(result) #生成错误率 err = 1 - np.mean(result) print(100 * err) err_list.append(err) print(d, ' 错误率:%.2f%%' % (100 * err)) plt.figure(facecolor='w') plt.plot(depth, err_list, 'ro-', lw=2) plt.xlabel('决策树深度', fontsize=15) plt.ylabel('错误率', fontsize=15) plt.title('决策树深度和过拟合', fontsize=18) plt.grid(True) # plt.show() from sklearn import tree X = [[0, 0], [2, 2]] y = [0.5, 2.5] clf = tree.DecisionTreeRegressor()#回归树 clf = clf.fit(X, y) clf.predict([[1, 1]]) # tree_reg = DecisionTreeRegressor(max_depth=2) # # tree_reg.fit(X, y)
解释下上面代码:
1、from sklearn.datasets import load_iris iris = load_iris(),iris['data'] ['target']。这个iris里边就包含了iris(data)和(target),这里边有两种调用它的方式。一种你可以写iris.Data,一种还有这种字典的方式,索引data,实际上sklearn把我们这两种风格的ATI都保留下来了。
2、我们在这儿引入了一个工具叫pandas,我们之前简单的讲了讲numpy就是一个简单的玩数组的东西,而pandas就是对numpy简单的进行了一个加强。原来的numpy是一个数组,pandas给每一列数组起了一个名字。比如说data是数组。你想调用其中一个元素,用numoy来去你就必须写data[0,0],而pandas分别给行和列起了索引号。可以使用名称来更灵活的调用它,这是其一,也是最根本的区别。其次,pandas里边集成了很多方面的数据操作的东西。这两个就是一个简单的tool就是两个简单的工具。你学Excel有多难学它就有多难能明所以它不是很复杂的东西。pandas里边有一个对象叫dataframe实际上是叫数据帧,数据帧就是一个带名称的二维数组。二维数组只有索引号。而dataframe加了一个名称。data = pd.DataFrame(iris.data),我们把这个iris里边儿的data拿出来,它是一个numpy数组。二维数组。data扔到dataframe中返回的一个什么东西呢?返回一个panda里边叫df的对象。那df对象有两个属性。一个叫columns,是指它这个列的名称。一个叫index是指行的名称。
3、然后我们通过train_test_split 这个工具来划分出验证集和测试集。然后我们新建一个对象叫做DecisionTreeClassifier,然后我们可以看到它实际上有两个类是决策树的。DecisionTreeClassifier,DecisionTreeRegressor。分别是什么意思呢?不用我说大家是不是已经明白了?一个是用来做分类的,一个是用来做回归的。
4、我们看下DecisionTreeClassifier的超参数有哪些呢?
criterion,是拿什么东西来评价的标准。可以取值gini,Gini越高它越不纯。也可以取值entropy评估的是信息增益。
splitter,取值Best是找到最好的那个分裂。取值random是找到最好的随机分裂,也就是说它随机多少次之后,把随机出来过的最好的结果给你。相当于一个加速运算的东西。相当于找到了一个随机出来的最优解,有点像随机梯度的意思。
max-depth,树的最大深度。这个可以说是我们最常用的预剪枝的操作手段。我们很多时候不去设置那些细枝末节的规则。仅仅设一下树的最大深度,就是你分裂多少层就不要再继续分裂了。我管你分的好分不好,你都不要再分裂了。
min_samples_split ,除了根结点和叶子结点其他中间的那些节点分裂一个所需要的最小的样本量默认是2。意思是这些节点要分裂所需的最小样本数是2。
min_samples_leaf,叶节点最小样本数。
min_weight_fraction_leaf,就是你这个叶子结点占总的比例有多少能成为叶子结点,这个比较有意思。
max_features,就是说在你寻找最佳切割点的时候要不要考虑所有的维度咱们本来是不是遍历所有的维度?现在改成随机取几个维度遍历。不取全了,能明白我的意思吗?因为它要分裂很多层,虽然第一层没有考虑到这个维度。第二层的时候有可能就考虑到了,如果是default=none那么如果是none什么意思啊?全部的维度都要进来去考虑如果你是int那么是什么意思呢?你传一个整形进来。就是每次在寻找切割的时候就随机的找到。你比如说乘以六。那么它就随机找的六个维度去考虑,寻找最佳切割点。那这样就肯定会变的不准了,但会变得更快了。然后如果你传一个浮点数过来那么实际上是百分比。你比如说传一个0.6。就是你每一次分裂的时候,就随机挑选出60%的维度出来,来寻找最佳切割点,如果是auto,是开个根号。比如说你有100个维度。我就给你整十个纬度。能够理解吗?sqrt也是开根号。Log2是取个log2然后再取,默认的通常是就选none,有多少我就考虑多少。
max_leaf_nodes ,最多的叶子节点也是,如果叶子节点够多了,你就不用再分裂了。
min_impurity_decrease ,我们的目的是gini系数必须得变小,我才让你分裂。你原来的分裂越大,当传入一个float进来,必须要gini系数必须要缩小多少才能进行此次分裂。
class-weigh,我们可以看到所有的函数,分类器也好,回归器也好,都会有这个参数。class-weight代表什么呢?代表每类样本,你到底有多么看重它?它的目的是将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。
5、我们创造一个这个DecisionTreeClassifier,然后输出acc score,然后输出我们再验证集上的准确度,达到了97%,能看到吗?比之前咱们的逻辑回归,训练及效果要好不少。
6、for d in depth,以及下面的代码大致意思是,我把树的深度从一到15遍历了一遍。然后分别画出这15棵树到底错误率是多少。训练15个模型,我们可以看,随着树的深度增加,在4的时候验证集错误率最低,但是后来随着深度的增加,反倒又上升了。这就有一点点过拟合的意思,但在咱们这个数据集里很难形成过拟合。因为总共才150条数据。但是这个东西你可以看到,不是说树越深在验证集上效果就表现得越好。
下一节里面我们会讲解决策树的另一个问题即什么时候停止的问题。
大白话5分钟带你走进人工智能-第二十六节决策树系列之Cart回归树及其参数(5)的更多相关文章
- 大白话5分钟带你走进人工智能-第二十九节集成学习之随机森林随机方式 ,out of bag data及代码(2)
大白话5分钟带你走进人工智能-第二十九节集成学习之随机森林随机方式 ,out of bag data及代码(2) 上一节中我们讲解了随机森林的基本概念,本节的话我们讲解随机森 ...
- 大白话5分钟带你走进人工智能-第二十节逻辑回归和Softmax多分类问题(5)
大白话5分钟带你走进人工智能-第二十节逻辑回归和Softmax多分类问题(5) 上一节中,我们讲 ...
- 大白话5分钟带你走进人工智能-第十四节过拟合解决手段L1和L2正则
第十四节过拟合解决手段L1和L2正则 第十三节中, ...
- 大白话5分钟带你走进人工智能-第十五节L1和L2正则几何解释和Ridge,Lasso,Elastic Net回归
第十五节L1和L2正则几何解释和Ridge,Lasso,Elastic Net回归 上一节中我们讲解了L1和L2正则的概念,知道了L1和L2都会使不重要的维度权重下降得多,重要的维度权重下降得少,引入 ...
- 大白话5分钟带你走进人工智能-第32节集成学习之最通俗理解XGBoost原理和过程
目录 1.回顾: 1.1 有监督学习中的相关概念 1.2 回归树概念 1.3 树的优点 2.怎么训练模型: 2.1 案例引入 2.2 XGBoost目标函数求解 3.XGBoost中正则项的显式表达 ...
- 大白话5分钟带你走进人工智能-第30节集成学习之Boosting方式和Adaboost
目录 1.前述: 2.Bosting方式介绍: 3.Adaboost例子: 4.adaboost整体流程: 5.待解决问题: 6.解决第一个问题:如何获得不同的g(x): 6.1 我们看下权重与函数的 ...
- 大白话5分钟带你走进人工智能-第31节集成学习之最通俗理解GBDT原理和过程
目录 1.前述 2.向量空间的梯度下降: 3.函数空间的梯度下降: 4.梯度下降的流程: 5.在向量空间的梯度下降和在函数空间的梯度下降有什么区别呢? 6.我们看下GBDT的流程图解: 7.我们看一个 ...
- 大白话5分钟带你走进人工智能-第三节最大似然推导mse损失函数(深度解析最小二乘来源)(1)
第三节最大似然推导mse损失函数(深度解析最小二乘来源) 在第二节中,我们介绍了高斯分布的 ...
- 大白话5分钟带你走进人工智能-第35节神经网络之sklearn中的MLP实战(3)
本节的话我们开始讲解sklearn里面的实战: 先看下代码: from sklearn.neural_network import MLPClassifier X = [[0, 0], [1, 1]] ...
随机推荐
- mooc课程mit 6.00.1x--problem set2解决方法
PAYING THE MINIMUM 计算每月信用卡最低还款额及剩余应还款额 balance = 4842 #还款额 annualInterestRate = 0.2 #年利息比率 monthlyPa ...
- java中的clone方法
Java中对象的创建 clone顾名思义就是复制, 在Java语言中, clone方法被对象调用,所以会复制对象.所谓的复制对象,首先要分配一个和源对象同样大小的空间,在这个空间中创建一个新的对象.那 ...
- 扫盲--.net 程序集
前言:用了几天的时间把高级编程里面程序集一章看完了,原来自己只知道写代码,右键添加引用,从来也不知道操作的实质是什么,微软总是这个套路,鼠标点点就能把任务完成,这对新手友好但是对要通透了解程序执行和内 ...
- linux 防火墙配置与REJECT导致没有生效问题
1.进入到/etc/sysconfig 如图 2.使用vi命令对iptables进行编辑."vi iptables",然后显示如图 # Firewall configuration ...
- Linux 设备和模块的分类
概念:在Linux系统中,所有设备都被映射成 [设备文件] 来处理,设备文件,应用程序可以像操作普通文件一样对硬件设备进行操作. 一.设备类型 整理自:(相当不错,建议有时间看下原文) <第一章 ...
- Java多线程系列 基础篇06 synchronized(同步锁)
转载 http://www.cnblogs.com/paddix/ 作者:liuxiaopeng http://www.infoq.com/cn/articles/java-se-16-synchro ...
- 冷门PHP函数汇总
概述 整理一些日常生活中基本用不到的PHP函数,也可以说在框架内基本都内置了,无需我们去自行使用的函数.量不多.后续在日常开发中如遇到更多的冷门,会更新本文章 sys_getloadavg 获取系统的 ...
- POJ 2348 Euclid Game (模拟题)
Euclid's Game Time Limit: 1000MS Memory Limit: 65536K Total Submissions: 7942 Accepted: 3227 Des ...
- 无法定位程序输入点glPopAttrib于动态连结库OPENGL.dll上
已经下载glut.lib glut32.lib glut.h glut.dll glut32.dll并放到了相应的文件夹中运行程序时还提示说缺少opengl.dll,我又下载了一个opengl.dll ...
- matlab之sortrows()函数
sortrows()函数的格式: sortrows(A,column) A是一个矩阵,如果没有第二个参数column,则默认按照第一列升序排列,如果遇到重复数字,则按照第二列升序排列,依次类推... ...