DATA类

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner; public class Data {
private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
public ArrayList<ArrayList<String>> getTrainData() {
return this.trainData;
} public Data() {
String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";
Scanner in;
try {
in = new Scanner(new File(dataPath));
while (in.hasNext()) {
String line=in.nextLine();
String []strs=line.trim().split(",");
ArrayList<String> tmp=new ArrayList<>();
for(int i=0;i<strs.length;i++)
{
tmp.add(strs[i]);
}
this.trainData.add(tmp);
}
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} } public static void main(String[] args) {
// TODO Auto-generated method stub
Data d =new Data(); } }

  TREE类

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.spi.TimeZoneNameProvider; public class Tree {
private Tree leftTree=new Tree();
private Tree rightTree=new Tree();
private double loss=-1;
private int attributeSplit=0;
private String attributeSplitType="";
boolean isLeaf;
double leafValue;
private ArrayList<Integer> leafNodeSet=new ArrayList<>(); public ArrayList<String> getAttributeSet(ArrayList<ArrayList<String>> trainData,int idx)
{
HashSet<String> mySet=new HashSet<>();
ArrayList<String> ans =new ArrayList<>();
for(int i=0;i<trainData.size();i++)
{
mySet.add(trainData.get(i).get(idx));
} Iterator<String> it=mySet.iterator(); while(it.hasNext())
{
ans.add(it.next());
} return ans;
}
public boolean myCmpLess(String str1,String str2)
{
if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))
return true;
else return false; }
public double computeLoss(ArrayList<Double> values)
{
double loss=0;
for(int i=0;i<values.size();i++)
{
loss+=values.get(i);
}
double mean=loss/values.size();
loss=0;
for(int i=0;i<values.size();i++)
{
loss+=Math.pow(values.get(i)-mean,2);
}
return Math.sqrt(loss);
}
public double getPredictValue(int K, ArrayList<Integer> subIdx,ArrayList<Double> target) {
double ans=0;
double sum=0,sum1=0;
for(int i=0;i<subIdx.size();i++)
{
sum+=target.get(subIdx.get(i));
}
for(int i=0;i<subIdx.size();i++)
{
sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));
}
ans=(K-1)/K*sum/sum1;
return ans;
}
public double getPredictValue(Tree root)
{
return root.leafValue;
}
public double getPredictValue(Tree root,ArrayList<String> instance,Boolean isDigit[])
{ if(root.isLeaf)
return root.leafValue;
else if(isDigit[root.attributeSplit])
{
if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))
return getPredictValue(root.leftTree, instance, isDigit);
return getPredictValue(root.rightTree, instance, isDigit);
}
else
{
if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))
return getPredictValue(root.leftTree, instance, isDigit);
return getPredictValue(root.rightTree, instance, isDigit);
} }
public Tree constructTree(ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList<Integer> subIdx,ArrayList<ArrayList<String>> trainData,ArrayList<Double> target,int maxDepth[],int depth)
{ int n=trainData.size();
int dim=trainData.get(0).size();
ArrayList<Integer> leftTreeIdx=new ArrayList<>();
ArrayList<Integer> rightTreeIdx=new ArrayList<>(); if(depth<maxDepth[0])
{
/*
* 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割
* */
double loss=-1;
ArrayList<Integer> leftNodes=new ArrayList<>();
ArrayList<Integer> rightNodes=new ArrayList<>();
int attributeSplit=0;
String attributeSplitType=""; for(int i=0;i<dim;i++)//遍历所有的attribute
{
//得到该attribute下所有的distinct的值
ArrayList<String> myAttributeSet=new ArrayList<>();
ArrayList<String> subDigitAttribute=new ArrayList<>();
myAttributeSet=getAttributeSet(trainData, i);
if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割
{
while(subDigitAttribute.size()<splitPoints)
{
Random r=new Random();
int tmp=r.nextInt(myAttributeSet.size());
subDigitAttribute.add(myAttributeSet.get(tmp));
myAttributeSet.clear();
myAttributeSet=subDigitAttribute;
}
}
for(int j=0;j<myAttributeSet.size();j++)
{
for(int k=0;k<subIdx.size();k++)
{
if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))
{
leftTreeIdx.add(subIdx.get(k));
}
else
{
rightTreeIdx.add(subIdx.get(k));
}
}
ArrayList<Double> leftTarget=new ArrayList<>();
ArrayList<Double> rightTarget=new ArrayList<>();
for(int k=0;k<leftTreeIdx.size();k++)
leftTarget.add(target.get(leftTreeIdx.get(k)));
for(int k=0;k<rightTreeIdx.size();k++)
rightTarget.add(target.get(rightTreeIdx.get(k)));
double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);
if(loss<0||loss<lossTmp)
{
leftNodes.clear();
rightNodes.clear();
for(int k=0;k<leftTreeIdx.size();k++)
leftNodes.add(leftTreeIdx.get(k));
for(int k=0;k<rightTreeIdx.size();k++)
rightNodes.add(rightTreeIdx.get(k));
attributeSplit=i;
attributeSplitType=myAttributeSet.get(j);
} } } Tree tmpTree=new Tree();
tmpTree.attributeSplit=attributeSplit;
tmpTree.attributeSplitType=attributeSplitType;
tmpTree.loss=loss;
tmpTree.isLeaf=false;
tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);
tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);
return tmpTree; }
else
{
Tree tmpTree=new Tree();
tmpTree.isLeaf=true;
tmpTree.leafValue=getPredictValue(K, subIdx, target);
for(int i=0;i<subIdx.size();i++)
tmpTree.leafNodeSet.add(subIdx.get(i));
leafNodes.add(subIdx);
leafValues.add(tmpTree.leafValue);
return tmpTree;
}
} public static void main(String[] args) {
// TODO Auto-generated method stub
Tree aTree=new Tree();
} }

  

