作为机器学习的小白和matlab的小白自己参照 python的 《机器学习实战》 写了一下分类回归树,这里记录一下。

关于决策树的基础概念就不过多介绍了,至于是分类还是回归。。我说不清楚。。我用的数据集是这个http://archive.ics.uci.edu/ml/datasets/Abalone 就是通过一些属性来预测鲍鱼有多少头,下面看一下

Length / continuous / mm / Longest shell measurement 
Diameter / continuous / mm / perpendicular to length 
Height / continuous / mm / with meat in shell 
Whole weight / continuous / grams / whole abalone 
Shucked weight / continuous / grams / weight of meat 
Viscera weight / continuous / grams / gut weight (after bleeding) 
Shell weight / continuous / grams / after being dried 
Rings / integer / -- / +1.5 gives the age in years

这些属性除了最后的Rings是整数,可以看做是离散的,其他都是浮点数,是连续的。所以还是用cart中二分的思想,就是小于等于分一边,大于分一边。但是没有用gini指数,因为熵还是好一点。

代码在github:https://github.com/jokermask/matlab_cart

参照《机器学习实战》代码有5个部分:getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集),chooseBestFeatureToSplit(寻找最佳分割点和阈值),createTree(建树),predict(预测)。

我按流程梳理一下,首先函数脚本来将数据集划分成,训练集和测试集,然后用训练集建树,用测试集测试,(更改后变成bootstrap sampleing)

  1. dataset = importdata('abalone.data.txt') ;
  2. origin_data = dataset.data ;
  3. labels = {'Length';'Diam';'Height'; 'Whole';'Shucked';'Viscera';'Shell';'Rings'} ;
  4. test_runtimes = ;
  5. ae = ;
  6. rr = ;
  7. for i=:test_runtimes
  8. data = sampleWithReplace(origin_data) ;%bootstrap sampling
  9. len = floor(length(data)/*) ;
  10. train_data = data(:len,:) ;
  11. test_data = data(len:end,:) ;
  12. test_y_truth = test_data(:,end) ;
  13. % tree = createTree(train_data,labels,) ;
  14. % predict_y = predict(tree,test_data,labels) ;
  15. % com_matrix = [predict_y,test_y_truth] ;
  16. % count = sum(predict_y==test_y_truth) ;
  17. % disp(com_matrix) ;
  18. % disp(mae) ;
  19. % disp(rr) ;
  20.  
  21. %plot single runtime
  22. % x = ::size(test_y_truth,) ;
  23. % plot(x,predict_y,'-b',x,test_y_truth,'-r') ;
  24.  
  25. ae = ae+sum(abs(predict_y-test_y_truth))/size(test_y_truth,) ;
  26. rr = rr+count/size(test_y_truth,) ;
  27.  
  28. %trian with office tools fitctree
  29.  
  30. std_tree = fitctree(train_data(:,:),train_data(:,end)) ;
  31. % view(std_tree) ;
  32. std_y = predict(std_tree,test_data(:,:)) ;
  33. % disp([std_y,y]) ;
  34. ae = ae+sum(abs(std_y-test_y_truth))/size(test_y_truth,) ;
  35. rr = rr+sum(std_y==test_y_truth)/size(test_y_truth,) ;
  36. end
  37. mae = mae / test_runtimes ;
  38. mrr = rr / test_runtimes ;
  39. disp('mae') ;
  40. disp(mae) ;
  41. disp('mrr') ;
  42. disp(mrr) ;

createTree函数:由于matlab没有指针,所以只能写成嵌套结构,就像tree{tree{tree}}这样。我们是递归实现的,但怎么样才会停止建树?条件是当前节点所有标签的类别一样,比如rings都为10,那说明这一个子集已经纯了,或者是这颗树的高度已经超出了我们设的阈值,就停止,第二种情况很可能当前节点下的数据集不纯,我们就找一个出现频率最高的类别代表该节点

  1. function [ tree ] = createTree( dataset,labels,heightcount )
  2. len = size(dataset,) ;
  3. templabel = dataset(,end) ;
  4. tree = templabel ;
  5. max_depth = ;%最大树高
  6. flag = ; %判断是否数据集中所有标签都一致了(纯的),是则返回
  7. for i=:len
  8. if templabel~=dataset(i,end) ;
  9. flag = ;
  10. end
  11. end
  12. if flag==
  13. return ;
  14. end
  15. if heightcount>max_depth
  16. labelVec = dataset(:,end) ;
  17. disp(labelVec) ;
  18. element = :max(labelVec) ;
  19. counts = histc(labelVec,element) ;
  20. [~,max_idx] = max(counts) ;
  21. tree = element(max_idx) ;
  22. return ;
  23. end
  24. [bestFeat,bestT] = chooseBestFeatureToSplit(dataset) ;
  25. bestFeatLabel = labels{bestFeat} ;
  26. tree = struct ;%struct储存树结构
  27. tree.bestFeatLabel = bestFeatLabel ;
  28. tree.bestT = bestT ;
  29. tree.greaterthan = createTree(splitDataset(dataset,bestFeat,bestT,),labels,heightcount+) ;%大于阈值部分的子树
  30. tree.lessthan = createTree(splitDataset(dataset,bestFeat,bestT,),labels,heightcount+) ;%小于阈值部分的子树
  31. end

chooseBestFeatureToSplit函数:在createTree时,每次递归都要找那个当前最佳的特征和阈值,也就是调用chooseBestFeatureToSplit函数,所以两层循环,第一层遍历每个属性,第二层本应该遍历每个属性下的值,但是那样计算量太大了,所以我就将值排序之后分成10端取中位数遍历,在里面找阈值,如果当前节点的数据子集已经不足10个里,那就把所有属性都遍历一哈

  1. function [ bestFeat,bestT ] = chooseBestFeatureToSplit( dataset )
  2. [~,numFeats] = size(dataset) ;
  3. numFeats = numFeats- ;%除去标签那一列
  4. baseEnt = getEnt(dataset) ;
  5. baseInfoGain = ;
  6. bestFeat = - ;
  7. for i=:numFeats
  8. featVec = dataset(:,i) ;
  9. %由于值是连续的,所以对于特征向量组排序分成n段取中位数
  10. sortedFeatVec = sort(featVec,'ascend') ;
  11. lengthofT = floor(sqrt(length(sortedFeatVec))) ; %取向量长度开根号来确定阈值的个数
  12. if lengthofT<
  13. lengthofT = length(sortedFeatVec) ;
  14. selectedFeat = sortedFeatVec ;
  15. else
  16. step = floor(length(sortedFeatVec)/lengthofT) ;
  17. selectedFeat = zeros(lengthofT,) ;
  18. for j=:lengthofT
  19. head = (j-)*step+ ;
  20. tail = j*step ;
  21. subSortedFeatVec = sortedFeatVec(head:tail) ;
  22. selectedFeat(j) = median(subSortedFeatVec) ;
  23. end
  24. end
  25. for k=:lengthofT
  26. newEnt = ;
  27. for l=:
  28. subDataset = splitDataset(dataset,i,selectedFeat(k),l) ;
  29. prob = size(subDataset,)/size(dataset,) ;
  30. newEnt = newEnt + prob*getEnt(subDataset) ;
  31. end
  32. infoGain = baseEnt - newEnt ;
  33. % disp('infoGain') ;
  34. % disp(infoGain) ;
  35. if(infoGain>baseInfoGain)
  36. baseInfoGain = infoGain ;
  37. bestFeat= i ;
  38. bestT = selectedFeat(k) ;
  39. end
  40. end
  41. end
  42. end

计算信息增益(infoGain)的时候需要用到getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集)函数

splitDataset:

  1. function [ retDataset ] = splitDataset(dataset,axis,value,arg )
  2. %axis 代表键值的位置 value表示阈值 返回划分后的dataset,arg表示取大于的部分()还是小于等于的部分
  3. if arg==
  4. retDataset = dataset(dataset(:,axis)>value,:) ;
  5. else
  6. retDataset = dataset(dataset(:,axis)<=value,:) ;
  7. end
  8. end

getEnt:

  1. function [ ent ] = getEnt( data )
  2. %index present the label
  3. [datalen,~] = size(data) ;
  4. maxLabel = max(data(:,end)) ;
  5. labelCountsMap = zeros(maxLabel,) ;%rings are all numbers
  6. for i=:datalen
  7. label = data(i,end) ;
  8. if labelCountsMap(label)~=
  9. labelCountsMap(label) = labelCountsMap(label) + ;
  10. else
  11. labelCountsMap(label) = ;
  12. end
  13. end
  14. ent = ;
  15. % disp('labelMap') ;
  16. % disp(labelCountsMap) ;
  17. for i=:maxLabel
  18. if labelCountsMap(i)~=
  19. prob = labelCountsMap(i)/datalen ;
  20. ent = ent - prob*log2(prob) ;
  21. end
  22. end
  23. end

最后预测函数:

  1. function [ classVec ] = predict( tree , dataset , labels)
  2. %tree应由createTree函数生成
  3. len = size(dataset,) ;
  4. classVec = zeros(len,) ;
  5. for i=:len
  6. dataVec = dataset(i,:end-) ;
  7. tempnode = tree ;
  8. while(isstruct(tempnode))
  9. [~,tempFeatIdx] = ismember(tempnode.bestFeatLabel,labels) ;
  10. if(dataVec(tempFeatIdx)>tempnode.bestT)
  11. tempnode = tempnode.greaterthan ;
  12. else
  13. tempnode = tempnode.lessthan ;
  14. end
  15. end
  16. classVec(i) = tempnode ;
  17. end
  18. end

更新了一下代码,加入了boostrap采样,就是有放回的采样,我是这样采用的,有多少个样本就进行多少次有放回采样,然后这个过程进行50次求均值。用了之后,官方的库正确率道理44%,而我的还在30%。。差距一下突显,还需继续学习。。

补充一下那个sampleWithReplace函数

  1. function [ sample_data ] = sampleWithReplace( dataset )
  2. len = size(dataset,) ;
  3. randidx = randsample(len,len,true) ;
  4. sample_data = dataset(randidx,:) ;
  5. end

matlab实现cart(回归分类树)的更多相关文章

  1. 连续值的CART(分类回归树)原理和实现

    上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现 思考连续值和离散值的不同之处: 二分子树的时候不同:离散值需要求出最优的两个 ...

  2. 用cart(分类回归树)作为弱分类器实现adaboost

    在之前的决策树到集成学习里我们说了决策树和集成学习的基本概念(用了adaboost昨晚集成学习的例子),其后我们分别学习了决策树分类原理和adaboost原理和实现, 上两篇我们学习了cart(决策分 ...

  3. CART回归树

    决策树算法原理(ID3,C4.5) 决策树算法原理(CART分类树) 决策树的剪枝 CART回归树模型表达式: 其中,数据空间被划分为R1~Rm单元,每个单元有一个固定的输出值Cm.这样可以计算模型输 ...

  4. 决策树算法原理(CART分类树)

    决策树算法原理(ID3,C4.5) CART回归树 决策树的剪枝 在决策树算法原理(ID3,C4.5)中,提到C4.5的不足,比如模型是用较为复杂的熵来度量,使用了相对较为复杂的多叉树,只能处理分类不 ...

  5. 大白话5分钟带你走进人工智能-第二十六节决策树系列之Cart回归树及其参数(5)

                                                    第二十六节决策树系列之Cart回归树及其参数(5) 上一节我们讲了不同的决策树对应的计算纯度的计算方法, ...

  6. 机器学习实战---决策树CART回归树实现

    机器学习实战---决策树CART简介及分类树实现 一:对比分类树 CART回归树和CART分类树的建立算法大部分是类似的,所以这里我们只讨论CART回归树和CART分类树的建立算法不同的地方.首先,我 ...

  7. 机器学习实战---决策树CART简介及分类树实现

    https://blog.csdn.net/weixin_43383558/article/details/84303339?utm_medium=distribute.pc_relevant_t0. ...

  8. sklearn 学习之分类树

    概要 基于 sklearn 包自带的 iris 数据集,了解一下分类树的各种参数设置以及代表的意义.   iris 数据集介绍 iris 数据集包含 150 个样本,对应数据集的每行数据,每行数据包含 ...

  9. sklearn CART决策树分类

    sklearn CART决策树分类 决策树是一种常用的机器学习方法,可以用于分类和回归.同时,决策树的训练结果非常容易理解,而且对于数据预处理的要求也不是很高. 理论部分 比较经典的决策树是ID3.C ...

随机推荐

  1. Qt使用QNetworkAccessManager实现Http操作

    版权声明:若无来源注明,Techie亮博客文章均为原创. 转载请以链接形式标明本文标题和地址: 本文标题:Qt使用QNetworkAccessManager实现Http操作     本文地址:http ...

  2. Destoon 模板存放规则 及 语法参考

    模板存放规则及语法参考 一.模板存放及调用规则 模板存放于系统 template 目录,template 目录下的一个目录例如 template/default/ 即为一套模板 模板文件以 .htm ...

  3. linux 内核态调试函数BUG_ON()[转]

    一些内核调用可以用来方便标记bug,提供断言并输出信息.最常用的两个是BUG()和BUG_ON(). 当被调用的时候,它们会引发oops,导致栈的回溯和错误信息的打印.为什么这些声明会导致 oops跟 ...

  4. mysql导出/导入表结构以及表数据

    导出: 命令行下具体用法如下:  mysqldump -u用戶名 -p密码 -d 数据库名 表名 脚本名; 1.导出数据库为dbname的表结构(其中用戶名为root,密码为dbpasswd,生成的脚 ...

  5. 第197天:js---caller、callee、constructor和prototype用法

    一.caller---返回函数调用者 //返回函数调用者 //caller的应用场景 主要用于察看函数本身被哪个函数调用 function fn() { //判断某函数是否被调用 if (fn.cal ...

  6. Vue.js 判断对象属性是否存,不存在添加

    Vue.set是可以对对象添加属性的,这里item对象添加一个checked属性 //if(typeof item.checked=='undefined'){if(!this.item.checke ...

  7. 洛谷P2740 [USACO4.2]草地排水Drainage Ditches

    题目背景 在农夫约翰的农场上,每逢下雨,贝茜最喜欢的三叶草地就积聚了一潭水.这意味着草地被水淹没了,并且小草要继续生长还要花相当长一段时间.因此,农夫约翰修建了一套排水系统来使贝茜的草地免除被大水淹没 ...

  8. 【比赛】HNOI2018 毒瘤

    虚树+dp 直接看zlttttt的强大题解 zlttttt的题解看这里 #include<bits/stdc++.h> #define ui unsigned int #define ll ...

  9. POJ.2251 Dungeon Master (三维BFS)

    POJ.2251 Dungeon Master (三维BFS) 题意分析 你被困在一个3D地牢中且继续寻找最短路径逃生.地牢由立方体单位构成,立方体中不定会充满岩石.向上下前后左右移动一个单位需要一分 ...

  10. MapReduce(三) 典型场景(一)

    一.mapreduce多job串联 1.需求 一个稍复杂点的处理逻辑往往需要多个 mapreduce 程序串联处理,多 job 的串联可以借助 mapreduce 框架的 JobControl 实现 ...