CART
一、为什么有CART回归树
以前学过全局回归,顾名思义,就是指全部数据符合某种曲线。比如线性回归,多项式拟合(泰勒)等等。可是这些数学规律多强,硬硬地将全部数据逼近一些特殊的曲线。生活中的数据可是千变万化。那么,局部回归是一种合理地选择。在斯坦福大学NG的公开课中,他也提到局部回归的好处。其中,CART回归树就是局部回归的一种。
二、CART回归树的算法流程
注意到,(1)中两步优化,即选择最优切分变量和切分点。(i)如果给定x的切分点。那么可以马上求得中括号内的最优。(ii)对于切分点怎么确定,这里是用遍历的方法。
三、CART分类树
实际上,CART分类树的生成树和ID3方法类似,只是这里用基尼指数代替了信息增益,定义
四、CART剪枝算法流程
例子参考:http://www.cnblogs.com/zhangchaoyang/articles/2709922.html
比如:
当分类回归树划分得太细时,会对噪声数据产生过拟合作用。因此我们要通过剪枝来解决。剪枝又分为前剪枝和后剪枝:前剪枝是指在构造树的过程中就知道哪些节点可以剪掉,于是干脆不对这些节点进行分裂,在N皇后问题和背包问题中用的都是前剪枝,上面的χ2方法也可以认为是一种前剪枝;后剪枝是指构造出完整的决策树之后再来考查哪些子树可以剪掉。
在分类回归树中可以使用的后剪枝方法有多种,比如:代价复杂性剪枝、最小误差剪枝、悲观误差剪枝等等。这里我们只介绍代价复杂性剪枝法。
对于分类回归树中的每一个非叶子节点计算它的表面误差率增益值α。
是子树中包含的叶子节点个数;
是节点t的误差代价,如果该节点被剪枝;
r(t)是节点t的误差率;
p(t)是节点t上的数据占所有数据的比例。
是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。
比如有个非叶子节点t4如图所示:
已知所有的数据总共有60条,则节点t4的节点误差代价为:
子树误差代价为:
以t4为根节点的子树上叶子节点有3个,最终:
找到α值最小的非叶子节点,令其左右孩子为NULL。当多个非叶子节点的α值同时达到最小时,取最大的进行剪枝。
#include<iostream>
#include<fstream>
#include<sstream>
#include<string>
#include<map>
#include<list>
#include<set>
#include<queue>
#include<utility>
#include<vector>
#include<cmath> using namespace std; //置信水平取0.95时的卡方表
const double CHI[]={0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962};
/*根据多维数组计算卡方值*/
template<typename Comparable>
double cal_chi(Comparable **arr,int row,int col){
vector<Comparable> rowsum(row);
vector<Comparable> colsum(col);
Comparable totalsum=static_cast<Comparable>();
//cout<<"observation"<<endl;
for(int i=;i<row;++i){
for(int j=;j<col;++j){
//cout<<arr[i][j]<<"\t";
totalsum+=arr[i][j];
rowsum[i]+=arr[i][j];
colsum[j]+=arr[i][j];
}
//cout<<endl;
}
double rect=0.0;
//cout<<"exception"<<endl;
for(int i=;i<row;++i){
for(int j=;j<col;++j){
double excep=1.0*rowsum[i]*colsum[j]/totalsum;
//cout<<excep<<"\t";
if(excep!=)
rect+=pow(arr[i][j]-excep,2.0)/excep;
}
//cout<<endl;
}
return rect;
} class MyTriple{
public:
double first;
int second;
int third;
MyTriple(){
first=0.0;
second=;
third=;
}
MyTriple(double f,int s,int t):first(f),second(s),third(t){}
bool operator< (const MyTriple &obj) const{
int cmp=this->first-obj.first;
if(cmp>)
return false;
else if(cmp<)
return true;
else{
cmp=obj.second-this->second;
if(cmp<)
return true;
else
return false;
}
}
}; typedef map<string,int> MAP_REST_COUNT;
typedef map<string,MAP_REST_COUNT> MAP_ATTR_REST;
typedef vector<MAP_ATTR_REST> VEC_STATI; const int ATTR_NUM=; //自变量的维度
vector<string> X(ATTR_NUM);
int rest_number; //因变量的种类数,即类别数
vector<pair<string,int> > classes; //把类别、对应的记录数存放在一个数组中
int total_record_number; //总的记录数
vector<vector<string> > inputData; //原始输入数据 class node{
public:
node* parent; //父节点
node* leftchild; //左孩子节点
node* rightchild; //右孩子节点
string cond; //分枝条件
string decision; //在该节点上作出的类别判定
double precision; //判定的正确率
int record_number; //该节点上涵盖的记录个数
int size; //子树包含的叶子节点的数目
int index; //层次遍历树,给节点标上序号
double alpha; //表面误差率的增加量
node(){
parent=NULL;
leftchild=NULL;
rightchild=NULL;
precision=0.0;
record_number=;
size=;
index=;
alpha=1.0;
}
node(node* p){
parent=p;
leftchild=NULL;
rightchild=NULL;
precision=0.0;
record_number=;
size=;
index=;
alpha=1.0;
}
node(node* p,string c,string d):cond(c),decision(d){
parent=p;
leftchild=NULL;
rightchild=NULL;
precision=0.0;
record_number=;
size=;
index=;
alpha=1.0;
}
void printInfo(){
cout<<"index:"<<index<<"\tdecisoin:"<<decision<<"\tprecision:"<<precision<<"\tcondition:"<<cond<<"\tsize:"<<size;
if(parent!=NULL)
cout<<"\tparent index:"<<parent->index;
if(leftchild!=NULL)
cout<<"\tleftchild:"<<leftchild->index<<"\trightchild:"<<rightchild->index;
cout<<endl;
}
void printTree(){
printInfo();
if(leftchild!=NULL)
leftchild->printTree();
if(rightchild!=NULL)
rightchild->printTree();
}
}; int readInput(string filename){
ifstream ifs(filename.c_str());
if(!ifs){
cerr<<"open inputfile failed!"<<endl;
return -;
}
map<string,int> catg;
string line;
getline(ifs,line);
string item;
istringstream strstm(line);
strstm>>item;
for(int i=;i<X.size();++i){
strstm>>item;
X[i]=item;
}
while(getline(ifs,line)){
vector<string> conts(ATTR_NUM+);
istringstream strstm(line);
//strstm.str(line);
for(int i=;i<conts.size();++i){
strstm>>item;
conts[i]=item;
if(i==conts.size()-)
catg[item]++;
}
inputData.push_back(conts);
}
total_record_number=inputData.size();
ifs.close();
map<string,int>::const_iterator itr=catg.begin();
while(itr!=catg.end()){
classes.push_back(make_pair(itr->first,itr->second));
itr++;
}
rest_number=classes.size();
return ;
} /*根据inputData作出一个统计stati*/
void statistic(vector<vector<string> > &inputData,VEC_STATI &stati){
for(int i=;i<ATTR_NUM+;++i){
MAP_ATTR_REST attr_rest;
for(int j=;j<inputData.size();++j){
string attr_value=inputData[j][i];
string rest=inputData[j][ATTR_NUM+];
MAP_ATTR_REST::iterator itr=attr_rest.find(attr_value);
if(itr==attr_rest.end()){
MAP_REST_COUNT rest_count;
rest_count[rest]=;
attr_rest[attr_value]=rest_count;
}
else{
MAP_REST_COUNT::iterator iter=(itr->second).find(rest);
if(iter==(itr->second).end()){
(itr->second).insert(make_pair(rest,));
}
else{
iter->second+=;
}
}
}
stati.push_back(attr_rest);
}
} /*依据某条件作出分枝时,inputData被分成两部分*/
void splitInput(vector<vector<string> > &inputData,int fitIndex,string cond,vector<vector<string> > &LinputData,vector<vector<string> > &RinputData){
for(int i=;i<inputData.size();++i){
if(inputData[i][fitIndex+]==cond)
LinputData.push_back(inputData[i]);
else
RinputData.push_back(inputData[i]);
}
} void printStati(VEC_STATI &stati){
for(int i=;i<stati.size();i++){
MAP_ATTR_REST::const_iterator itr=stati[i].begin();
while(itr!=stati[i].end()){
cout<<itr->first;
MAP_REST_COUNT::const_iterator iter=(itr->second).begin();
while(iter!=(itr->second).end()){
cout<<"\t"<<iter->first<<"\t"<<iter->second;
iter++;
}
itr++;
cout<<endl;
}
cout<<endl;
}
} void split(node *root,vector<vector<string> > &inputData,vector<pair<string,int> > classes){
//root->printInfo();
root->record_number=inputData.size();
VEC_STATI stati;
statistic(inputData,stati);
//printStati(stati);
//for(int i=0;i<rest_number;i++)
// cout<<classes[i].first<<"\t"<<classes[i].second<<"\t";
//cout<<endl;
/*找到最大化GINI指标的划分*/
double minGain=1.0; //最小的GINI增益
int fitIndex=-;
string fitCond;
vector<pair<string,int> > fitleftclasses;
vector<pair<string,int> > fitrightclasses;
int fitleftnumber;
int fitrightnumber;
for(int i=;i<stati.size();++i){ //扫描每一个自变量
MAP_ATTR_REST::const_iterator itr=stati[i].begin();
while(itr!=stati[i].end()){ //扫描自变量上的每一个取值
string condition=itr->first; //判定的条件,即到达左孩子的条件
//cout<<"cond 为"<<X[i]+condition<<"时:";
vector<pair<string,int> > leftclasses(classes); //左孩子节点上类别、及对应的数目
vector<pair<string,int> > rightclasses(classes); //右孩子节点上类别、及对应的数目
int leftnumber=; //左孩子节点上包含的类别数目
int rightnumber=; //右孩子节点上包含的类别数目
for(int j=;j<leftclasses.size();++j){ //更新类别对应的数目
string rest=leftclasses[j].first;
MAP_REST_COUNT::const_iterator iter2;
iter2=(itr->second).find(rest);
if(iter2==(itr->second).end()){ //没找到
leftclasses[j].second=;
rightnumber+=rightclasses[j].second;
}
else{ //找到
leftclasses[j].second=iter2->second;
leftnumber+=leftclasses[j].second;
rightclasses[j].second-=(iter2->second);
rightnumber+=rightclasses[j].second;
}
}
/**if(leftnumber==0 || rightnumber==0){
cout<<"左右有一边为空"<<endl; for(int k=0;k<rest_number;k++)
cout<<leftclasses[k].first<<"\t"<<leftclasses[k].second<<"\t";
cout<<endl;
for(int k=0;k<rest_number;k++)
cout<<rightclasses[k].first<<"\t"<<rightclasses[k].second<<"\t";
cout<<endl;
}**/
double gain1=1.0; //计算GINI增益
double gain2=1.0;
if(leftnumber==)
gain1=0.0;
else
for(int j=;j<leftclasses.size();++j)
gain1-=pow(1.0*leftclasses[j].second/leftnumber,2.0);
if(rightnumber==)
gain2=0.0;
else
for(int j=;j<rightclasses.size();++j)
gain2-=pow(1.0*rightclasses[j].second/rightnumber,2.0);
double gain=1.0*leftnumber/(leftnumber+rightnumber)*gain1+1.0*rightnumber/(leftnumber+rightnumber)*gain2;
//cout<<"GINI增益:"<<gain<<endl;
if(gain<minGain){
//cout<<"GINI增益:"<<gain<<"\t"<<i<<"\t"<<condition<<endl;
fitIndex=i;
fitCond=condition;
fitleftclasses=leftclasses;
fitrightclasses=rightclasses;
fitleftnumber=leftnumber;
fitrightnumber=rightnumber;
minGain=gain;
}
itr++;
}
} /*计算卡方值,看有没有必要进行分裂*/
//cout<<"按"<<X[fitIndex]+fitCond<<"划分,计算卡方"<<endl;
int **arr=new int*[];
for(int i=;i<;i++)
arr[i]=new int[rest_number];
for(int i=;i<rest_number;i++){
arr[][i]=fitleftclasses[i].second;
arr[][i]=fitrightclasses[i].second;
}
double chi=cal_chi(arr,,rest_number);
//cout<<"chi="<<chi<<" CHI="<<CHI[rest_number-2]<<endl;
if(chi<CHI[rest_number-]){ //独立,没必要再分裂了
delete []arr[]; delete []arr[]; delete []arr;
return; //不需要分裂函数就返回
}
delete []arr[]; delete []arr[]; delete []arr; /*分裂*/
root->cond=X[fitIndex]+"="+fitCond; //root的分枝条件
//cout<<"分类条件:"<<root->cond<<endl;
node *travel=root; //root及其祖先节点的size都要加1
while(travel!=NULL){
(travel->size)++;
travel=travel->parent;
} node *LChild=new node(root); //创建左右孩子
node *RChild=new node(root);
root->leftchild=LChild;
root->rightchild=RChild;
int maxLcount=;
int maxRcount=;
string Ldicision,Rdicision;
for(int i=;i<rest_number;++i){ //统计哪种类别出现的最多,从而作出类别判定
if(fitleftclasses[i].second>maxLcount){
maxLcount=fitleftclasses[i].second;
Ldicision=fitleftclasses[i].first;
}
if(fitrightclasses[i].second>maxRcount){
maxRcount=fitrightclasses[i].second;
Rdicision=fitrightclasses[i].first;
}
}
LChild->decision=Ldicision;
RChild->decision=Rdicision;
LChild->precision=1.0*maxLcount/fitleftnumber;
RChild->precision=1.0*maxRcount/fitrightnumber; /*递归对左右孩子进行分裂*/
vector<vector<string> > LinputData,RinputData;
splitInput(inputData,fitIndex,fitCond,LinputData,RinputData);
//cout<<"左边inputData行数:"<<LinputData.size()<<endl;
//cout<<"右边inputData行数:"<<RinputData.size()<<endl;
split(LChild,LinputData,fitleftclasses);
split(RChild,RinputData,fitrightclasses);
} /*计算子树的误差代价*/
double calR2(node *root){
if(root->leftchild==NULL)
return (-root->precision)*root->record_number/total_record_number;
else
return calR2(root->leftchild)+calR2(root->rightchild);
} /*层次遍历树,给节点标上序号。同时计算alpha*/
void index(node *root,priority_queue<MyTriple> &pq){
int i=;
queue<node*> que;
que.push(root);
while(!que.empty()){
node* n=que.front();
que.pop();
n->index=i++;
if(n->leftchild!=NULL){
que.push(n->leftchild);
que.push(n->rightchild);
//计算表面误差率的增量
double r1=(-n->precision)*n->record_number/total_record_number; //节点的误差代价
double r2=calR2(n);
n->alpha=(r1-r2)/(n->size-);
pq.push(MyTriple(n->alpha,n->size,n->index));
}
}
} /*剪枝*/
void prune(node *root,priority_queue<MyTriple> &pq){
MyTriple triple=pq.top();
int i=triple.third;
queue<node*> que;
que.push(root);
while(!que.empty()){
node* n=que.front();
que.pop();
if(n->index==i){
cout<<"将要剪掉"<<i<<"的左右子树"<<endl;
n->leftchild=NULL;
n->rightchild=NULL;
int s=n->size-;
node *trav=n;
while(trav!=NULL){
trav->size-=s;
trav=trav->parent;
}
break;
}
else if(n->leftchild!=NULL){
que.push(n->leftchild);
que.push(n->rightchild);
}
}
} void test(string filename,node *root){
ifstream ifs(filename.c_str());
if(!ifs){
cerr<<"open inputfile failed!"<<endl;
return;
}
string line;
getline(ifs,line);
string item;
istringstream strstm(line); //跳过第一行
map<string,string> independent; //自变量,即分类的依据
while(getline(ifs,line)){
istringstream strstm(line);
//strstm.str(line);
strstm>>item;
cout<<item<<"\t";
for(int i=;i<ATTR_NUM;++i){
strstm>>item;
independent[X[i]]=item;
}
node *trav=root;
while(trav!=NULL){
if(trav->leftchild==NULL){
cout<<(trav->decision)<<"\t置信度:"<<(trav->precision)<<endl;;
break;
}
string cond=trav->cond;
string::size_type pos=cond.find("=");
string pre=cond.substr(,pos);
string post=cond.substr(pos+);
if(independent[pre]==post)
trav=trav->leftchild;
else
trav=trav->rightchild;
}
}
ifs.close();
} int main(){
string inputFile="animal";
readInput(inputFile);
VEC_STATI stati; //最原始的统计
statistic(inputData,stati); // for(int i=0;i<classes.size();++i)
// cout<<classes[i].first<<"\t"<<classes[i].second<<"\t";
// cout<<endl;
node *root=new node();
split(root,inputData,classes); //分裂根节点
priority_queue<MyTriple> pq;
index(root,pq);
root->printTree();
cout<<"剪枝前使用该决策树最多进行"<<root->size-<<"次条件判断"<<endl;
/**
//检验一个是不是表面误差增量最小的被剪掉了
while(!pq.empty()){
MyTriple triple=pq.top();
pq.pop();
cout<<triple.first<<"\t"<<triple.second<<"\t"<<triple.third<<endl;
}
**/
test(inputFile,root); prune(root,pq);
cout<<"剪枝后使用该决策树最多进行"<<root->size-<<"次条件判断"<<endl;
test(inputFile,root);
return ;
}
参考文献:
http://blog.csdn.net/google19890102/article/details/32329823
CART的更多相关文章
- 【十大经典数据挖掘算法】CART
[十大经典数据挖掘算法]系列 C4.5 K-Means SVM Apriori EM PageRank AdaBoost kNN Naïve Bayes CART 1. 前言 分类与回归树(Class ...
- ID3、C4.5、CART、RandomForest的原理
决策树意义: 分类决策树模型是表示基于特征对实例进行分类的树形结构.决策树可以转换为一个if_then规则的集合,也可以看作是定义在特征空间划分上的类的条件概率分布. 它着眼于从一组无次序.无规则的样 ...
- C4.5,CART,randomforest的实践
#################################Weka-J48(C4.5)################################# ################### ...
- CART(分类回归树)
1.简单介绍 线性回归方法可以有效的拟合所有样本点(局部加权线性回归除外).当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法一个是困难一个是笨拙.此外,实际中很多问题为非线性的,例如常 ...
- 决策树-预测隐形眼镜类型 (ID3算法,C4.5算法,CART算法,GINI指数,剪枝,随机森林)
1. 1.问题的引入 2.一个实例 3.基本概念 4.ID3 5.C4.5 6.CART 7.随机森林 2. 我们应该设计什么的算法,使得计算机对贷款申请人员的申请信息自动进行分类,以决定能否贷款? ...
- CART:分类与回归树
起源:决策树切分数据集 决策树每次决策时,按照一定规则切分数据集,并将切分后的小数据集递归处理.这样的处理方式给了线性回归处理非线性数据一个启发. 能不能先将类似特征的数据切成一小部分,再将这一小部分 ...
- cart中回归树的原理和实现
前面说了那么多,一直围绕着分类问题讨论,下面我们开始学习回归树吧, cart生成有两个关键点 如何评价最优二分结果 什么时候停止和如何确定叶子节点的值 cart分类树采用gini系数来对二分结果进行评 ...
- 用cart(分类回归树)作为弱分类器实现adaboost
在之前的决策树到集成学习里我们说了决策树和集成学习的基本概念(用了adaboost昨晚集成学习的例子),其后我们分别学习了决策树分类原理和adaboost原理和实现, 上两篇我们学习了cart(决策分 ...
- 连续值的CART(分类回归树)原理和实现
上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现 思考连续值和离散值的不同之处: 二分子树的时候不同:离散值需要求出最优的两个 ...
- CART(分类回归树)原理和实现
前面我们了解了决策树和adaboost的决策树墩的原理和实现,在adaboost我们看到,用简单的决策树墩的效果也很不错,但是对于更多特征的样本来说,可能需要很多数量的决策树墩 或许我们可以考虑使用更 ...
随机推荐
- 10+优秀“分步引导”jQuery插件(转)
很 多时候一个网站或者一个Web应用出品,为了让你的用户知道你的站点(或应用)有些什么?如何操作?为了让你的用户有更好的体验.往往这个时候都会给你的 站点(应用)添加一个分步指引的效果.然而这样的效果 ...
- influxdb和boltDB简介——底层本质类似LMDB,MVCC+B+树
influxdb influxdb是最新的一个时间序列数据库,最新一两年才产生,但已经拥有极高的人气.influxdb 是用Go写的,0.9版本的influxdb对于之前会有很大的改变,后端存储有Le ...
- perl 正则匹配代码
36 chomp $line; 37 my @vec = split /\t/, $line; 38 my @vec2 = ($vec[1]=~/[a-z]+/g); 39 ...
- Eclipse 反编译器
Help-->Install New SoftWare 贴上反编译地址:http://opensource.cpupk.com/decompiler/update/ 选择add,一路向北,起飞.
- NodeJS无所不能:细数10个令人惊讶的NodeJS开源项目
在几年的时间里,NodeJS逐渐发展成一个成熟的开发平台,吸引了许多开发者.有许多大型高流量网站都采用NodeJS进行开发,像PayPal,此外,开发人员还可以使用它来开发一些快速移动Web框架. 除 ...
- 实现OAUTH协议 实现 QQ 第三方登录效果
1.OAuth的简述 OAuth(Open Authorization,开放授权)是为用户资源的授权定义了一个安全.开放及简单的标准,第三方无需知道用户的账号及密码,就可获取到用户的授权信息,并且这是 ...
- Mysqldump参数大全
Mysqldump参数大全(参数来源于mysql5.5.19源码) 参数 参数说明 --all-databases , -A 导出全部数据库. mysqldump -uroot -p --al ...
- js基础之弹性运动(四)
一.滑动菜单.图片 var iSpeed=0;var left=0;function startMove(obj,iTarg){ clearInterval(obj.timer);//记得先关定时器 ...
- HDU 3265 扫描线(矩形面积并变形)
Posters Time Limit: 5000/2000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others)Total Sub ...
- DotNetBar v12.9.0.0 Fully Cracked
更新信息: http://www.devcomponents.com/customeronly/releasenotes.asp?p=dnbwf&v=12.9.0.0 如果遇到破解问题可以与我 ...