目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。

这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。

关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。

bagging和boosting的区别

bagging:是指在原始数据上通过放回抽样,抽出和原始数据大小相等的新数据集(这个性质说明新数据集存在重复的值,而原始数据部分数据值不会出现在新数据集中),并重复该过程选择N个新数据集,这样通过N个分类器对这个N个数据集进行分类,最后选择分类器投票结果中最多类别作为最后的分类结果。
boosting:相比bagging,boosting像是一种串行,bagging是一种并行的,bagging可以对于N个数据集通过N个分类器同时进行分类,并且每个分类器的权重是一样的,但是boosting则相反,boosting是利用一个数据集依次由每个分类器进行分类,而确定每个分类器的权重是加大正确率高的分类器的权重,减少正确率低的分类器的权重。同时为了提高准确率,每次会降低被正确分类的样本的权重,提高没有正确分类的样本的权重。这样做其实比较符合人的决策过程,就是要多训练自己容易做错的题型,并且要多听取正确性高的老师的意见。
 
那么AdaBoost的主要的两个过程就是提高错误分类的样本权重和提高正确率高的分类器的权重。
算法的步骤:
输入:训练集T,弱学习分类器(这里是一个节点的决策树)
输出:最终的分类器G
1 先初始化样本权重值,D1={W11...W1n}W1i=1/n
2 根据样本权重D1以及决策树求分类误差率,并求的最小的误差率em,以及该决策树
  em=
3 计算该分类器的权重
  可以看出,误差率越小的,其权重越大
4 更新各个样本的权重,Dm+1,(用公式编辑器好麻烦。。。 )
  
其中Zm是规范化银子:
  
5 构建基本分类器
  F(X)=
6 计算该分类器下的误差率,如果小于某个阈值就停止,否则从第二步开始迭代
 
终于不用打公式了。。。。
附上代码:
 import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList; class Stump{
public int dim;
public double thresh;
public String condition;
public double error;
public ArrayList<Integer> labelList;
double factor; public String toString(){
return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
}
} class Utils{
//加载数据集
public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
ArrayList<Double> data=new ArrayList<Double>();
String[] s=line.split(" "); for(int i=0;i<s.length-1;i++){
data.add(Double.parseDouble(s[i]));
}
dataSet.add(data);
}
return dataSet;
} //加载类别
public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
ArrayList<Integer> labelSet=new ArrayList<Integer>(); FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
String[] s=line.split(" ");
labelSet.add(Integer.parseInt(s[s.length-1]));
}
return labelSet;
}
//测试用的
public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
for(ArrayList<Double> data:dataSet){
System.out.println(data);
}
}
//获取最大值,用于求步长
public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
double max=-9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)>max){
max=data.get(index);
}
}
return max;
}
//获取最小值,用于求步长
public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
double min=9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)<min){
min=data.get(index);
}
}
return min;
} //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别
public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
ArrayList<Integer> labelList=new ArrayList<Integer>();
if(condition.compareTo("lt")==0){
for(ArrayList<Double> data:dataSet){
if(data.get(feature)<=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}else{
for(ArrayList<Double> data:dataSet){
if(data.get(feature)>=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}
return labelList;
}
//求预测类别与真实类别的加权误差
public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
double error=0; int n=real.size(); for(int i=0;i<fake.size();i++){
if(fake.get(i)!=real.get(i)){
error+=weights.get(i); }
} return error;
}
//构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。
public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
int featureNum=dataSet.get(0).size(); int rowNum=dataSet.size();
Stump stump=new Stump();
double minError=999.0;
System.out.println("第"+n+"次迭代");
for(int i=0;i<featureNum;i++){
double min=getMin(dataSet,i);
double max=getMax(dataSet,i);
double step=(max-min)/(rowNum);
for(double j=min-step;j<=max+step;j=j+step){
String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类
for(String condition:conditions){
ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); double error=Utils.getError(labelList,labelSet,weights);
if(error<minError){
minError=error;
stump.dim=i;
stump.thresh=j;
stump.condition=condition;
stump.error=minError;
stump.labelList=labelList;
stump.factor=0.5*(Math.log((1-error)/error));
} }
} } return stump;
} public static ArrayList<Double> getInitWeights(int n){
double weight=1.0/n;
ArrayList<Double> weights=new ArrayList<Double>();
for(int i=0;i<n;i++){
weights.add(weight);
}
return weights;
}
//更新样本权值
public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
double Z=0;
ArrayList<Double> newWeights=new ArrayList<Double>();
int row=labelList.size();
double e=Math.E;
double factor=stump.factor;
for(int i=0;i<row;i++){
Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
} for(int i=0;i<row;i++){
double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
newWeights.add(weight);
}
return newWeights;
}
//对加权误差累加
public static ArrayList<Double> InitAccWeightError(int n){
ArrayList<Double> accError=new ArrayList<Double>();
for(int i=0;i<n;i++){
accError.add(0.0);
}
return accError;
} public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
ArrayList<Integer> t=stump.labelList;
double factor=stump.factor;
ArrayList<Double> newAccError=new ArrayList<Double>();
for(int i=0;i<t.size();i++){
double a=accerror.get(i)+factor*t.get(i);
newAccError.add(a);
}
return newAccError;
} public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
ArrayList<Integer> a=new ArrayList<Integer>();
int wrong=0;
for(int i=0;i<accError.size();i++){
if(accError.get(i)>0){
if(labelList.get(i)==-1){
wrong++;
}
}else if(labelList.get(i)==1){
wrong++;
}
}
double error=wrong*1.0/accError.size();
return error;
} public static void showStumpList(ArrayList<Stump> G){
for(Stump s:G){
System.out.println(s);
System.out.println(" ");
}
}
} public class Adaboost { /**
* @param args
* @throws IOException
*/ public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
int row=labelList.size();
ArrayList<Double> weights=Utils.getInitWeights(row);
ArrayList<Stump> G=new ArrayList<Stump>();
ArrayList<Double> accError=Utils.InitAccWeightError(row);
int n=1;
while(true){
Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树
G.add(stump);
weights=Utils.updateWeights(stump,labelList,weights);//更新权值
accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了
double error=Utils.calErrorRate(accError,labelList);
if(error<0.001){
break;
}
n++;
}
return G;
} public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
Utils.showStumpList(G);
System.out.println("finished");
} }

