目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。

这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。

关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。

bagging和boosting的区别

bagging:是指在原始数据上通过放回抽样,抽出和原始数据大小相等的新数据集(这个性质说明新数据集存在重复的值,而原始数据部分数据值不会出现在新数据集中),并重复该过程选择N个新数据集,这样通过N个分类器对这个N个数据集进行分类,最后选择分类器投票结果中最多类别作为最后的分类结果。
boosting:相比bagging,boosting像是一种串行,bagging是一种并行的,bagging可以对于N个数据集通过N个分类器同时进行分类,并且每个分类器的权重是一样的,但是boosting则相反,boosting是利用一个数据集依次由每个分类器进行分类,而确定每个分类器的权重是加大正确率高的分类器的权重,减少正确率低的分类器的权重。同时为了提高准确率,每次会降低被正确分类的样本的权重,提高没有正确分类的样本的权重。这样做其实比较符合人的决策过程,就是要多训练自己容易做错的题型,并且要多听取正确性高的老师的意见。
 
那么AdaBoost的主要的两个过程就是提高错误分类的样本权重和提高正确率高的分类器的权重。
算法的步骤:
输入:训练集T,弱学习分类器(这里是一个节点的决策树)
输出:最终的分类器G
1 先初始化样本权重值,D1={W11...W1n}W1i=1/n
2 根据样本权重D1以及决策树求分类误差率,并求的最小的误差率em,以及该决策树
  em=
3 计算该分类器的权重
  可以看出,误差率越小的,其权重越大
4 更新各个样本的权重,Dm+1,(用公式编辑器好麻烦。。。 )
  
其中Zm是规范化银子:
  
5 构建基本分类器
  F(X)=
6 计算该分类器下的误差率,如果小于某个阈值就停止,否则从第二步开始迭代
 
终于不用打公式了。。。。
附上代码:
 import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList; class Stump{
public int dim;
public double thresh;
public String condition;
public double error;
public ArrayList<Integer> labelList;
double factor; public String toString(){
return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
}
} class Utils{
//加载数据集
public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
ArrayList<Double> data=new ArrayList<Double>();
String[] s=line.split(" "); for(int i=0;i<s.length-1;i++){
data.add(Double.parseDouble(s[i]));
}
dataSet.add(data);
}
return dataSet;
} //加载类别
public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
ArrayList<Integer> labelSet=new ArrayList<Integer>(); FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
String[] s=line.split(" ");
labelSet.add(Integer.parseInt(s[s.length-1]));
}
return labelSet;
}
//测试用的
public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
for(ArrayList<Double> data:dataSet){
System.out.println(data);
}
}
//获取最大值,用于求步长
public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
double max=-9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)>max){
max=data.get(index);
}
}
return max;
}
//获取最小值,用于求步长
public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
double min=9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)<min){
min=data.get(index);
}
}
return min;
} //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别
public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
ArrayList<Integer> labelList=new ArrayList<Integer>();
if(condition.compareTo("lt")==0){
for(ArrayList<Double> data:dataSet){
if(data.get(feature)<=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}else{
for(ArrayList<Double> data:dataSet){
if(data.get(feature)>=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}
return labelList;
}
//求预测类别与真实类别的加权误差
public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
double error=0; int n=real.size(); for(int i=0;i<fake.size();i++){
if(fake.get(i)!=real.get(i)){
error+=weights.get(i); }
} return error;
}
//构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。
public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
int featureNum=dataSet.get(0).size(); int rowNum=dataSet.size();
Stump stump=new Stump();
double minError=999.0;
System.out.println("第"+n+"次迭代");
for(int i=0;i<featureNum;i++){
double min=getMin(dataSet,i);
double max=getMax(dataSet,i);
double step=(max-min)/(rowNum);
for(double j=min-step;j<=max+step;j=j+step){
String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类
for(String condition:conditions){
ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); double error=Utils.getError(labelList,labelSet,weights);
if(error<minError){
minError=error;
stump.dim=i;
stump.thresh=j;
stump.condition=condition;
stump.error=minError;
stump.labelList=labelList;
stump.factor=0.5*(Math.log((1-error)/error));
} }
} } return stump;
} public static ArrayList<Double> getInitWeights(int n){
double weight=1.0/n;
ArrayList<Double> weights=new ArrayList<Double>();
for(int i=0;i<n;i++){
weights.add(weight);
}
return weights;
}
//更新样本权值
public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
double Z=0;
ArrayList<Double> newWeights=new ArrayList<Double>();
int row=labelList.size();
double e=Math.E;
double factor=stump.factor;
for(int i=0;i<row;i++){
Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
} for(int i=0;i<row;i++){
double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
newWeights.add(weight);
}
return newWeights;
}
//对加权误差累加
public static ArrayList<Double> InitAccWeightError(int n){
ArrayList<Double> accError=new ArrayList<Double>();
for(int i=0;i<n;i++){
accError.add(0.0);
}
return accError;
} public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
ArrayList<Integer> t=stump.labelList;
double factor=stump.factor;
ArrayList<Double> newAccError=new ArrayList<Double>();
for(int i=0;i<t.size();i++){
double a=accerror.get(i)+factor*t.get(i);
newAccError.add(a);
}
return newAccError;
} public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
ArrayList<Integer> a=new ArrayList<Integer>();
int wrong=0;
for(int i=0;i<accError.size();i++){
if(accError.get(i)>0){
if(labelList.get(i)==-1){
wrong++;
}
}else if(labelList.get(i)==1){
wrong++;
}
}
double error=wrong*1.0/accError.size();
return error;
} public static void showStumpList(ArrayList<Stump> G){
for(Stump s:G){
System.out.println(s);
System.out.println(" ");
}
}
} public class Adaboost { /**
* @param args
* @throws IOException
*/ public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
int row=labelList.size();
ArrayList<Double> weights=Utils.getInitWeights(row);
ArrayList<Stump> G=new ArrayList<Stump>();
ArrayList<Double> accError=Utils.InitAccWeightError(row);
int n=1;
while(true){
Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树
G.add(stump);
weights=Utils.updateWeights(stump,labelList,weights);//更新权值
accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了
double error=Utils.calErrorRate(accError,labelList);
if(error<0.001){
break;
}
n++;
}
return G;
} public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
Utils.showStumpList(G);
System.out.println("finished");
} }

