matlab实现cart(回归分类树)
作为机器学习的小白和matlab的小白自己参照 python的 《机器学习实战》 写了一下分类回归树,这里记录一下。
关于决策树的基础概念就不过多介绍了,至于是分类还是回归。。我说不清楚。。我用的数据集是这个http://archive.ics.uci.edu/ml/datasets/Abalone 就是通过一些属性来预测鲍鱼有多少头,下面看一下
Length / continuous / mm / Longest shell measurement
Diameter / continuous / mm / perpendicular to length
Height / continuous / mm / with meat in shell
Whole weight / continuous / grams / whole abalone
Shucked weight / continuous / grams / weight of meat
Viscera weight / continuous / grams / gut weight (after bleeding)
Shell weight / continuous / grams / after being dried
Rings / integer / -- / +1.5 gives the age in years
这些属性除了最后的Rings是整数,可以看做是离散的,其他都是浮点数,是连续的。所以还是用cart中二分的思想,就是小于等于分一边,大于分一边。但是没有用gini指数,因为熵还是好一点。
代码在github:https://github.com/jokermask/matlab_cart
参照《机器学习实战》代码有5个部分:getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集),chooseBestFeatureToSplit(寻找最佳分割点和阈值),createTree(建树),predict(预测)。
我按流程梳理一下,首先函数脚本来将数据集划分成,训练集和测试集,然后用训练集建树,用测试集测试,(更改后变成bootstrap sampleing)
dataset = importdata('abalone.data.txt') ;
origin_data = dataset.data ;
labels = {'Length';'Diam';'Height'; 'Whole';'Shucked';'Viscera';'Shell';'Rings'} ;
test_runtimes = ;
ae = ;
rr = ;
for i=:test_runtimes
data = sampleWithReplace(origin_data) ;%bootstrap sampling
len = floor(length(data)/*) ;
train_data = data(:len,:) ;
test_data = data(len:end,:) ;
test_y_truth = test_data(:,end) ;
% tree = createTree(train_data,labels,) ;
% predict_y = predict(tree,test_data,labels) ;
% com_matrix = [predict_y,test_y_truth] ;
% count = sum(predict_y==test_y_truth) ;
% disp(com_matrix) ;
% disp(mae) ;
% disp(rr) ;
%plot single runtime
% x = ::size(test_y_truth,) ;
% plot(x,predict_y,'-b',x,test_y_truth,'-r') ;
ae = ae+sum(abs(predict_y-test_y_truth))/size(test_y_truth,) ;
rr = rr+count/size(test_y_truth,) ;
%trian with office tools fitctree
std_tree = fitctree(train_data(:,:),train_data(:,end)) ;
% view(std_tree) ;
std_y = predict(std_tree,test_data(:,:)) ;
% disp([std_y,y]) ;
ae = ae+sum(abs(std_y-test_y_truth))/size(test_y_truth,) ;
rr = rr+sum(std_y==test_y_truth)/size(test_y_truth,) ;
end
mae = mae / test_runtimes ;
mrr = rr / test_runtimes ;
disp('mae') ;
disp(mae) ;
disp('mrr') ;
disp(mrr) ;
createTree函数:由于matlab没有指针,所以只能写成嵌套结构,就像tree{tree{tree}}这样。我们是递归实现的,但怎么样才会停止建树?条件是当前节点所有标签的类别一样,比如rings都为10,那说明这一个子集已经纯了,或者是这颗树的高度已经超出了我们设的阈值,就停止,第二种情况很可能当前节点下的数据集不纯,我们就找一个出现频率最高的类别代表该节点
function [ tree ] = createTree( dataset,labels,heightcount )
len = size(dataset,) ;
templabel = dataset(,end) ;
tree = templabel ;
max_depth = ;%最大树高
flag = ; %判断是否数据集中所有标签都一致了(纯的),是则返回
for i=:len
if templabel~=dataset(i,end) ;
flag = ;
end
end
if flag==
return ;
end
if heightcount>max_depth
labelVec = dataset(:,end) ;
disp(labelVec) ;
element = :max(labelVec) ;
counts = histc(labelVec,element) ;
[~,max_idx] = max(counts) ;
tree = element(max_idx) ;
return ;
end
[bestFeat,bestT] = chooseBestFeatureToSplit(dataset) ;
bestFeatLabel = labels{bestFeat} ;
tree = struct ;%struct储存树结构
tree.bestFeatLabel = bestFeatLabel ;
tree.bestT = bestT ;
tree.greaterthan = createTree(splitDataset(dataset,bestFeat,bestT,),labels,heightcount+) ;%大于阈值部分的子树
tree.lessthan = createTree(splitDataset(dataset,bestFeat,bestT,),labels,heightcount+) ;%小于阈值部分的子树
end
chooseBestFeatureToSplit函数:在createTree时,每次递归都要找那个当前最佳的特征和阈值,也就是调用chooseBestFeatureToSplit函数,所以两层循环,第一层遍历每个属性,第二层本应该遍历每个属性下的值,但是那样计算量太大了,所以我就将值排序之后分成10端取中位数遍历,在里面找阈值,如果当前节点的数据子集已经不足10个里,那就把所有属性都遍历一哈
function [ bestFeat,bestT ] = chooseBestFeatureToSplit( dataset )
[~,numFeats] = size(dataset) ;
numFeats = numFeats- ;%除去标签那一列
baseEnt = getEnt(dataset) ;
baseInfoGain = ;
bestFeat = - ;
for i=:numFeats
featVec = dataset(:,i) ;
%由于值是连续的,所以对于特征向量组排序分成n段取中位数
sortedFeatVec = sort(featVec,'ascend') ;
lengthofT = floor(sqrt(length(sortedFeatVec))) ; %取向量长度开根号来确定阈值的个数
if lengthofT<
lengthofT = length(sortedFeatVec) ;
selectedFeat = sortedFeatVec ;
else
step = floor(length(sortedFeatVec)/lengthofT) ;
selectedFeat = zeros(lengthofT,) ;
for j=:lengthofT
head = (j-)*step+ ;
tail = j*step ;
subSortedFeatVec = sortedFeatVec(head:tail) ;
selectedFeat(j) = median(subSortedFeatVec) ;
end
end
for k=:lengthofT
newEnt = ;
for l=:
subDataset = splitDataset(dataset,i,selectedFeat(k),l) ;
prob = size(subDataset,)/size(dataset,) ;
newEnt = newEnt + prob*getEnt(subDataset) ;
end
infoGain = baseEnt - newEnt ;
% disp('infoGain') ;
% disp(infoGain) ;
if(infoGain>baseInfoGain)
baseInfoGain = infoGain ;
bestFeat= i ;
bestT = selectedFeat(k) ;
end
end
end
end
计算信息增益(infoGain)的时候需要用到getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集)函数
splitDataset:
function [ retDataset ] = splitDataset(dataset,axis,value,arg )
%axis 代表键值的位置 value表示阈值 返回划分后的dataset,arg表示取大于的部分()还是小于等于的部分
if arg==
retDataset = dataset(dataset(:,axis)>value,:) ;
else
retDataset = dataset(dataset(:,axis)<=value,:) ;
end
end
getEnt:
function [ ent ] = getEnt( data )
%index present the label
[datalen,~] = size(data) ;
maxLabel = max(data(:,end)) ;
labelCountsMap = zeros(maxLabel,) ;%rings are all numbers
for i=:datalen
label = data(i,end) ;
if labelCountsMap(label)~=
labelCountsMap(label) = labelCountsMap(label) + ;
else
labelCountsMap(label) = ;
end
end
ent = ;
% disp('labelMap') ;
% disp(labelCountsMap) ;
for i=:maxLabel
if labelCountsMap(i)~=
prob = labelCountsMap(i)/datalen ;
ent = ent - prob*log2(prob) ;
end
end
end
最后预测函数:
function [ classVec ] = predict( tree , dataset , labels)
%tree应由createTree函数生成
len = size(dataset,) ;
classVec = zeros(len,) ;
for i=:len
dataVec = dataset(i,:end-) ;
tempnode = tree ;
while(isstruct(tempnode))
[~,tempFeatIdx] = ismember(tempnode.bestFeatLabel,labels) ;
if(dataVec(tempFeatIdx)>tempnode.bestT)
tempnode = tempnode.greaterthan ;
else
tempnode = tempnode.lessthan ;
end
end
classVec(i) = tempnode ;
end
end
更新了一下代码,加入了boostrap采样,就是有放回的采样,我是这样采用的,有多少个样本就进行多少次有放回采样,然后这个过程进行50次求均值。用了之后,官方的库正确率道理44%,而我的还在30%。。差距一下突显,还需继续学习。。
补充一下那个sampleWithReplace函数
function [ sample_data ] = sampleWithReplace( dataset )
len = size(dataset,) ;
randidx = randsample(len,len,true) ;
sample_data = dataset(randidx,:) ;
end
matlab实现cart(回归分类树)的更多相关文章
- 连续值的CART(分类回归树)原理和实现
上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现 思考连续值和离散值的不同之处: 二分子树的时候不同:离散值需要求出最优的两个 ...
- 用cart(分类回归树)作为弱分类器实现adaboost
在之前的决策树到集成学习里我们说了决策树和集成学习的基本概念(用了adaboost昨晚集成学习的例子),其后我们分别学习了决策树分类原理和adaboost原理和实现, 上两篇我们学习了cart(决策分 ...
- CART回归树
决策树算法原理(ID3,C4.5) 决策树算法原理(CART分类树) 决策树的剪枝 CART回归树模型表达式: 其中,数据空间被划分为R1~Rm单元,每个单元有一个固定的输出值Cm.这样可以计算模型输 ...
- 决策树算法原理(CART分类树)
决策树算法原理(ID3,C4.5) CART回归树 决策树的剪枝 在决策树算法原理(ID3,C4.5)中,提到C4.5的不足,比如模型是用较为复杂的熵来度量,使用了相对较为复杂的多叉树,只能处理分类不 ...
- 大白话5分钟带你走进人工智能-第二十六节决策树系列之Cart回归树及其参数(5)
第二十六节决策树系列之Cart回归树及其参数(5) 上一节我们讲了不同的决策树对应的计算纯度的计算方法, ...
- 机器学习实战---决策树CART回归树实现
机器学习实战---决策树CART简介及分类树实现 一:对比分类树 CART回归树和CART分类树的建立算法大部分是类似的,所以这里我们只讨论CART回归树和CART分类树的建立算法不同的地方.首先,我 ...
- 机器学习实战---决策树CART简介及分类树实现
https://blog.csdn.net/weixin_43383558/article/details/84303339?utm_medium=distribute.pc_relevant_t0. ...
- sklearn 学习之分类树
概要 基于 sklearn 包自带的 iris 数据集,了解一下分类树的各种参数设置以及代表的意义. iris 数据集介绍 iris 数据集包含 150 个样本,对应数据集的每行数据,每行数据包含 ...
- sklearn CART决策树分类
sklearn CART决策树分类 决策树是一种常用的机器学习方法,可以用于分类和回归.同时,决策树的训练结果非常容易理解,而且对于数据预处理的要求也不是很高. 理论部分 比较经典的决策树是ID3.C ...
随机推荐
- div跟随鼠标移动
1.目标是实现div跟随鼠标而移动,分三种情况进行实现 a)首先获取div,进行绑定鼠标移动事件,给div开启定位功能 第一种实现方式,假如body的大小跟页面大小一样,则可以用这个方法. 1)获取鼠 ...
- 【vue】vue组件的自定义事件
父组件: <template> <div> <my-child abcClick="sayHello"></my-child> &l ...
- ping不通的常见原因和解决办法
Ping是Windows.Unix和Linux系统下的一个命令.ping也属于一个通信协议,是TCP/IP协议的一部分.利用“ping”命令可以检查网络是否连通.如果ping不通则可以通过以下方式寻找 ...
- openstack的网络配置
首先在浏览器输入咱们的控制节点的ip地址登陆horizon,也就是dashboard控制页面 输入好用户名与密码,这时输入的用户名与密码会与我们的老大哥keystone进行认证.确认你输入的这个用户有 ...
- atcoder 2017Code festival C ——D题 Yet Another Palindrome Partitioning(思维+dp)
题目大意: 把一个字符串s分割成m个串,这m个串满足至多有一种字符出现次数为奇数次,其他均为偶数次,问m的最小值 题解: 首先我们想一下纯暴力怎么做 显然是可以n^2暴力的,然后dp[i]表示分割到i ...
- "strcmp()" Anyone? UVA - 11732(trie出现的次数)
给你n个单词,让他们两两比较,要求他们运用strcmp时,进行比较的次数. 边建树边统计 #include <iostream> #include <cstdio> #incl ...
- CF712E Memory and Casinos 期望概率
题意:\(n\)个赌场,每个赌场有\(p_{i}\)的胜率,如果赢了就走到下一个赌场,输了就退回上一个赌场,规定\(1\)号赌场的上一个是\(0\)号赌场,\(n\)号赌场的下一个是\(n + 1\) ...
- SpringBoot多数据源配置事务
除了消费降级,这将会是娱乐继续下沉的一年. 36氪从多个信源处获悉,资讯阅读应用趣头条已经完成了腾讯领投的Pre-IPO轮融资,交易金额预计达上亿美元,本轮融资估值在13-15亿美金之间:完成此轮融资 ...
- Mysql(一) 基本操作
一.介绍 1.数据库 数据库,通俗的讲,即为存储数据的“仓库”.不过,数据库不仅只是存储,还对所存储的数据做相应的管理,例如,访问权限,安全性,并发操作,数据的备份与恢复,日志等.实际上,我们所提及的 ...
- 【树链剖分换根】P3979 遥远的国度
Description zcwwzdjn在追杀十分sb的zhx,而zhx逃入了一个遥远的国度.当zcwwzdjn准备进入遥远的国度继续追杀时,守护神RapiD阻拦了zcwwzdjn的去路,他需要zcw ...