ID3是以信息增益作为划分训练数据集的特征,即认为信息增益大的特征是对分类结果影响更大,但是信息增益的方法偏向于选择取值较多的特征,因此引入了C4.5决策树,也就是使用信息增益率(比)来作为划分数据集的特征,信息增益率定义如下:

就是在ID3中已经计算出特征A的信息增益之后再除一个熵HA(D),HA(D)的计算例子如下图所示:

对应的数据集是:

例子来自:http://baike.baidu.com/link?url=uVS7uFMB44R86TEdRzwwpNWsmzQtA3ds88X0CLYLN0C-8bmS-OAlOFnpD8PNv6pdD_SvWOIpV8UMKQRpVu4tHK

以下是代码实现:

  1. //import java.awt.color.ICC_ColorSpace;
  2. import java.io.*;
  3. import java.util.ArrayList;
  4. import java.util.Collections;
  5. import java.util.Comparator;
  6. import java.util.HashMap;
  7. import java.util.HashSet;
  8. import java.util.Iterator;
  9. //import java.util.Iterator;
  10. import java.util.List;
  11. //import java.util.Locale.Category;
  12. import java.util.Map;
  13. import java.util.Map.Entry;
  14. import java.util.Set;
  15. class decisionTree{
  16.  
  17. private static Map<String, Map<String, Integer>> featureValuesAndCounts=new HashMap<String, Map<String,Integer>>();
  18. private static ArrayList<String[]> dataSet=new ArrayList<String[]>();
  19. private static ArrayList<String> features=new ArrayList<String>();
  20. private static Set<String> category=new HashSet<String>();
  21. //public static DecisionNode root=new DecisionNode();
  22. //private static ArrayList<ArrayList<String>> featureValue=new ArrayList<ArrayList<String>>();
  23. public static void GetDataSet()
  24. {
  25. File file = new File("C:\\Users\\hfz\\workspace\\decisionTree\\src\\loan.txt");
  26. try{
  27. BufferedReader br = new BufferedReader(new FileReader(file));//
  28. String s = null;
  29. s=br.readLine();//读取第一行的内容,即是各特征的名称
  30. String[] tempFeatures=s.split(",");
  31. for (String string1 : tempFeatures) {
  32. features.add(string1);
  33. }
  34. s=br.readLine(); //开始读取特征值
  35. String[] tt=null;
  36. int flag=s.length();
  37. while(flag!=0){//英文文档读到结尾得到的值是null,而中文文档读到结尾得到的值却是""
  38. tt=s.split(",");
  39. dataSet.add(tt); //将特征值存入
  40. category.add(tt[tt.length-1]);//category为集合类型,用于存储类型值
  41.  
  42. s=br.readLine();
  43. if (s!=null) {
  44. flag = s.length();
  45. }
  46. else{
  47. flag=0;
  48. }
  49.  
  50. }
  51.  
  52. for (int j = 0; j < features.size(); j++) {//逻辑上模拟列优先的方式读取二维数组形式的数据集,就是首先读取一个特征名称,再遍历数据集
  53. Map<String, Integer> ttt=new HashMap<String, Integer>();//将某特征的各个特征值存入Map中,然后再度第二个特征,再遍历数据集。。。
  54. for (int i = 0; i < dataSet.size(); i++) {
  55. String currentFeatureValue=dataSet.get(i)[j];
  56. if(!(ttt.containsKey(currentFeatureValue)))
  57. ttt.put(currentFeatureValue, 1);
  58. else {
  59. ttt.replace(currentFeatureValue, ttt.get(currentFeatureValue)+1);
  60. }
  61.  
  62. }
  63. featureValuesAndCounts.put(features.get(j), ttt);//嵌套形式的Map,第一层的key是特征名称,value是一个新的Map
  64. // 新Map中key是特征的各个值,value是特征值在数据集中出现的次数。
  65.  
  66. }
  67.  
  68. br.close();
  69. }
  70.  
  71. catch(Exception e){
  72. e.printStackTrace();
  73. }
  74. }
  75. public static DecisionNode treeGrowth(ArrayList<String[]> dataset,String currentFeatureName,
  76. String currentFeatureValue,ArrayList<String> current_features,
  77. Map<String,Map<String,Integer>> current_featureValuesCounts){
  78. /*
  79. dataset:用于split方法,从dataset数据集中去除掉具有某个特征值的对应的若干实例,生成一个新的新的数据集
  80. currentFeatureName:当前的特征名称,用于叶子节点,赋值给叶子节点的featureName字段
  81. currentFeatureValue:当前特征名称对应的特征值,也用于叶子节点,赋值给featureValue字段
  82. current_features:当前数据集中包含的所有特征名称,用于findBestAttribute方法,找到信息增益最大的的属性
  83. current_featureValuesCounts:当前数据集中所有特征的各个特征值出现的次数,用于findBestAttribute方法,用于计算条件熵,进而计算信息增益。
  84. */
  85. ArrayList<String> classList=new ArrayList<String>();
  86. int flag=0;
  87. for (String[] string : dataset) {
  88. //测试数据集中类型值的数量,flag表示数据集中的类型数量
  89. if (classList.contains(string[string.length-1])) {
  90.  
  91. }
  92. else {
  93. classList.add(string[string.length-1]);
  94. flag++;//如果flag>1表示数据集
  95. }
  96.  
  97. }
  98. if(1==flag){//如果只有一个类结果,则返回此叶子节点
  99. DecisionNode d=new DecisionNode();
  100. d.init(currentFeatureName,classList.get(0),currentFeatureValue);
  101. return d;
  102. }
  103. if (dataset.get(0).length==1) {//如果数据集已经没有属性了只剩下类结果,则返回占比最大的类结果,也是叶子节点
  104. DecisionNode d=new DecisionNode();
  105. d.init(currentFeatureName,classify(classList),currentFeatureValue);
  106. return d;
  107. }
  108.  
  109. /*
  110. DecisionNode是一个自定义的递归型的数据类型,类中一个children字段是DecisionNode类型的数组,
  111. 正好用这种类型来存储递归算法产生的结果(决策树),也就是用这种结构来存储一棵树。
  112. */
  113. //程序运行到这里就说明此节点不是叶子节点
  114. DecisionNode root2=new DecisionNode();//那么root2就是一个决策属性节点(非叶子节点)了,非叶子节点就有孩子节点,下面就是计算它的孩子节点
  115.  
  116. int bestFeatureIndex=findBestAttribute(dataset,current_features,current_featureValuesCounts);
  117. String bestFeatureLabel=current_features.get(bestFeatureIndex);
  118. //root.testCondition=bestFeatureLabel;
  119. ArrayList<String> feature_values=new ArrayList<String>();
  120. for (Entry<String, Integer> featureEntry : current_featureValuesCounts.get(bestFeatureLabel).entrySet()) {
  121. feature_values.add(featureEntry.getKey());
  122.  
  123. }
  124. //给非叶子节点,也就是特征节点仅仅赋特征名称值
  125. root2.init(currentFeatureName,currentFeatureValue);//java中不能是使用像C++中默认参数的函数,只能通过重载来实现同样的目的。
  126. for (String values : feature_values) {
  127. //DecisionNode tempRoot=new DecisionNode();
  128.  
  129. ArrayList<String[]> subDataSet = splitDataSet(dataset, bestFeatureIndex, values);//生成子数据集,即去除了包含values的实例,
  130. // 接下来就是计算对此数据集利用决策树进行决策,又需要调用treeGrow方法
  131. //所以,接下来需要得到对应这个子数据集的特征名称以及每个特征值在数据集中出现的次数
  132. ArrayList<String> currentAttibutes=new ArrayList<>();
  133. Iterator item1=current_features.iterator();
  134. while(item1.hasNext()){
  135. currentAttibutes.add(item1.next().toString());//这个子数据集的特征名称
  136. }
  137.  
  138. Map<String,Map<String,Integer>> currentAttributeValuesCounts=new HashMap<String, Map<String, Integer>>();
  139. //ArrayList<String[]> subDataSet = splitDataSet(dataset, bestFeatureIndex, values);
  140. currentAttibutes.remove(bestFeatureLabel);
  141. for (int j = 0; j < currentAttibutes.size(); j++) {
  142. Map<String, Integer> ttt=new HashMap<String, Integer>();
  143. for (int i = 0; i <subDataSet.size(); i++) {
  144. String currentFeatureValueXX=subDataSet.get(i)[j];
  145. if(!(ttt.containsKey(currentFeatureValueXX)))
  146. ttt.put(currentFeatureValueXX, 1);
  147. else {
  148. ttt.replace(currentFeatureValueXX, ttt.get(currentFeatureValueXX)+1);
  149. }
  150.  
  151. }
  152. currentAttributeValuesCounts.put(currentAttibutes.get(j), ttt);//每个特征值在数据集中出现的次数
  153.  
  154. }
  155.  
  156. root2.add(treeGrowth(subDataSet, bestFeatureLabel, values, currentAttibutes, currentAttributeValuesCounts));
  157.  
  158. }
  159.  
  160. return root2;
  161.  
  162. }
  163.  
  164. public static void main(String[] agrs){
  165. decisionTree.GetDataSet();
  166. DecisionNode dd=decisionTree.treeGrowth(dataSet,"oo","xx",features,featureValuesAndCounts);
  167. System.out.print(dd);
  168.  
  169. }
  170.  
  171. public static double calEntropy(ArrayList<String[]> dataset){//熵表示随机变量X不确定性的度量,在决策树中计算的熵就是决策结果这个变量的熵。
  172. int sampleCounts=dataset.size();
  173. Map<String, Integer> categoryCounts=new HashMap<String, Integer>();
  174. for (String[] strings : dataset) {
  175.  
  176. if(categoryCounts.containsKey(strings[strings.length-1]))
  177. categoryCounts.replace(strings[strings.length-1], categoryCounts.get(strings[strings.length-1])+1);
  178. else {
  179. categoryCounts.put(strings[strings.length-1],1);
  180. }
  181.  
  182. }
  183. double shannonEnt=0.0;
  184. for (Integer value: categoryCounts.values()) {
  185. double probability=value.doubleValue()/sampleCounts;
  186. shannonEnt-=probability*(Math.log10(probability)/Math.log10(2));
  187.  
  188. }
  189. return shannonEnt;
  190. }
  191.  
  192. public static int findBestAttribute(ArrayList<String[]> dataset,ArrayList<String> currentFeatures,
  193. Map<String,Map<String,Integer>> currentFeatureValuesCounts){
  194. double baseEntroy=calEntropy(dataset);//计算基础熵,就是在不划分出某个特征的情况下。
  195. double bestInfoGain=0.0;
  196. int bestFeatureIndex=-1;
  197.  
  198. for (int i = 0; i <currentFeatures.size(); i++) {//遍历当前数据集的每个特征,计算每个特征的信息增益
  199. double conditionalEntroy=0.0;
  200. double entroy=0.0;
  201. Map<String,Integer> tempFeatureCounts=currentFeatureValuesCounts.get(currentFeatures.get(i));
  202. //Map类型有一个entrySet方法,此方法返回一个Map.Entry类型的集合,其中集合中的每个元素就是一个键值对,利用增强型的for循环可以遍历Map中
  203. //key(entry.getkey)和value(entry.getValue)
  204. for (Entry<String, Integer> entry : tempFeatureCounts.entrySet()) {
  205. //计算条件熵,就是根据某个具体特征值划分出新的数据集,计算新的数据集的基础熵,再乘以权值,累加得到某个特征的条件熵。
  206. conditionalEntroy+=(entry.getValue().doubleValue()/dataset.size())*calEntropy(splitDataSet(dataset, i, entry.getKey()));
  207. //根据信息增益进一步计算信息增益比
  208. double tempValue=entry.getValue().doubleValue()/dataset.size();
  209. entroy+=tempValue*(Math.log10(tempValue)/Math.log10(2));
  210.  
  211. }
  212.  
  213. if ((baseEntroy-conditionalEntroy)/(-entroy)>bestInfoGain) {
  214. bestInfoGain=(baseEntroy-conditionalEntroy)/(-entroy);
  215. bestFeatureIndex=i;
  216.  
  217. }
  218. }
  219. if (-1==bestFeatureIndex){
  220. System.out.print("cannot find best attribute!");
  221. return -1;
  222. }
  223. else {
  224. return bestFeatureIndex;//返回信息增益最大的特征的索引,在当前特征(currentFeatures)中的索引。
  225. }
  226. }
  227. public static String classify(ArrayList<String> dataset) {
  228.  
  229. Map<String, Integer> categoryCount = new HashMap<String, Integer>();
  230. for (String s1 : dataset) {
  231. if (categoryCount.containsKey(s1)) {
  232. categoryCount.replace(s1, categoryCount.get(s1) + 1);
  233. } else {
  234. categoryCount.put(s1, 1);
  235. }
  236. }
  237. int maxCounts=-1;
  238. String maxCountsCategory=null;
  239. for (Entry<String,Integer> entry:categoryCount.entrySet()){//利用Map.Entry得到Map中的Value最大的键值对。
  240. if (entry.getValue()>maxCounts){
  241. maxCounts=entry.getValue();
  242. maxCountsCategory=entry.getKey();
  243. }
  244. }
  245. return maxCountsCategory;
  246.  
  247. }
  248.  
  249. public static ArrayList<String[]> splitDataSet(ArrayList<String[]> dataset,int featureIndex,String featureValue
  250. ){
  251. ArrayList<String[]> tempDataSet=new ArrayList<String[]>();
  252. for (String[] strings : dataset) {
  253. if (strings[featureIndex].equals(featureValue)) {
  254.  
  255. String[] xx=strings.clone();//数组的clone方法实现的是浅拷贝,实质就是以下的过程
  256. /*
  257. for (int i = featureIndex; i < strings.length-1; i++) {
  258. xx[i]=strings[i];//就是把引用的值(地址)复制了一份,指向了同一个对象。
  259. }
  260.  
  261. */
  262. for (int i = featureIndex; i < strings.length-1; i++) {//xx中各个元素的值与strings中各个元素的值完全相等。
  263. xx[i]=xx[i+1];//只是复制了引用的值而已,跟引用指向的对象没一点关系。Java将基本类型和引用类型变量都看成是值而已·
  264. }
  265. //最最最需要注意的一点,以上代码不能以下面这种形式实现
  266. /*
  267. for (int i = featureIndex; i < strings.length-1; i++) {//
  268. strings[i]=strings[i+1];//这样会改变strings指向的对象,进而影响到dataset,改变了函数的参数dataset,
  269. 这样就在函数内“无意间”修改了dataset的值,集合类型,其实所有引用类型都是,以参数形式传入函数的话,可能会“无意间”就被修改了
  270. }
  271. */
  272. String[] tempStrings=new String[xx.length-1];
  273. for (int i = 0; i < tempStrings.length; i++) {
  274. tempStrings[i]=xx[i];
  275.  
  276. }
  277. tempDataSet.add(tempStrings);
  278. }
  279.  
  280. }
  281. return tempDataSet;
  282. }
  283.  
  284. }
  285. class DecisionNode{
  286. public String featureName;
  287. public String result;
  288. public String featureValue;
  289. public List<DecisionNode> children=new ArrayList<DecisionNode>();
  290. public void add(DecisionNode node){
  291. children.add(node);
  292. }
  293. public void init(String featureName,String result,String featureValue){
  294. this.featureName=featureName;
  295. this.result=result;
  296. this.featureValue=featureValue;
  297. }
  298. public void init(String featureName,String featureValue){
  299. this.featureName=featureName;
  300. this.featureValue=featureValue;
  301. }
  302. }

