起源:决策树切分数据集

决策树每次决策时,按照一定规则切分数据集,并将切分后的小数据集递归处理。这样的处理方式给了线性回归处理非线性数据一个启发。

能不能先将类似特征的数据切成一小部分,再将这一小部分放大处理,使用线性的方法增加准确率呢?

Part I:  树的枝与叶

枝:二叉 or 多叉?

在AdaBoost的单决策树中,对于连续型数据构建决策树,我们采取步进阈值切分2段的方法。还有一种简化处理,即选择子数据集中的当前维度所有不同的值作为阈值切分。

而在CART里,大于阈值归为左孩子,小于阈值的归为右孩子。若是离散型数据,则根据离散数据种类建立对应的多叉树即可。

叶:何时不再切分?

ID3决策树中,停止切分的条件有两个:

①DFS链路中全部切分方式被扫过一次,很明显,对于离散型特征,每次按照某一维度异同切分,再扫相同维度毫无意义。

对于连续型特征,则共有维度*(当前维度不同值数量,即阈值数量)种切分方式,同种方式也毫无意义。

②当前子数据集分类全部一致,已经是很完美的切分了,再切也没意思。

CART中,由于搜索深度只有1,重复选取也不会卡死。所以直接遵循②。

追加③条件:手动限制切分子集数量下限tolN,误差变化下限tolS(目标函数收敛)。一旦达到这两个下限,就立刻停止。

枝:一个好枝?

ID3算法给出了一个评价离散Label的好枝的标准:分类的混乱度(香农熵)降低。

对于连续数据,好枝的参考标准则是类似最小二乘法的目标函数,即误差越小越好。

由于计算误差需要先进行线性回归,相当于树套回归,虽然效果很好,但是无疑带来计算压力。

在这点上, CART利用均值和方差的性质给出了一个简化的误差计算:即假设一团数据的回归结果是这团数据的均值,那么目标函数即可当成总方差。

使用均值替代回归结果的树称为回归树,使用实际回归结果的树成为模型树。

叶:数量越多越好?

叶结点数量越多,越容易过拟合。数量越少,则容易欠拟合。

而tolN和tolS在选择最好的切分方式时,控制着叶结点的数量,这两个值越小,叶子越多,且对tolS的值很敏感。

树的递归构建:

①对当前数据集做最好的切分。

②若不能切分,则将该结点设为叶结点。

否则,由于切分的性质,所以切出的两个子集必定不为空。对大于阈值的子集进行左孩子递归构建,小于阈值的子集进行右孩子递归构建。

Part II :  树的剪枝

叶结点数量决定着拟合情况。人工调整不是一件好事。

所以出现一种先强行过拟合(tolN=0,tolS=1)生成CART树,然后利用新的样本数据进行剪枝的方法,称为后剪枝。

后剪枝有两种方法:

①后剪枝会将大量的枝从树顶直接转化成叶子,相当于废掉原树中很多数据,所以需要引入新的数据。

而把一个大枝转为叶子的方法,则是利用均值的性质。新叶子的回归值=原枝上所有叶的均值。

②除了废枝为叶,还有利用均值的计算性质、借助新数据归并两叶。当然归并是有条件的。

新数据递归切分之后,必然会分到叶子上。如果恰好一个枝上是两片叶子,那么分别计算ErrNoMerga、ErrMerga的值,观察是否变小来决定是否归并。

$ErrNoMerga=\sum_{i=1}^{LSet}(Set[i].y-L.leaf)^{2}+\sum_{i=1}^{RSet}(Set[i].y-R.leaf)^{2}$

$NewLeaf=mergaMean=avg(L.leaf+R.leaf)$

$ErrMerga=\sum_{i=1}^{Set}{(Set[i].y-mergaMean)^{2}}$

Part III:   回归与模型树

对于每条测试数据,从树顶按照树中保存的切分规则左右递归直到叶结点,返回叶结点的值作为回归值。

实际测试结果下,效果并不好。所以应当每一个叶结点:保留数据、以及线性回归方程(w、b),从而建立起模型树。