这里的数据采用的是统计学习方法中的数据

0 1
1 1
2 1
3 -1
4 -1
5 -1
6 1
7 1
8 1
9 -1

这里是单个特征的,也可以是多维数据,例如

1.0 2.1 1
2.0 1.1 1
1.3 1.0 -1
1.0 1.0 -1
2.0 1.0 1

AdaBoost的java实现的更多相关文章

  1. Spark案例分析

    一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...

  2. 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost

    集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...

  3. smile——Java机器学习引擎

    资源 https://haifengl.github.io/ https://github.com/haifengl/smile 介绍 Smile(统计机器智能和学习引擎)是一个基于Java和Scal ...

  4. 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题

    背景起因: 记起以前的另一次也是关于内存的调优分享下   有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...

  5. Elasticsearch之java的基本操作一

    摘要   接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...

  6. 论:开发者信仰之“天下IT是一家“(Java .NET篇)

    比尔盖茨公认的IT界领军人物,打造了辉煌一时的PC时代. 2008年,史蒂夫鲍尔默接替了盖茨的工作,成为微软公司的总裁. 2013年他与微软做了最后的道别. 2013年以后,我才真正看到了微软的变化. ...

  7. 故障重现, JAVA进程内存不够时突然挂掉模拟

    背景,服务器上的一个JAVA服务进程突然挂掉,查看产生了崩溃日志,如下: # Set larger code cache with -XX:ReservedCodeCacheSize= # This ...

  8. 死磕内存篇 --- JAVA进程和linux内存间的大小关系

    运行个JAVA 用sleep去hold住 package org.hjb.test; public class TestOnly { public static void main(String[] ...

  9. 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用

    有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...

随机推荐

  1. css(非表格变成表格用)

    父元素:display:table: 子元素:display:table-cell:vertical-align:middle:

  2. position的absolute与fixed共同点与不同点

    这个问题面试被问到过~~ A:共同点: 1.改变行内元素的呈现方式,display被置为block: 2.让元素脱离普通流,不占据空间: 3.默认会覆盖到非定位元素上   B不同点: absolute ...

  3. 过滤器HttpModule

    1.建一个类库文件  FirsModule,实现IHttpModule接口,实现其中的两个方法,写一函数实现自己的代码逻辑,在Init方法中调用即可. // <summary> /// 第 ...

  4. 初识Jmeter(一)

    倒霉熊的推荐: 文本学习网址:http://m.open-open.com/m/doc/category/105 视频学习网址: 软件学习网:http://www.ask3.cn/index.html ...

  5. JS-运动基础(一)续

    2.淡入淡出的图片 用变量存储透明度 <title>无标题文档</title> <style> #div1{width:293px; height:220px; b ...

  6. 优化之sitemap+RSS

    RSS也叫聚合, RSS是在线共享内容的一种简易方式,也叫聚合内容,Really Simple Syndication. 通常在时效性比较强的网站或网络平台上应用RSS订阅功能可以更快速获取信息,网站 ...

  7. Itext简绍及操作PDF文件

    iText简介 iText是著名的开放源码的站点sourceforge一个项目,是用于生成PDF文档的一个java类库.通过iText不仅可以生成PDF或rtf的文档,而且可以将XML.Html文件转 ...

  8. typedef void far *LPVOID 的具体定义

    首先这里的far,在32位系统已经废除不用了.它是C/C++语言在16位系统中用以标明指针是个远指针的修饰符. 远指针是说指针所指向的地址已经超出了64K(2的十六次方),所以需要使用DS加偏移量的方 ...

  9. mysql date range

    http://stackoverflow.com/questions/9935690/mysql-datetime-range-query-issue " ";

  10. ASP.NET Cache 类

    在查找资料的过程中.原来园子里面已经有过分析了.nopCommerce架构分析系列(二)数据Cache. 接下来是一些学习补充. 1.Nop中没有System.Web.Caching.Cache的实现 ...