C4.5决策树--Java的更多相关文章

  1. Python3实现机器学习经典算法(四)C4.5决策树

    一.C4.5决策树概述 C4.5决策树是ID3决策树的改进算法,它解决了ID3决策树无法处理连续型数据的问题以及ID3决策树在使用信息增益划分数据集的时候倾向于选择属性分支更多的属性的问题.它的大部分 ...

  2. ID3决策树---Java

    1)熵与信息增益: 2)以下是实现代码: //import java.awt.color.ICC_ColorSpace; import java.io.*; import java.util.Arra ...

  3. 小啃机器学习(1)-----ID3和C4.5决策树

    第一部分:简介 ID3和C4.5算法都是被Quinlan提出的,用于分类模型,也被叫做决策树.我们给一组数据,每一行数据都含有相同的结构,包含了一系列的attribute/value对. 其中一个属性 ...

  4. C4.5决策树-为什么可以选用信息增益来选特征

    要理解信息增益,首先要明白熵是什么,开始很不理解熵,其实本质来看熵是一个度量值,这个值的大小能够很好的解释一些问题. 从二分类问题来看,可以看到,信息熵越是小的,说明分类越是偏斜(明确),可以理解为信 ...

  5. 机器学习之决策树(ID3 、C4.5算法)

    声明:本篇博文是学习<机器学习实战>一书的方式路程,系原创,若转载请标明来源. 1 决策树的基础概念 决策树分为分类树和回归树两种,分类树对离散变量做决策树 ,回归树对连续变量做决策树.决 ...

  6. ID3、C4.5、CART决策树介绍

    决策树是一类常见的机器学习方法,它可以实现分类和回归任务.决策树同时也是随机森林的基本组成部分,后者是现今最强大的机器学习算法之一. 1. 简单了解决策树 举个例子,我们要对”这是好瓜吗?”这样的问题 ...

  7. 决策树(ID3、C4.5、CART)

    ID3决策树 ID3决策树分类的根据是样本集分类前后的信息增益. 假设我们有一个样本集,里面每个样本都有自己的分类结果. 而信息熵可以理解为:“样本集中分类结果的平均不确定性”,俗称信息的纯度. 即熵 ...

  8. ID3、C4.5和CART决策树对比

    ID3决策树:利用信息增益来划分节点 信息熵是度量样本集合纯度最常用的一种指标.假设样本集合D中第k类样本所占的比重为pk,那么信息熵的计算则为下面的计算方式 当这个Ent(D)的值越小,说明样本集合 ...

  9. 【机器学习】决策树C4.5、ID3

    一.算法流程 step1:计算信息熵 step2: 划分数据集 step3: 创建决策树 step4: 利用决策树分类 二.信息熵Entropy.信息增益Gain 重点:选择一个属性进行分支.注意信息 ...