线性模型树方法将取代回归树中的均值误差理论,主要修改地方在选择分支、后剪枝上。

$Err =\sum_{i=1}^{m} (data[i].y-Regression(y))^{2}$

这样,叶结点就变成了一个线性回归器,返回线性回归结果即可。

Part IV 代码

#include "cstdio"
#include "iostream"
#include "fstream"
#include "math.h"
#include "sstream"
#include "string"
#include "vector"
#include "set"
using namespace std;
#define Dim dataSet[0].feature.size()
#define TREE pair<vector<Data>,vector<Data> >
#define NULL 0
struct Data
{
vector<double> feature;
double y;
Data(vector<double> feature,double y):feature(feature),y(y) {}
};
struct RegTree
{
int dim;double value;
RegTree *Left,*Right;
RegTree():Left(NULL),Right(NULL) {}
RegTree(int dim,double value):Left(NULL),Right(NULL),dim(dim),value(value) {}
};
vector<Data> dataSet,addSet,testSet;
pair<int,double> ops(,);
void read()
{
ifstream fin("data1.txt"),fin2("data2.txt"),fin3("data3.txt");
string line;double tmp,y;
while(getline(fin,line))
{
stringstream sin(line);
vector<double> feature;
while(sin>>tmp) feature.push_back(tmp);
y=feature.back();feature.pop_back();
dataSet.push_back(Data(feature,y));
}
while(getline(fin2,line))
{
stringstream sin(line);
vector<double> feature;
while(sin>>tmp) feature.push_back(tmp);
y=feature.back();feature.pop_back();
addSet.push_back(Data(feature,y));
}
while(getline(fin3,line))
{
stringstream sin(line);
vector<double> feature;
while(sin>>tmp) feature.push_back(tmp);
y=feature.back();feature.pop_back();
testSet.push_back(Data(feature,y));
}
}
pair<vector<Data>,vector<Data> > splitDataSet(vector<Data> dataSet,int dim,double value)
{
vector<Data> Left,Right;
for(int i=;i<dataSet.size();i++)
{
if(dataSet[i].feature[dim]>value) Left.push_back(dataSet[i]);
else Right.push_back(dataSet[i]);
}
return make_pair(Left,Right);
}
double regLeaf(vector<Data> dataSet)
{
double ret=0.0;
//printf("Leaf:\n");
for(int i=;i<dataSet.size();i++)
{
ret+=dataSet[i].y;
/*
for(int j=0;j<dataSet[i].feature.size();j++) printf("%.2lf ",dataSet[i].feature[j]);
printf("%lf\n",dataSet[i].y);*/
}
//printf("\n");
return ret/dataSet.size();
}
double calcErr(vector<Data> dataSet)
{
double avg=0.0,ret=0.0;
for(int i=;i<dataSet.size();i++) avg+=dataSet[i].y;
avg/=dataSet.size();
for(int i=;i<dataSet.size();i++) ret+=(dataSet[i].y-avg)*(dataSet[i].y-avg);
return ret;
}
pair<int,double> chooseBestSplit(vector<Data> dataSet)
{
//tolN、tolS(较敏感)过小都会导致Leaf过多,过大则会导致Leaf过少
int tolN=ops.first;double tolS=ops.second,S,newS,bestS=1e10,bestValue,bestDim;
set<double> y;
for(int i=;i<dataSet.size();i++) y.insert(dataSet[i].y);
if(y.size()==) return make_pair(-,regLeaf(dataSet));
S=calcErr(dataSet);
for(int i=;i<Dim;i++)
{
set<double> splitValue;
for(int j=;j<dataSet.size();j++) splitValue.insert(dataSet[j].feature[i]);
for(set<double>::iterator j=splitValue.begin();j!=splitValue.end();j++)
{
TREE tree=splitDataSet(dataSet,i,*j);
if(tree.first.size()<tolN||tree.second.size()<tolN) continue;
newS=calcErr(tree.first)+calcErr(tree.second);
if(newS<bestS) {bestDim=i;bestValue=*j;bestS=newS;}
}
}
if(S-bestS<tolS) return make_pair(-,regLeaf(dataSet));
TREE tree=splitDataSet(dataSet,bestDim,bestValue);
if(tree.first.size()<tolN||tree.second.size()<tolN) return make_pair(-,regLeaf(dataSet));
return make_pair(bestDim,bestValue);
}
RegTree *buildTree(vector<Data> dataSet)
{
pair<int,double> info=chooseBestSplit(dataSet);
if(info.first==-)
{
RegTree *node=new RegTree(info.first,info.second);
return node;
}
RegTree *node=new RegTree(info.first,info.second);
TREE tree=splitDataSet(dataSet,info.first,info.second);
//printf("Node: dim:%d %.2lf\n",info.first,info.second);
node->Left=buildTree(tree.first);
node->Right=buildTree(tree.second);
return node;
}
double getMean(RegTree *root)
{
double ret=0.0;
if(root->Left->dim!=-) ret+=getMean(root->Left);
else ret+=root->Left->value;
if(root->Right->dim!=-) ret+=getMean(root->Right);
else ret+=root->Right->value;
return ret/=;
}
RegTree *prune(RegTree *&root,vector<Data> dataSet)
{
if(dataSet.size()==) return new RegTree(-,getMean(root));
double errNoMerga=0.0,errMerga=0.0;
if(root->Left->dim!=-||root->Right->dim!=-)
{
TREE tree=splitDataSet(dataSet,root->dim,root->value);
if(root->Left->dim!=-) root->Left=prune(root->Left,tree.first);
if(root->Right->dim!=-) root->Right=prune(root->Right,tree.second);
}
if(root->Left->dim==-&&root->Right->dim==-)
{
TREE tree=splitDataSet(dataSet,root->dim,root->value);
for(int i=;i<tree.first.size();i++) errNoMerga+=(tree.first[i].y-root->Left->value)*(tree.first[i].y-root->Left->value);
for(int i=;i<tree.second.size();i++) errNoMerga+=(tree.second[i].y-root->Right->value)*(tree.second[i].y-root->Right->value);
double mergaMean=(root->Left->value+root->Right->value)/;
for(int i=;i<dataSet.size();i++) errMerga+=(dataSet[i].y-mergaMean)*(dataSet[i].y-mergaMean);
if(errMerga<errNoMerga) {/*cout<<"Merga"<<endl;*/return new RegTree(-,mergaMean);}
else return root;
}
return root;
}
int ccnt=;
void displayTree(RegTree *root)
{
if(root->Left->dim!=-) displayTree(root->Left);
else {printf("Leaf:%.2lf\n",root->Left->value);ccnt++;}
if(root->Right->dim!=-) displayTree(root->Right);
else {printf("Leaf:%.2lf\n",root->Right->value);ccnt++;}
}
double forcast(RegTree *root,Data data)
{
if(root->dim==-) return root->value; //in case the super root is a leaf
if(data.feature[root->dim]>root->value)
{
if(root->Left->dim!=-) return forcast(root->Left,data);
else return root->Left->value;
}
else
{
if(root->Right->dim!=-) return forcast(root->Right,data);
else return root->Right->value;
}
}
void forcastAll(RegTree *root,vector<Data> dataSet)
{
for(int i=;i<dataSet.size();i++)
{
double y=forcast(root,dataSet[i]);
printf("origin:%.2lf forcast:%.2lf\n",dataSet[i].y,y);
}
}
int main()
{
read();
RegTree *root=buildTree(dataSet);
root=prune(root,addSet);
forcastAll(root,testSet);
}