这里的数据采用的是统计学习方法中的数据

0 1
1 1
2 1
3 -1
4 -1
5 -1
6 1
7 1
8 1
9 -1

这里是单个特征的,也可以是多维数据,例如

1.0 2.1 1
2.0 1.1 1
1.3 1.0 -1
1.0 1.0 -1
2.0 1.0 1

AdaBoost的java实现的更多相关文章

  1. Spark案例分析

    一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...

  2. 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost

    集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...

  3. smile——Java机器学习引擎

    资源 https://haifengl.github.io/ https://github.com/haifengl/smile 介绍 Smile(统计机器智能和学习引擎)是一个基于Java和Scal ...

  4. 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题

    背景起因: 记起以前的另一次也是关于内存的调优分享下   有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...

  5. Elasticsearch之java的基本操作一

    摘要   接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...

  6. 论:开发者信仰之“天下IT是一家“(Java .NET篇)

    比尔盖茨公认的IT界领军人物,打造了辉煌一时的PC时代. 2008年,史蒂夫鲍尔默接替了盖茨的工作,成为微软公司的总裁. 2013年他与微软做了最后的道别. 2013年以后,我才真正看到了微软的变化. ...

  7. 故障重现, JAVA进程内存不够时突然挂掉模拟

    背景,服务器上的一个JAVA服务进程突然挂掉,查看产生了崩溃日志,如下: # Set larger code cache with -XX:ReservedCodeCacheSize= # This ...

  8. 死磕内存篇 --- JAVA进程和linux内存间的大小关系

    运行个JAVA 用sleep去hold住 package org.hjb.test; public class TestOnly { public static void main(String[] ...

  9. 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用

    有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...

随机推荐

  1. 标签—box-shadow

    box-shadow:2px 3px 4px #CCC; 一个带外阴影的元素,阴影位置x轴偏移2px,y轴偏移3px,模糊范围4px,阴影颜色#CCC box-shadow:inset 0 -4px  ...

  2. C语言头文件

    最近在工作当中遇到了一点小问题,关于C语言头文件的应用问题,主要还是关于全局变量的定义和声明问题.学习C语言已经有好几年了,工作使用也近半年了,但是对于这部分的东西的确还没有深入的思考过.概念上还是比 ...

  3. T-shirts Distribution

    T-shirts Distribution time limit per test 1 second memory limit per test 256 megabytes input standar ...

  4. JPA 系列教程5-双向一对多

    双向一对多的ddl语句 同单向多对一,单向一对多表的ddl语句一致 Product package com.jege.jpa.one2many; import javax.persistence.En ...

  5. Provisioning profile 浅析

    转载自:    http://blog.csdn.net/chenyufeng1991/article/details/48976245 一般在我们代码编写中不会用到Provisioning prof ...

  6. C# Socket的TCP通讯 异步 (2015-11-07 10:07:19)转载▼

    异步 相对于同步,异步中的连接,接收和发送数据的方法都不一样,都有一个回调函数,就是即使不能连接或者接收不到数据,程序还是会一直执行下去,如果连接上了或者接到数据,程序会回到这个回调函数的地方重新往下 ...

  7. static加载问题

    原文地址:http://blog.csdn.net/lubiaopan/article/details/4802430     感谢原作者! static{}(即static块),会在类被加载的时候执 ...

  8. angular.js_$scope

    Scope(作用域) 是应用在 HTML (视图) 和 JavaScript (控制器)之间的纽带. Scope 是一个对象,有可用的方法和属性. Scope 可应用在视图和控制器上. Angular ...

  9. springmvc json数据

    的 @RequestMapping("/getAllEdu") @ResponseBody public void getAllEdu(HttpServletRequest req ...

  10. 黄聪:基于Asp.net的CMS系统We7架设实验(环境WIN7,SQL2005,.NET3.5)(初学者参考贴)

    http://www.cnblogs.com/huangcong/archive/2010/03/30/1700348.html