随机推荐

  1. Rails学习:create操作 局部模板

    学习Ruby on Rails实战真经 里面说rails4使用了strong parameters, 所以代码这么写:注意不是Event.new(params[:event])了,而是参数是函数返回值 ...

  2. [转]Oracle_ProC编程

    1.引言 由于PL/SQL不能用来开发面向普通用户的应用程序,必须借助其他语言或开发工具. 在Linux操作系统下应该用什么语言或开发工具来进行Oracle数据库应用的开发呢?本文将介绍2种方案:Pr ...

  3. 第一篇代码 嗨翻C语言 21点扑克

    /* *  计算牌面点数的程序. *  使用“拉斯难加斯公开许可证”. *  学院21点扑克游戏小组. */#include <stdio.h>#include <stdlib.h& ...

  4. 启动FM预算基金管理模块后,0L总账消失的解决办法

    只要用SE38运行一下:FMGL_CHANGE_APPL_IN_LEDGER 问题就解决了.

  5. hdu 1509 Windows Message Queue

    题目连接 http://acm.hdu.edu.cn/showproblem.php?pid=1509 Windows Message Queue Description Message queue ...

  6. go语言示例-Timer计时器的用法

    计时器用来定时执行任务,分享一段代码: package main import "time" import "fmt" func main() { //新建计时 ...

  7. [超简洁]EasyQ框架-应对WEB高并发业务(秒杀、抽奖)等业务

    背景介绍 这几年一直在摸索一种框架,足够简单,又能应付很多高并发高性能的需求.研究过一些框架思想如DDD DCI,也实践过CQRS框架. 但是总觉得复杂度高,门槛也高,自己学都吃力,如果团队新人更难接 ...

  8. QT 按钮类继承处理带定时器

    01.class KeyButton : public QPushButton  02.{  03.    Q_OBJECT  04.public:  05.    explicit KeyButto ...

  9. 状压DP

    今天稍微看了下状压DP,大概就是这样子的,最主要的就是位运算, i and (1<<k)=0 意味着i状态下没有 k : i and (1<<k)>0 意味着i状态下有 ...

  10. Bootstrap入门一:Hello Bootstrap

    一.Bootstrap简介 Bootstrap,来自 Twitter,是目前很受欢迎的前端框架.Bootstrap是基于 HTML5.CSS3和Javascriopt开发的,它在 jQuery的基础上 ...