回归树

CART:分类与回归树的更多相关文章

  1. CART分类与回归树与GBDT(Gradient Boost Decision Tree)

    一.CART分类与回归树 资料转载: http://dataunion.org/5771.html        Classification And Regression Tree(CART)是决策 ...

  2. CART分类与回归树 学习笔记

    CART:Classification and regression tree,分类与回归树.(是二叉树) CART是决策树的一种,主要由特征选择,树的生成和剪枝三部分组成.它主要用来处理分类和回归问 ...

  3. 【机器学习笔记之三】CART 分类与回归树

    本文结构: CART算法有两步 回归树的生成 分类树的生成 剪枝 CART - Classification and Regression Trees 分类与回归树,是二叉树,可以用于分类,也可以用于 ...

  4. 数据挖掘十大经典算法--CART: 分类与回归树

    一.决策树的类型  在数据挖掘中,决策树主要有两种类型: 分类树 的输出是样本的类标. 回归树 的输出是一个实数 (比如房子的价格,病人呆在医院的时间等). 术语分类和回归树 (CART) 包括了上述 ...

  5. CART 分类与回归树

    from www.jianshu.com/p/b90a9ce05b28 本文结构: CART算法有两步 回归树的生成 分类树的生成 剪枝 CART - Classification and Regre ...

  6. 回归树(Regression Tree)

    目录 回归树 理论解释 算法流程 ID3 和 C4.5 能不能用来回归? 回归树示例 References 说到决策树(Decision tree),我们很自然会想到用其做分类,每个叶子代表有限类别中 ...

  7. CART(分类回归树)

    1.简单介绍 线性回归方法可以有效的拟合所有样本点(局部加权线性回归除外).当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法一个是困难一个是笨拙.此外,实际中很多问题为非线性的,例如常 ...

  8. 连续值的CART(分类回归树)原理和实现

    上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现 思考连续值和离散值的不同之处: 二分子树的时候不同:离散值需要求出最优的两个 ...

  9. 机器学习技法-决策树和CART分类回归树构建算法

    课程地址:https://class.coursera.org/ntumltwo-002/lecture 重要!重要!重要~ 一.决策树(Decision Tree).口袋(Bagging),自适应增 ...

