java实现gbdt
DATA类
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner; public class Data {
private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
public ArrayList<ArrayList<String>> getTrainData() {
return this.trainData;
} public Data() {
String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";
Scanner in;
try {
in = new Scanner(new File(dataPath));
while (in.hasNext()) {
String line=in.nextLine();
String []strs=line.trim().split(",");
ArrayList<String> tmp=new ArrayList<>();
for(int i=0;i<strs.length;i++)
{
tmp.add(strs[i]);
}
this.trainData.add(tmp);
}
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} } public static void main(String[] args) {
// TODO Auto-generated method stub
Data d =new Data(); } }
TREE类
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.spi.TimeZoneNameProvider; public class Tree {
private Tree leftTree=new Tree();
private Tree rightTree=new Tree();
private double loss=-1;
private int attributeSplit=0;
private String attributeSplitType="";
boolean isLeaf;
double leafValue;
private ArrayList<Integer> leafNodeSet=new ArrayList<>(); public ArrayList<String> getAttributeSet(ArrayList<ArrayList<String>> trainData,int idx)
{
HashSet<String> mySet=new HashSet<>();
ArrayList<String> ans =new ArrayList<>();
for(int i=0;i<trainData.size();i++)
{
mySet.add(trainData.get(i).get(idx));
} Iterator<String> it=mySet.iterator(); while(it.hasNext())
{
ans.add(it.next());
} return ans;
}
public boolean myCmpLess(String str1,String str2)
{
if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))
return true;
else return false; }
public double computeLoss(ArrayList<Double> values)
{
double loss=0;
for(int i=0;i<values.size();i++)
{
loss+=values.get(i);
}
double mean=loss/values.size();
loss=0;
for(int i=0;i<values.size();i++)
{
loss+=Math.pow(values.get(i)-mean,2);
}
return Math.sqrt(loss);
}
public double getPredictValue(int K, ArrayList<Integer> subIdx,ArrayList<Double> target) {
double ans=0;
double sum=0,sum1=0;
for(int i=0;i<subIdx.size();i++)
{
sum+=target.get(subIdx.get(i));
}
for(int i=0;i<subIdx.size();i++)
{
sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));
}
ans=(K-1)/K*sum/sum1;
return ans;
}
public double getPredictValue(Tree root)
{
return root.leafValue;
}
public double getPredictValue(Tree root,ArrayList<String> instance,Boolean isDigit[])
{ if(root.isLeaf)
return root.leafValue;
else if(isDigit[root.attributeSplit])
{
if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))
return getPredictValue(root.leftTree, instance, isDigit);
return getPredictValue(root.rightTree, instance, isDigit);
}
else
{
if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))
return getPredictValue(root.leftTree, instance, isDigit);
return getPredictValue(root.rightTree, instance, isDigit);
} }
public Tree constructTree(ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList<Integer> subIdx,ArrayList<ArrayList<String>> trainData,ArrayList<Double> target,int maxDepth[],int depth)
{ int n=trainData.size();
int dim=trainData.get(0).size();
ArrayList<Integer> leftTreeIdx=new ArrayList<>();
ArrayList<Integer> rightTreeIdx=new ArrayList<>(); if(depth<maxDepth[0])
{
/*
* 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割
* */
double loss=-1;
ArrayList<Integer> leftNodes=new ArrayList<>();
ArrayList<Integer> rightNodes=new ArrayList<>();
int attributeSplit=0;
String attributeSplitType=""; for(int i=0;i<dim;i++)//遍历所有的attribute
{
//得到该attribute下所有的distinct的值
ArrayList<String> myAttributeSet=new ArrayList<>();
ArrayList<String> subDigitAttribute=new ArrayList<>();
myAttributeSet=getAttributeSet(trainData, i);
if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割
{
while(subDigitAttribute.size()<splitPoints)
{
Random r=new Random();
int tmp=r.nextInt(myAttributeSet.size());
subDigitAttribute.add(myAttributeSet.get(tmp));
myAttributeSet.clear();
myAttributeSet=subDigitAttribute;
}
}
for(int j=0;j<myAttributeSet.size();j++)
{
for(int k=0;k<subIdx.size();k++)
{
if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))
{
leftTreeIdx.add(subIdx.get(k));
}
else
{
rightTreeIdx.add(subIdx.get(k));
}
}
ArrayList<Double> leftTarget=new ArrayList<>();
ArrayList<Double> rightTarget=new ArrayList<>();
for(int k=0;k<leftTreeIdx.size();k++)
leftTarget.add(target.get(leftTreeIdx.get(k)));
for(int k=0;k<rightTreeIdx.size();k++)
rightTarget.add(target.get(rightTreeIdx.get(k)));
double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);
if(loss<0||loss<lossTmp)
{
leftNodes.clear();
rightNodes.clear();
for(int k=0;k<leftTreeIdx.size();k++)
leftNodes.add(leftTreeIdx.get(k));
for(int k=0;k<rightTreeIdx.size();k++)
rightNodes.add(rightTreeIdx.get(k));
attributeSplit=i;
attributeSplitType=myAttributeSet.get(j);
} } } Tree tmpTree=new Tree();
tmpTree.attributeSplit=attributeSplit;
tmpTree.attributeSplitType=attributeSplitType;
tmpTree.loss=loss;
tmpTree.isLeaf=false;
tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);
tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);
return tmpTree; }
else
{
Tree tmpTree=new Tree();
tmpTree.isLeaf=true;
tmpTree.leafValue=getPredictValue(K, subIdx, target);
for(int i=0;i<subIdx.size();i++)
tmpTree.leafNodeSet.add(subIdx.get(i));
leafNodes.add(subIdx);
leafValues.add(tmpTree.leafValue);
return tmpTree;
}
} public static void main(String[] args) {
// TODO Auto-generated method stub
Tree aTree=new Tree();
} }
GBDT类
import java.rmi.server.SkeletonNotFoundException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set; public class GBDT { private ArrayList<ArrayList<String>> datas=new ArrayList<ArrayList<String>>();
private ArrayList<String> labelSets=new ArrayList<>();
private ArrayList<ArrayList<Double>> F=new ArrayList<ArrayList<Double>>();
private ArrayList<ArrayList<Double>> residual=new ArrayList<ArrayList<Double>>();
private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
private ArrayList<Integer> labelTrainData=new ArrayList<Integer>();
private int K;
private Boolean isDigit[];
private int dim;
private int n;
private double learningRate; private ArrayList<ArrayList<Tree>> trees=new ArrayList<ArrayList<Tree>>(); //存放所有的树 private int max_iter;
private double sampleRate;
private int maxDepth;
private int splitPoints; public void computeResidual(ArrayList<Integer> subId)
{
for(int i=0;i<subId.size();i++)
{
int idx=subId.get(i);
int y=0;
if(this.labelTrainData.get(idx)==-1) y=0;
else y=1;
double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));
double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;
this.residual.get(idx).set(0, y-p1);
this.residual.get(idx).set(1, y-p2);
}
}
public ArrayList<Integer> myrandom(int maxNum,int num)
{
ArrayList<Integer> ans=new ArrayList<>();
Set<Integer> mySet=new HashSet<>();
while(mySet.size()<num)
{
Random r=new Random();
int tmp=r.nextInt(maxNum);
mySet.add(tmp);
}
Iterator<Integer> it=mySet.iterator();
while(it.hasNext())
{
ans.add(it.next());
}
return ans;
} public GBDT()
{
this.max_iter=50;
this.sampleRate=0.8;
this.K=2;//2分类问题
this.maxDepth=6;
this.splitPoints=3;
this.learningRate=0.01;
getData();
} public void train()
{
for(int i=0;i<max_iter;i++)
{
ArrayList<Integer> subSet=new ArrayList<>();
int numSubset=(int)(this.n*this.sampleRate);
subSet=myrandom(this.n,numSubset);
computeResidual(subSet);
ArrayList<Double> target=new ArrayList<>();
ArrayList<Tree> tmpTree=new ArrayList<>();
int maxdepths[]={this.maxDepth};
for(int j=0;j<this.K;j++)
{
target.clear();
for(int k=0;k<subSet.size();k++)
{
target.add(residual.get(subSet.get(k)).get(j));
}
ArrayList<ArrayList<Integer>> leafNodes=new ArrayList<ArrayList<Integer>>();
ArrayList<Double> leafValues=new ArrayList<>();
Tree treeSub=new Tree();
Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);
tmpTree.add(iterTree);
updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);
} trees.add(tmpTree);
}
} public void updateFvalue(Boolean isDigit[], ArrayList<Integer> subIdx,ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int label,Tree root)
{
ArrayList<Integer> remainIdx=new ArrayList<>();
int arr[]=new int[this.n];
for(int i=0;i<this.n;i++)
arr[i]=i;
for(int i=0;i<subIdx.size();i++)
{
arr[subIdx.get(i)]=-1;
}
//求出不是用来训练树的余下集合
for(int i=0;i<this.n;i++)
{
if(arr[i]!=-1)
remainIdx.add(i);
}
for(int i=0;i<leafNodes.size();i++)
{
for(int j=0;j<leafNodes.get(i).size();j++)
{
this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));
}
}
for(int i=0;i<remainIdx.size();i++)
{
double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);
this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);
} } public boolean checkDigit(String str) {
for(int i=0;i<str.length();i++)
{
if(!(str.charAt(i)>='0'&&str.charAt(i)<='9'))
{
return false;
}
}
return true;
} public void getData() {
Data d =new Data();
this.datas=d.getTrainData();
this.dim=this.datas.get(0).size()-1;
this.isDigit=new Boolean[this.dim];
//遍历所有样本,去掉中间含有不是正常的数据
for(int i=0;i<this.datas.get(0).size()-1;i++)
labelSets.add(this.datas.get(0).get(i));
//保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串
for(int i=0;i<this.dim;i++)
{
if(checkDigit(this.datas.get(0).get(i)))
this.isDigit[i]=true;
else this.isDigit[i]=false;
}
//如果字符串==?说明是异常数据,这里做数据的清理
for(int i=1;i<this.datas.size();i++)
{
ArrayList<String> tmp=new ArrayList<>();
boolean flag=true;
for(int j=0;j<this.dim;j++)
{
if(datas.get(i).get(j).trim().equals("?"))
{
flag=false;
break;
}
}
if(!flag) continue;
if(datas.get(i).get(this.dim).trim().equals("?")) continue;
trainData.add(tmp);
if(datas.get(i).get(this.dim).trim().equals("<=50K"))
labelTrainData.add(-1);
else
labelTrainData.add(1); }
this.n=this.labelTrainData.size(); for(int i=0;i<this.datas.get(0).size()-1;i++)
labelSets.add(this.datas.get(0).get(i)); //初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了
for(int i=0;i<this.n;i++)
{
ArrayList<Double> arrTmp=new ArrayList<Double>();
for(int j=0;j<2;j++)
{
arrTmp.add(0.0);
}
this.F.add(arrTmp);
this.residual.add(arrTmp);
} } public static void main(String[] args) {
GBDT dGbdt=new GBDT();
dGbdt.getData();
System.err.println(dGbdt.n); }
}
java实现gbdt的更多相关文章
- Spark案例分析
一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...
- 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)
梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python) http://blog.csdn.net/liulingyuan6/article/details ...
- 决策树和基于决策树的集成方法(DT,RF,GBDT,XGBT)复习总结
摘要: 1.算法概述 2.算法推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 内容: 1.算法概述 1.1 决策树(DT)是一种基本的分类和回归方法.在分类问题中它可以认为是if-the ...
- GBDT算法原理深入解析
GBDT算法原理深入解析 标签: 机器学习 集成学习 GBM GBDT XGBoost 梯度提升(Gradient boosting)是一种用于回归.分类和排序任务的机器学习技术,属于Boosting ...
- 决策树和基于决策树的集成方法(DT,RF,GBDT,XGB)复习总结
摘要: 1.算法概述 2.算法推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 内容: 1.算法概述 1.1 决策树(DT)是一种基本的分类和回归方法.在分类问题中它可以认为是if-the ...
- Spark2.0机器学习系列之6:GBDT(梯度提升决策树)、GBDT与随机森林差异、参数调试及Scikit代码分析
概念梳理 GBDT的别称 GBDT(Gradient Boost Decision Tree),梯度提升决策树. GBDT这个算法还有一些其他的名字,比如说MART(Multiple Addi ...
- 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost
集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...
- 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题
背景起因: 记起以前的另一次也是关于内存的调优分享下 有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...
- Elasticsearch之java的基本操作一
摘要 接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...
随机推荐
- 学习笔记——代理模式Proxy
代理模式,主要是逻辑和实现解耦.具体逻辑如何,由代理Proxy自己来设计,我们只需要把逻辑Subject交给代理即可. 主要应用场景,包括创建大开销对象时,使用代理来慢慢创建:远程代理,如网络不确定时 ...
- 多校 Cow Bowling
题目链接:http://acm.hust.edu.cn/vjudge/contest/124435#problem/I 密码:acm Sample Input Sample Output 分析: #i ...
- arrayList里的快速失败
快速失败是指某个线程在迭代集合类的时候,不允许其他线程修改该集合类的内容,这样迭代器迭代出来的结果就会不准确. 比如用iterator迭代collection的时候,iterator就是另外起的一个线 ...
- PAT1014
Suppose a bank has N windows open for service. 一个银行有N个服务的窗口 There is a yellow line in front of the w ...
- android开发中应该注意的问题
1. Activity可继承自BaseActivity,便于统一风格与处理公共事件,构建对话框统一构建器的建立,万一需要整体变动,一处修改到处有效. 2. 数据库表段字段常量和SQL逻辑分离,更清 ...
- Python3基础 list() 将一个元组转换成列表
镇场诗: 诚听如来语,顿舍世间名与利.愿做地藏徒,广演是经阎浮提. 愿尽吾所学,成就一良心博客.愿诸后来人,重现智慧清净体.-------------------------------------- ...
- Python3基础 函数 收集参数+普通参数 的示例
镇场诗: 诚听如来语,顿舍世间名与利.愿做地藏徒,广演是经阎浮提. 愿尽吾所学,成就一良心博客.愿诸后来人,重现智慧清净体.-------------------------------------- ...
- VS2015安装提示出现“安装包丢失或损坏”解决方法
原因:microsoft root certificate authority 2010.microsoft root certificate authority 2011证书未安装,导致文件校验未通 ...
- Struts2--带参数的结果集
带参数的结果集: 配置文件: <result type="redirect">/user_success.jsp?t=${type}</result> js ...
- android4.0 的图库Gallery2代码分析(一)
最近迫于生存压力,不得不给人兼职打工.故在博文中加了个求点击的链接.麻烦有时间的博友们帮我点击一下.没时间的不用勉强啊.不过请放心,我是做技术的,肯定链接没病毒,就是我打工的淘宝店铺.嘻嘻.http: ...