GBDT类

import java.rmi.server.SkeletonNotFoundException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set; public class GBDT { private ArrayList<ArrayList<String>> datas=new ArrayList<ArrayList<String>>();
private ArrayList<String> labelSets=new ArrayList<>();
private ArrayList<ArrayList<Double>> F=new ArrayList<ArrayList<Double>>();
private ArrayList<ArrayList<Double>> residual=new ArrayList<ArrayList<Double>>();
private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
private ArrayList<Integer> labelTrainData=new ArrayList<Integer>();
private int K;
private Boolean isDigit[];
private int dim;
private int n;
private double learningRate; private ArrayList<ArrayList<Tree>> trees=new ArrayList<ArrayList<Tree>>(); //存放所有的树 private int max_iter;
private double sampleRate;
private int maxDepth;
private int splitPoints; public void computeResidual(ArrayList<Integer> subId)
{
for(int i=0;i<subId.size();i++)
{
int idx=subId.get(i);
int y=0;
if(this.labelTrainData.get(idx)==-1) y=0;
else y=1;
double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));
double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;
this.residual.get(idx).set(0, y-p1);
this.residual.get(idx).set(1, y-p2);
}
}
public ArrayList<Integer> myrandom(int maxNum,int num)
{
ArrayList<Integer> ans=new ArrayList<>();
Set<Integer> mySet=new HashSet<>();
while(mySet.size()<num)
{
Random r=new Random();
int tmp=r.nextInt(maxNum);
mySet.add(tmp);
}
Iterator<Integer> it=mySet.iterator();
while(it.hasNext())
{
ans.add(it.next());
}
return ans;
} public GBDT()
{
this.max_iter=50;
this.sampleRate=0.8;
this.K=2;//2分类问题
this.maxDepth=6;
this.splitPoints=3;
this.learningRate=0.01;
getData();
} public void train()
{
for(int i=0;i<max_iter;i++)
{
ArrayList<Integer> subSet=new ArrayList<>();
int numSubset=(int)(this.n*this.sampleRate);
subSet=myrandom(this.n,numSubset);
computeResidual(subSet);
ArrayList<Double> target=new ArrayList<>();
ArrayList<Tree> tmpTree=new ArrayList<>();
int maxdepths[]={this.maxDepth};
for(int j=0;j<this.K;j++)
{
target.clear();
for(int k=0;k<subSet.size();k++)
{
target.add(residual.get(subSet.get(k)).get(j));
}
ArrayList<ArrayList<Integer>> leafNodes=new ArrayList<ArrayList<Integer>>();
ArrayList<Double> leafValues=new ArrayList<>();
Tree treeSub=new Tree();
Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);
tmpTree.add(iterTree);
updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);
} trees.add(tmpTree);
}
} public void updateFvalue(Boolean isDigit[], ArrayList<Integer> subIdx,ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int label,Tree root)
{
ArrayList<Integer> remainIdx=new ArrayList<>();
int arr[]=new int[this.n];
for(int i=0;i<this.n;i++)
arr[i]=i;
for(int i=0;i<subIdx.size();i++)
{
arr[subIdx.get(i)]=-1;
}
//求出不是用来训练树的余下集合
for(int i=0;i<this.n;i++)
{
if(arr[i]!=-1)
remainIdx.add(i);
}
for(int i=0;i<leafNodes.size();i++)
{
for(int j=0;j<leafNodes.get(i).size();j++)
{
this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));
}
}
for(int i=0;i<remainIdx.size();i++)
{
double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);
this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);
} } public boolean checkDigit(String str) {
for(int i=0;i<str.length();i++)
{
if(!(str.charAt(i)>='0'&&str.charAt(i)<='9'))
{
return false;
}
}
return true;
} public void getData() {
Data d =new Data();
this.datas=d.getTrainData();
this.dim=this.datas.get(0).size()-1;
this.isDigit=new Boolean[this.dim];
//遍历所有样本,去掉中间含有不是正常的数据
for(int i=0;i<this.datas.get(0).size()-1;i++)
labelSets.add(this.datas.get(0).get(i));
//保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串
for(int i=0;i<this.dim;i++)
{
if(checkDigit(this.datas.get(0).get(i)))
this.isDigit[i]=true;
else this.isDigit[i]=false;
}
//如果字符串==?说明是异常数据,这里做数据的清理
for(int i=1;i<this.datas.size();i++)
{
ArrayList<String> tmp=new ArrayList<>();
boolean flag=true;
for(int j=0;j<this.dim;j++)
{
if(datas.get(i).get(j).trim().equals("?"))
{
flag=false;
break;
}
}
if(!flag) continue;
if(datas.get(i).get(this.dim).trim().equals("?")) continue;
trainData.add(tmp);
if(datas.get(i).get(this.dim).trim().equals("<=50K"))
labelTrainData.add(-1);
else
labelTrainData.add(1); }
this.n=this.labelTrainData.size(); for(int i=0;i<this.datas.get(0).size()-1;i++)
labelSets.add(this.datas.get(0).get(i)); //初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了
for(int i=0;i<this.n;i++)
{
ArrayList<Double> arrTmp=new ArrayList<Double>();
for(int j=0;j<2;j++)
{
arrTmp.add(0.0);
}
this.F.add(arrTmp);
this.residual.add(arrTmp);
} } public static void main(String[] args) {
GBDT dGbdt=new GBDT();
dGbdt.getData();
System.err.println(dGbdt.n); }
}

  

