AdaBoost的java实现
目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。
这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。
关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList; class Stump{
public int dim;
public double thresh;
public String condition;
public double error;
public ArrayList<Integer> labelList;
double factor; public String toString(){
return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
}
} class Utils{
//加载数据集
public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
ArrayList<Double> data=new ArrayList<Double>();
String[] s=line.split(" "); for(int i=0;i<s.length-1;i++){
data.add(Double.parseDouble(s[i]));
}
dataSet.add(data);
}
return dataSet;
} //加载类别
public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
ArrayList<Integer> labelSet=new ArrayList<Integer>(); FileInputStream fis=new FileInputStream(filename);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line=""; while((line=br.readLine())!=null){
String[] s=line.split(" ");
labelSet.add(Integer.parseInt(s[s.length-1]));
}
return labelSet;
}
//测试用的
public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
for(ArrayList<Double> data:dataSet){
System.out.println(data);
}
}
//获取最大值,用于求步长
public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
double max=-9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)>max){
max=data.get(index);
}
}
return max;
}
//获取最小值,用于求步长
public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
double min=9999.0;
for(ArrayList<Double> data:dataSet){
if(data.get(index)<min){
min=data.get(index);
}
}
return min;
} //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别
public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
ArrayList<Integer> labelList=new ArrayList<Integer>();
if(condition.compareTo("lt")==0){
for(ArrayList<Double> data:dataSet){
if(data.get(feature)<=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}else{
for(ArrayList<Double> data:dataSet){
if(data.get(feature)>=thresh){
labelList.add(1);
}else{
labelList.add(-1);
}
}
}
return labelList;
}
//求预测类别与真实类别的加权误差
public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
double error=0; int n=real.size(); for(int i=0;i<fake.size();i++){
if(fake.get(i)!=real.get(i)){
error+=weights.get(i); }
} return error;
}
//构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。
public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
int featureNum=dataSet.get(0).size(); int rowNum=dataSet.size();
Stump stump=new Stump();
double minError=999.0;
System.out.println("第"+n+"次迭代");
for(int i=0;i<featureNum;i++){
double min=getMin(dataSet,i);
double max=getMax(dataSet,i);
double step=(max-min)/(rowNum);
for(double j=min-step;j<=max+step;j=j+step){
String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类
for(String condition:conditions){
ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); double error=Utils.getError(labelList,labelSet,weights);
if(error<minError){
minError=error;
stump.dim=i;
stump.thresh=j;
stump.condition=condition;
stump.error=minError;
stump.labelList=labelList;
stump.factor=0.5*(Math.log((1-error)/error));
} }
} } return stump;
} public static ArrayList<Double> getInitWeights(int n){
double weight=1.0/n;
ArrayList<Double> weights=new ArrayList<Double>();
for(int i=0;i<n;i++){
weights.add(weight);
}
return weights;
}
//更新样本权值
public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
double Z=0;
ArrayList<Double> newWeights=new ArrayList<Double>();
int row=labelList.size();
double e=Math.E;
double factor=stump.factor;
for(int i=0;i<row;i++){
Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
} for(int i=0;i<row;i++){
double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
newWeights.add(weight);
}
return newWeights;
}
//对加权误差累加
public static ArrayList<Double> InitAccWeightError(int n){
ArrayList<Double> accError=new ArrayList<Double>();
for(int i=0;i<n;i++){
accError.add(0.0);
}
return accError;
} public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
ArrayList<Integer> t=stump.labelList;
double factor=stump.factor;
ArrayList<Double> newAccError=new ArrayList<Double>();
for(int i=0;i<t.size();i++){
double a=accerror.get(i)+factor*t.get(i);
newAccError.add(a);
}
return newAccError;
} public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
ArrayList<Integer> a=new ArrayList<Integer>();
int wrong=0;
for(int i=0;i<accError.size();i++){
if(accError.get(i)>0){
if(labelList.get(i)==-1){
wrong++;
}
}else if(labelList.get(i)==1){
wrong++;
}
}
double error=wrong*1.0/accError.size();
return error;
} public static void showStumpList(ArrayList<Stump> G){
for(Stump s:G){
System.out.println(s);
System.out.println(" ");
}
}
} public class Adaboost { /**
* @param args
* @throws IOException
*/ public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
int row=labelList.size();
ArrayList<Double> weights=Utils.getInitWeights(row);
ArrayList<Stump> G=new ArrayList<Stump>();
ArrayList<Double> accError=Utils.InitAccWeightError(row);
int n=1;
while(true){
Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树
G.add(stump);
weights=Utils.updateWeights(stump,labelList,weights);//更新权值
accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了
double error=Utils.calErrorRate(accError,labelList);
if(error<0.001){
break;
}
n++;
}
return G;
} public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
Utils.showStumpList(G);
System.out.println("finished");
} }
这里的数据采用的是统计学习方法中的数据
0 1
1 1
2 1
3 -1
4 -1
5 -1
6 1
7 1
8 1
9 -1
这里是单个特征的,也可以是多维数据,例如
1.0 2.1 1
2.0 1.1 1
1.3 1.0 -1
1.0 1.0 -1
2.0 1.0 1
AdaBoost的java实现的更多相关文章
- Spark案例分析
一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...
- 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost
集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...
- smile——Java机器学习引擎
资源 https://haifengl.github.io/ https://github.com/haifengl/smile 介绍 Smile(统计机器智能和学习引擎)是一个基于Java和Scal ...
- 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题
背景起因: 记起以前的另一次也是关于内存的调优分享下 有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...
- Elasticsearch之java的基本操作一
摘要 接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...
- 论:开发者信仰之“天下IT是一家“(Java .NET篇)
比尔盖茨公认的IT界领军人物,打造了辉煌一时的PC时代. 2008年,史蒂夫鲍尔默接替了盖茨的工作,成为微软公司的总裁. 2013年他与微软做了最后的道别. 2013年以后,我才真正看到了微软的变化. ...
- 故障重现, JAVA进程内存不够时突然挂掉模拟
背景,服务器上的一个JAVA服务进程突然挂掉,查看产生了崩溃日志,如下: # Set larger code cache with -XX:ReservedCodeCacheSize= # This ...
- 死磕内存篇 --- JAVA进程和linux内存间的大小关系
运行个JAVA 用sleep去hold住 package org.hjb.test; public class TestOnly { public static void main(String[] ...
- 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用
有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...
随机推荐
- C#中partial关键字
1. 什么是局部类型? C# 2.0 引入了局部类型的概念.局部类型允许我们将一个类.结构或接口分成几个部分,分别实现在几个不同的.cs文件中. 局部类型适用于以下情况: (1) 类型特别大,不宜放在 ...
- C语言头文件
最近在工作当中遇到了一点小问题,关于C语言头文件的应用问题,主要还是关于全局变量的定义和声明问题.学习C语言已经有好几年了,工作使用也近半年了,但是对于这部分的东西的确还没有深入的思考过.概念上还是比 ...
- 利用Fiddler抓取手机APP数据包
Fiddler是一个调试代理,下载地址http://www.telerik.com/download/fiddler 下载安装运行后,查出运行机器的IP,手机连接同一网域内的WIFI,手机WIFI连接 ...
- thunk技术
Thunk : 将一段机器码对应的字节保存在一个连续内存结构里, 然后将其指针强制转换成函数. 即用作函数来执行,通常用来将对象的成员函数作为回调函数. #include "stdafx.h ...
- UIImageView 的contentMode属性 浅析
UIImageView 的contentMode这个属性是用来设置图片的显示方式,如居中.居右,是否缩放等,有以下几个常量可供设定:UIViewContentModeScaleToFillUIView ...
- iOS 开发之照片框架详解
转载自:http://kayosite.com/ios-development-and-detail-of-photo-framework.html 一. 概要 在 iOS 设备中,照片和视频是相当重 ...
- hdu_5724_Chess(组合博弈)
题目链接:hdu_5724_Chess 题意: 给你一个n行20列的棋盘,棋盘里面有些棋子,每个棋子每次只能往右走一步,如果右边有棋子,可以跳过去,前提是最右边有格子,如果当前选手走到没有棋子可以走了 ...
- 安卓开发之探秘蓝牙隐藏API(转)
源:http://www.cnblogs.com/xiaochao1234/p/3793172.html 上次讲解Android的蓝牙基本用法,这次讲得深入些,探讨下蓝牙方面的隐藏API.用过Andr ...
- Segment,Path,Ring和Polyline对象
Segment几何对象 Segment对象是一个有起点和终点的“线“,也就是说Segement只有两个点,至于两点之间的线是直的,还是曲的,需要其余的参数定义.所以Segment是由起点,终点和参 ...
- oracle中的赋权
1 怎么给用户赋权限 grant create view to scott; (create view 是权限的名称) 2 怎么给用户撤销权限 revoke create view from scot ...