随机推荐

  1. 用java来删除数组中指定的元素

    public static void main(String[] args){        String[] a = new String[]{"1","5" ...

  2. 使用drozer连接时提示:Could not find java. Please ensure that it is installed and on your path

    在安装drozer后使用 drozer.bat console connect命令提示如下错误(实际上我已经安装了jdk并添加了path) 参考上面的链接已经它的提示解决方法如下: 建立名为 .dro ...

  3. android 5.1 WIFI图标上的感叹号及其解决办法

    转自:http://blog.csdn.net/w6980112/article/details/45843129 第一次调试android5.1的 WIFI更改小功能 Wifi 源码的相关路径目录  ...

  4. JAVA基础学习之final关键字、遍历集合、日期类对象的使用、Math类对象的使用、Runtime类对象的使用、时间对象Date(两个日期相减)(5)

    1.final关键字和.net中的const关键字一样,是常量的修饰符,但是final还可以修饰类.方法.写法规范:常量所有字母都大写,多个单词中间用 "_"连接. 2.遍历集合A ...

  5. 重温WCF之流与文件传输(七)

    WCF开启流模式,主要是设置一个叫TransferMode的属性,所以,你看看哪些Binding的派生类有这个属性就可以了. TransferMode其实是一个举枚,看看它的几个有效值: Buffer ...

  6. ASP.NET Web API 全局权限和全局异常处理

    在开发中,我使用json格式序列化,所以将默认的xml序列化移除 public static class WebApiConfig { public static void Register(Http ...

  7. 使用Asyncio的Coroutine来实现一个有限状态机

    如图: #!/usr/bin/env python # -*- coding: utf-8 -*- import asyncio import datetime import time from ra ...

  8. 无废话ExtJs 入门教程十[单选组:RadioGroup、复选组:CheckBoxGroup]

    无废话ExtJs 入门教程十[单选组:RadioGroup.复选组:CheckBoxGroup] extjs技术交流,欢迎加群(201926085) 继上一节内容,我们在表单里加了个一个单选组,一个复 ...

  9. Oracle 备份与恢复介绍

    一.Oracle备份方式分类:Oracle有两类备份方式:(1)物理备份:是将实际组成数据库的操作系统文件从一处拷贝到另一处的备份过程,通常是从磁盘到磁带.物理备份又分为冷备份.热备份:   (2)逻 ...

  10. C# 非UI线程对控件的控制

    第一步:定义委托 public delegate void wei(string ss); 第二步:控制UI的方法 public void get1(string ss) { richTextBox1 ...