java实现gbdt的更多相关文章

  1. Spark案例分析

    一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...

  2. 梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

    梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python) http://blog.csdn.net/liulingyuan6/article/details ...

  3. 决策树和基于决策树的集成方法(DT,RF,GBDT,XGBT)复习总结

    摘要: 1.算法概述 2.算法推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 内容: 1.算法概述 1.1 决策树(DT)是一种基本的分类和回归方法.在分类问题中它可以认为是if-the ...

  4. GBDT算法原理深入解析

    GBDT算法原理深入解析 标签: 机器学习 集成学习 GBM GBDT XGBoost 梯度提升(Gradient boosting)是一种用于回归.分类和排序任务的机器学习技术,属于Boosting ...

  5. 决策树和基于决策树的集成方法(DT,RF,GBDT,XGB)复习总结

    摘要: 1.算法概述 2.算法推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 内容: 1.算法概述 1.1 决策树(DT)是一种基本的分类和回归方法.在分类问题中它可以认为是if-the ...

  6. Spark2.0机器学习系列之6:GBDT(梯度提升决策树)、GBDT与随机森林差异、参数调试及Scikit代码分析

    概念梳理 GBDT的别称 GBDT(Gradient Boost Decision Tree),梯度提升决策树.     GBDT这个算法还有一些其他的名字,比如说MART(Multiple Addi ...

  7. 机器学习之——集成算法,随机森林,Bootsing,Adaboost,Staking,GBDT,XGboost

    集成学习 集成算法 随机森林(前身是bagging或者随机抽样)(并行算法) 提升算法(Boosting算法) GBDT(迭代决策树) (串行算法) Adaboost (串行算法) Stacking ...

  8. 故障重现(内存篇2),JAVA内存不足导致频繁回收和swap引起的性能问题

    背景起因: 记起以前的另一次也是关于内存的调优分享下   有个系统平时运行非常稳定运行(没经历过大并发考验),然而在一次活动后,人数并发一上来后,系统开始卡. 我按经验开始调优,在每个关键步骤的加入如 ...

  9. Elasticsearch之java的基本操作一

    摘要   接触ElasticSearch已经有一段了.在这期间,遇到很多问题,但在最后自己的不断探索下解决了这些问题.看到网上或多或少的都有一些介绍ElasticSearch相关知识的文档,但个人觉得 ...

随机推荐

  1. POJ 1365 Prime Land(整数拆分)

    题意:感觉题意不太好懂,题目并不难,就是给一些p和e,p是素数,e是指数,然后把这个数求出来,设为x,然后让我们逆过程输出x-1的素数拆分形式,形式与输入保持一致. 思路:素数打表以后正常拆分即可. ...

  2. UIImage 和 UIImageView区别

    // // ViewController.m // 06-UIImage 和 UIImageView // // Created by Stephen on 16/4/18. // Copyright ...

  3. eclipse 配置scala问题-More than one scala library found in the build path

    配置eclipse出错解决 山重水复疑无路,柳暗花明又一村!经过大量的验证...终于make it. 参考博客:http://blog.csdn.net/wankunde/article/detail ...

  4. Win7下配置Django+Apache+mod_wsgi+Sqlite

    搭建环境: win7 64位 Django 1.8.5 Apache2.4.17 mod_wsgi_ap24py27.so Python2.7.9 1 安装Apache 下载Apache Haus版, ...

  5. 激活OFFICE2010时,提示choice.exe不是有效的win32程序

    我在安装office2010破解版时,提示choice.exe不是有效的win32应用程序 删除choice.exe再激活,按提示找到目录删掉这个文件,需要设置显示隐藏文件夹

  6. 随机法解决TSP问题

    TSP问题一直是个头疼的问题,但是解决的方法数不胜数,很多的算法也都能解决.百度资料一大堆,但是我找到了代码比较简练的一种.随机法.下面只是个人的看法而已,如果有任何问题虚心接受. 顾名思义,随机法就 ...

  7. hdu_5589_Tree(莫队+字典树)

    题目连接:hdu_5589_Tree 题意:给你一棵树和一些边值,n个点n-1条边,一个m,q个询问,每个询问让你输出在[l,r]区间内任意两点树上的路径的边权异或的和大于m的点对数. 题解:这题很巧 ...

  8. HTML center tag

    <center>This text will be center-aligned.</center> 或者可以把一个div给center了,例如将一个html表格给center ...

  9. VC分发包版本问题

    来源:http://www.cnblogs.com/mixiyou/archive/2010/02/09/1663620.html 之前曾经写过一篇个人经历,是关于VC2005分发包版本不一致而引起应 ...

  10. C++ Builder string相互转换(转)

    源:http://www.cnblogs.com/zhcncn/archive/2013/05/20/3089084.html 1. char*->string (1)直接转换 const ch ...