决策树ID3算法的java实现
决策树的分类过程和人的决策过程比较相似,就是先挑“权重”最大的那个考虑,然后再往下细分。比如你去看医生,症状是流鼻涕,咳嗽等,那么医生就会根据你的流鼻涕这个权重最大的症状先认为你是感冒,接着再根据你咳嗽等症状细分你是否为病毒性感冒等等。决策树的过程其实也是基于极大似然估计。那么我们用一个什么标准来衡量某个特征是权重最大的呢,这里有信息增益和基尼系数两个。ID3算法采用的是信息增益这个量。
根据《统计学习方法》中的描述,G(D,A)表示数据集D在特征A的划分下的信息增益。具体公式:
G(D,A)=H(D)-H(D|A)。其中H(D)表示数据集D的熵,熵可以用来描述其混乱度,计算公式为
H(D)=可见对于数据集D而言,|Dk|表示类别为k的数量,其类别越多,越混乱。
而H(D|A)表示数据集D在A的划分下的的不确定性。他们的差也即是信息增益,表示由于特征A使得数据集D的分类的不确定减少的差,所以这个值越大说明A的分类对D越有效,也就是权重越大。
H(D|A)=|Dik|表示在特征A中value为i的划分下数据集类别为k的数量。
有了这两个公式,接下来就可以写代码了。这里为了清晰的表示这个结果,采用了xml来输出。由于刚开始学java所以只能即学即用(java和python简直不能比,python写ID3一百行代码妥妥的搞定,java用了将近300行。。。)
算法步骤:
输入:数据集D,特征集A(这里也可以输入一个阀值,如果信息增益小于该阀值就直接作为叶节点,这样可以避免过拟合)
输出:xml文件
1 如果D中的类别是同一类,则作为叶节点,标记为该类Ck
2 如果特征集A中没有特征了,那么作为叶节点,并且用数据集D中类别最多的类作为类标记
3 对D的各个特征求最大信息增益,选择信息增益最大的特征Ag
4 对特征Ag中各个值ai继续对数据集进行分割为Di
5 以Di为数据集,A-Ag为特征集为输入进行1-4步骤
具体代码:
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set; import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.XMLWriter; class Utils{
//用于从文件中获取数据集
public static ArrayList<ArrayList<String>> loadDataSet(String file) throws IOException{
ArrayList<ArrayList<String>> dataSet=new ArrayList<ArrayList<String>>();
FileInputStream fis=new FileInputStream(file);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr);
String line="";
line=br.readLine();
while((line=br.readLine())!=null){
String[] words=line.split(",");
ArrayList<String> data=new ArrayList<String>();
for(int i=0;i<words.length;i++){
data.add(words[i]);
}
dataSet.add(data);
}
br.close();
isr.close();
fis.close();
return dataSet;
}
//用于从文件中获取特征
public static ArrayList<String> loadFeature(String file) throws IOException{
FileInputStream fis=new FileInputStream(file);
InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
BufferedReader br=new BufferedReader(isr); String[] line=br.readLine().split(",");
ArrayList<String> feature=new ArrayList<String>();
for(int i=0;i<line.length-1;i++){
feature.add(line[i]);
}
br.close();
isr.close();
fis.close();
return feature;
}
//用于获得数据集中的类别列表
public static ArrayList<String> getClassList(ArrayList<ArrayList<String>> dataSet){
ArrayList<String> classList=new ArrayList<String>();
int length=dataSet.get(0).size();
for(ArrayList<String> data:dataSet){
String label=data.get(length-1);
classList.add(label);
}
return classList;
}
//返回数据集中的特征数
public static int featureNum(ArrayList<ArrayList<String>> dataList){
int len=dataList.get(0).size()-1;
return len;
} // public static void writeToXML(String fileName) throws IOException{
// Document document = DocumentHelper.createDocument();
// Element root = document.addElement("DecisionTree");
// Element outlook=root.addElement("outlook");
// outlook.addAttribute("value","sunny");
// Element humidity1=outlook.addElement("humidity");
// humidity1.addAttribute("value","high");
// humidity1.addText("no");
// Element humidity2=outlook.addElement("humidity");
// humidity2.addAttribute("value","normal");
// humidity2.addText("yes");
//
// XMLWriter writer=new XMLWriter(new FileWriter(fileName));
// writer.write(document);
// writer.close();
// }
//用于获得数据集中第index列的map映射,方便后续的遍历value和计算熵
public static Map<String,Integer> getSubMap(ArrayList<ArrayList<String>> dataSet,int index){
int total=dataSet.size();
Map<String,Integer> subMap=new HashMap();
for(ArrayList<String> data:dataSet){
String lable=data.get(index);
if(subMap.get(lable)==null){
subMap.put(lable,1);
}else{
subMap.put(lable,subMap.get(lable)+1);
}
}
return subMap;
}
//打印map,用于debug的时候
public static void showMap(Map<String,Integer> map){
for(Map.Entry<String,Integer> entry:map.entrySet()){
System.out.println(entry.getKey()+":"+entry.getValue());
}
}
//求熵
public static double getEntropy(ArrayList<ArrayList<String>> dataSet,int index){
int total=dataSet.size();
Map<String,Integer> subMap=getSubMap(dataSet,index);
double entropy=0;
for(Map.Entry<String,Integer> entry:subMap.entrySet()){
double temp=entry.getValue()*1.0/total;
entropy+=temp*(Math.log(temp)/Math.log(2));
}
return -entropy;
}
//求信息增益最大的分割点
public static String bestFeatureSplit(ArrayList<ArrayList<String>> dataSet,ArrayList<String> featureList){
int length=dataSet.get(0).size();
double totalEntropy=getEntropy(dataSet,length-1); int featureNum=dataSet.get(0).size()-1;
int index=-1;
double maxInfoGain=-1;
for(int i=0;i<featureNum;i++){
double entropy=getEntropy(dataSet,i);
Map<String,Integer> map=getSubMap(dataSet,i);//获得该特征下的map
ArrayList<String> lableList=new ArrayList<String>();
double entropySum=0; for(Map.Entry<String,Integer> entry:map.entrySet()){//这里的Di就是map中的特征的value值
Map<String,Integer> subMap=new HashMap(); for(ArrayList<String> data:dataSet){
if(data.get(i).compareTo(entry.getKey())==0){
if(subMap.get(data.get(length-1))==null){ subMap.put(data.get(length-1),1);
}else{
subMap.put(data.get(length-1),subMap.get(data.get(length-1))+1);
}
}
}
double x=0;
for(Map.Entry<String,Integer> subEntry:subMap.entrySet()){
double temp=subEntry.getValue()*1.0/entry.getValue();
x+=temp*(Math.log(temp)/Math.log(2));
} entropySum+=-x*(entry.getValue())/dataSet.size();
}
entropySum=totalEntropy-entropySum;
if(entropySum>maxInfoGain){
index=i;
maxInfoGain=entropySum;
}
}
return featureList.get(index);
}
//分割数据集,index为特征的下标
public static ArrayList<ArrayList<String>> splitDataSet(ArrayList<ArrayList<String>> dataSet,int index,String value){
ArrayList<ArrayList<String>> subDataSet=new ArrayList<ArrayList<String>>();
for(ArrayList<String> data:dataSet){
if(data.get(index).compareTo(value)==0){
ArrayList<String> temp=new ArrayList<String>();
for(int i=0;i<data.size();i++){
if(i!=index){
temp.add(data.get(i));
}
}
subDataSet.add(temp);
}
}
return subDataSet;
}
//list-》map
public static Map<String,Integer> arrayToMap(ArrayList<String> list){
Map<String,Integer> map=new HashMap();
for(String word:list){
if(map.get(word)==null){
map.put(word,1);
}else{
map.put(word,map.get(word)+1);
}
}
return map;
}
//求label中某个数量最多的类别
public static String major(ArrayList<String> labelList){
Map<String,Integer> map=arrayToMap(labelList);
int max=0;
String label="";
for(Map.Entry<String,Integer> entry:map.entrySet()){
if(entry.getValue()>max){
label=entry.getKey();
}
}
return label;
} public static Set<String> getValueFromDataSet(ArrayList<ArrayList<String>> dataSet,int index){
ArrayList<String> values=new ArrayList<String>();
for(ArrayList<String> data:dataSet){
try{
values.add(data.get(index));
}catch(Exception e){ System.out.println("index is "+index);
}
}
Set<String> set=new HashSet();
for(String value:values){
set.add(value);
}
return set;
} public static ArrayList<String> copyArrayList(ArrayList<String> src){
ArrayList<String> dest=new ArrayList<String>();
for(String s:src){
dest.add(s);
}
return dest;
} public static void showArrayList(ArrayList<ArrayList<String>> dataSet){
for(ArrayList<String> data:dataSet){
System.out.println(data);
}
} } public class DecisionTree { public static int createTree(ArrayList<ArrayList<String>> dataSet,ArrayList<String> featureList,Element e){
ArrayList<String> labelList=Utils.getClassList(dataSet);//获取数据集中label的列表
if(Utils.arrayToMap(labelList).size()==1){//表示label中只有一种类别,所以此时不需要再分类了
e.addText(labelList.get(0));
return 1;
}
if(dataSet.get(0).size()==1){//表示此时已经没有特征了,所以也不需要再继续了,此时以label中最多的类别为该节点的类别
e.addText(Utils.major(labelList));
return 1;
} ArrayList<String> subFeatureList=Utils.copyArrayList(featureList); String feature=Utils.bestFeatureSplit(dataSet,featureList);
subFeatureList.remove(feature);
int index=featureList.indexOf(feature); Set<String> valueSet=Utils.getValueFromDataSet(dataSet,index);
// Element next=e.addElement(feature);//原来的代码位置
for(String value:valueSet){
Element next=e.addElement(feature);//后来放到这里之后,xml的输出就正确了,原因在于每递归一次就需要创建一个element,所以应该在for内创建。
next.addAttribute("value",value);
ArrayList<ArrayList<String>> subDataSet=Utils.splitDataSet(dataSet,index,value);
createTree(subDataSet,subFeatureList,next);
}
return 1;
} public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
String file="C:/Users/Administrator/Desktop/upload/DT.txt";
String xml="C:/Users/Administrator/Desktop/upload/DT1.xml";
ArrayList<ArrayList<String>> dataSet=Utils.loadDataSet(file);
ArrayList<String> featureList=Utils.loadFeature(file);
Document document = DocumentHelper.createDocument();
Element root = document.addElement("DecisionTree");
createTree(dataSet,featureList,root);
XMLWriter writer=new XMLWriter(new FileWriter(xml));
writer.write(document);
writer.close();
System.out.println("finished");
} }
这次除了算法上的理解更加深刻了外,在java上也学到了些关于xml解析,读写等方法。
另外对递归的使用也更加形象些,对于递归一个容易错的点就是函数上的参数,一定要认真对待,要清楚该参数该在什么时候初始化,什么时候被用到。我一开始在第269行上就出现错误了,一开始没有考虑清楚这个next该在什么时候分配,后来发现每次创建节点的时候我们在xml就要创建一个相应的节点用来描述他,所以应该是在for循环里面创建,如果在for外面创建就表示,该特征下的所有值都只有一个element。
当然对于set,map的遍历啥的也更加清晰了。
决策树ID3算法的java实现的更多相关文章
- 决策树ID3算法的java实现(基本试用所有的ID3)
已知:流感训练数据集,预定义两个类别: 求:用ID3算法建立流感的属性描述决策树 流感训练数据集 No. 头痛 肌肉痛 体温 患流感 1 是(1) 是(1) 正常(0) 否(0) 2 是(1) 是(1 ...
- 决策树ID3算法的java实现(基本适用所有的ID3)
已知:流感训练数据集,预定义两个类别: 求:用ID3算法建立流感的属性描述决策树 流感训练数据集 No. 头痛 肌肉痛 体温 患流感 1 是(1) 是(1) 正常(0) 否(0) 2 是(1) 是(1 ...
- 数据挖掘之决策树ID3算法(C#实现)
决策树是一种非常经典的分类器,它的作用原理有点类似于我们玩的猜谜游戏.比如猜一个动物: 问:这个动物是陆生动物吗? 答:是的. 问:这个动物有鳃吗? 答:没有. 这样的两个问题顺序就有些颠倒,因为一般 ...
- 决策树ID3算法[分类算法]
ID3分类算法的编码实现 <?php /* *决策树ID3算法(分类算法的实现) */ /* *求信息增益Grain(S1,S2) */ //-------------------------- ...
- 决策树---ID3算法(介绍及Python实现)
决策树---ID3算法 决策树: 以天气数据库的训练数据为例. Outlook Temperature Humidity Windy PlayGolf? sunny 85 85 FALSE no ...
- 02-21 决策树ID3算法
目录 决策树ID3算法 一.决策树ID3算法学习目标 二.决策树引入 三.决策树ID3算法详解 3.1 if-else和决策树 3.2 信息增益 四.决策树ID3算法流程 4.1 输入 4.2 输出 ...
- 机器学习之决策树(ID3)算法与Python实现
机器学习之决策树(ID3)算法与Python实现 机器学习中,决策树是一个预测模型:他代表的是对象属性与对象值之间的一种映射关系.树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每 ...
- ID3算法(Java实现)
数据存储文件:buycomputer.properties #数据个数 datanum=14 #属性及属性值 nodeAndAttribute=年龄:青/中/老,收入:高/中/低,学生:是/否,信誉: ...
- 决策树 -- ID3算法小结
ID3算法(Iterative Dichotomiser 3 迭代二叉树3代),是一个由Ross Quinlan发明的用于决策树的算法:简单理论是越是小型的决策树越优于大的决策树. 算法归 ...
随机推荐
- 【stack】模拟网页浏览 poj 1028
#include<stdio.h> #include<string.h> int main() { ][]; ]; int i,depth; strcpy(s[]," ...
- java解析XML获取城市代码
运行前先导入dom4j架包,由于我们公司用的代理服务器所以下面我设置了代理ip,不需要的可直接忽略 package com.chengshidaima.tools; import java.io.Bu ...
- Java的引用c++的引用和C指针的区别
Java的引用本质上就是C中的指针,而c++的引用则完全不同:有一个类 class Point { int x; int y;} 同样的一个Point p; 在Java中p表示一个引用,它等同于C语言 ...
- easyui formatter 返回easyui组件
<table id="dg2" title="标题" style="width:400px;float: left;"> < ...
- HTML5编程之旅系列一:HTML5 Geolocation 初探
让我们假设这样一个场景,有一个web应用程序,它可以向用户提供附近不远处某商场的打折优惠信息.使用HTML5 Geolocation API(地理定位 API),可以请求用户共享他们的位置信息. HT ...
- android性能优化优秀文章
郭霖最近整理的文章: 合理管理内存 分析编码过程中如何避免过多内存占用,以及如何实现高性能的内存使用. 尽可能使用intentService; 当界面不可见时释放内存; 当内存紧张时释放内存; 避免b ...
- Json解析要点
解析Json时,昨天遇到了新的问题,之前都是解析的数组,不是数组的用类来做. 这是Json串; {"status":"00001","ver" ...
- Linux RAR 安装和使用
描述:Linux默认自带ZIP压缩,最大支持4GB压缩,RAR的压缩比大于4GB. 流程:下载 >安装 > 使用 ------------------------------------- ...
- AJAX开发技术--AJAX简介
Asynchronous JavaScript and XML,异步JavaScript和XML 主要目的用于页面的局部刷新.不用全部刷新,提高性能. 在AJAX中主要是通过XMLHttpReque ...
- 转 如何高效使用和管理Bitmap--图片缓存管理模块的设计与实现
上周为360全景项目引入了图片缓存模块.因为是在Android4.0平台以上运作,出于惯性,都会在设计之前查阅相关资料,尽量避免拿一些以前2.3平台积累的经验来进行类比处理.开发文档中有一个 Bitm ...