注释:之前从未接触过决策树,直接上手对着书看源码,有点难,确实有点难~~

   本代码是基于ID3编写,之后的ID4.5和CART等还没学习到

一.决策树的原理

  没有看网上原理,直接看源码懂得原理,下面是我一个抛砖引玉的例子:

     

  太丑了,在Linux下面操作实在不习惯,用的Kolourpqint画板也不好用,凑合看吧!

  假设有两个特征:no surfing 、Flippers ,一个结果:Fish

  现在假如给你一个测试:no surfing = 1, Flippers=0, 如何知道Fish的结果?太简单了Fish==A...

  现在样本你不知道排序的情况下,那我们操作的步骤只能是两种:

                                1.no surfing = 1时判断Fish,直接得出结果Fish==A

                                2.Flippers=0时判断Fish,Fish可能是A也可能是B,再判断no surfing =1时,得出Fish == A

  从上面我们可以看出,你选择的特征顺序对结果无影响,但是对计算的过程影响很大,我们能不能找到一种很好的途径去解决这个问题呢?

  下面是两种方法:

方法一

方法二

  由以上的两种思路可以得出,不同的分类方法差距很大吧?

  决策树就是用来解决如何选用最佳的方法的一种算法!!!

  一点不了解的,先花几分钟看一下我“信息熵”,这是整个算法的核心。

 

二.决策树的实现

  (1)计算信息熵

      为什么计算“信息熵”?自己去看原理就懂了。

  1. def claShannonEnt(setData):
  2. lengthData = len(setData)
  3. dicData = {}
  4. for cnt in range(lengthData):
  5. if setData[cnt,-1] not in dicData.keys():
  6. dicData[setData[cnt,-1]] = 0
  7. dicData[setData[cnt,-1]] += 1
  8. Hent = 0.0#输出信息ent
  9. for key in dicData.keys():
  10. pData = float(dicData[key])/lengthData
  11. Hent -= pData*math.log(pData,2)
  12. return Hent

  (2)划分数据集

      划分之后计算部分的信息熵之和,信息熵越小越好,信息增益越大越好。

  1. def splitData(setData,axis,value):
  2. ''' setData: sample sata
  3. axis : 轴的位置
  4. value : 满足条件的值
  5. '''
  6. lengthData = setData.shape[0]
  7. resultMat = np.zeros([1,setData.shape[1]])
  8. for count in range(lengthData):
  9. if int(setData[count,axis]) == int(value) :
  10. resultMat = np.vstack((resultMat,setData[count,:]))
  11. returnMat = resultMat[1:,:]
  12. resultMat = np.hstack((returnMat[:,0:axis],returnMat[:,axis+1:]))
  13. return resultMat

  (3)选择最佳的划分方案

      这里的原理就是划分之后的信息熵变小,信息增益变大,其中信息熵越小越好,也就是信息增益越大越好,循环比较每种划分之后的信息增益。

  1. def chooseBestTeature(setData):
  2. numFeature = setData.shape[1] - 1 #特征数量
  3. baceEntropy = claShannonEnt(setData) #信息熵
  4. bestGain = 0.0 #最好增益
  5. bestFeature = 0 #最好特征
  6. for i in range(numFeature):
  7. #featList = [example[i] for example in setData]
  8. featList = setData[:,i]
  9. uniquaVals = set(featList) #不同的Value值,set之后就变成无序集合
  10. newEntropy = 0.0
  11. for value in uniquaVals:
  12. subDataSet = splitData(setData,i,value)#分割特征
  13. prob = len(subDataSet)/float(len(setData))
  14. newEntropy += prob * claShannonEnt(subDataSet)#平均信息熵
  15. infoGain = baceEntropy - newEntropy
  16. if (infoGain > bestGain):#求得最大增益
  17. bestGain = infoGain
  18. bestFeature = i
  19. return bestFeature

  (4)计算分类之后的标签

      这里有点难理解,准备在下面程序讲解的,写到这里就直接讲解了。

      这是为了分类不了的情况做的准备,比如:[1,1,'yes'],[1,1,'no'],[1,0,'no'],[1,0,'yes'],[0,0,'no'],[0,0,'yes'],[0,1,'no'],[0,1,'yes'],大家可以按照上面的方法动手试试怎么分割?

      我们可以想象一下,就像以前中学学的解方程,Y1+Y2=10 && 2Y1 +2Y2 =10 ,你怎么求解Y1和Y2 ?两个有冲突的方程和上面的样本之间的冲突是一样的。

      这明显是一个出错的样本导致的,那怎么解决呢?

      再给出一组样本:[1,1,'yes'],[1,1,'yes'],[1,1,'no'],[1,1,'yes'],我们利用错误的样本为少数,多数的样本为正确的,所以[1,1] = 'YES'

  1. #计算分类之后的标签
  2. def majorityCnt(classList):
  3. classCount = {}
  4. for vote in classList:
  5. if vote not in classCount.keys():
  6. classCount[vote] = 0
  7. classCount[vote] += 1
  8. sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
  9. return sortedClassCount

  (5)建立决策树

      这里采用递归的方法进行划分

      调出循环的条件是:

                1.最后的标签相同--->>>也就是最后就省一个答案了,没必要划分直接得出结果了。

                2.就是第四点说的无解题,那就多的保留,少的丢弃。

  1. def creatTree(dataSet,labels):
  2. classList = dataSet[:,-1]
  3. #标签全部相等的时候退出
  4. if list(classList).count(classList[0]) == len(list(classList)):
  5. return classList[0]
  6. #最后的标签不相同,这个时候没办法分割,所以只能选择一个占比例大的标签了,博客会给具体例子
  7. if len(dataSet[0,:]) == 1:
  8. return majorityCnt(classList)
  9. bestFeat = chooseBestTeature(dataSet)
  10. bestFeatLabel = labels[bestFeat]
  11. myTree = {bestFeatLabel:{}}
  12. del(labels[bestFeat])
  13. featValue = dataSet[:,bestFeat]
  14. uniqueVals = set(featValue)
  15. for value in uniqueVals:
  16. subLabels = labels[:]
  17. myTree[bestFeatLabel][value] = creatTree(splitData(dataSet,bestFeat,value),subLabels)
  18. return myTree

   

   (6)使用决策树

      就像建立决策树一样,采用递归一层一层的去找到数据属于哪个类,看懂上面的建立之后现在这里不很简单

  1. def classify(inputTrees,featLabels,testVec):
  2. firstStr = list(inputTrees.keys())[0]#字典首元素
  3. secondDict = inputTrees[firstStr]#下一个字典
  4. featIndex = featLabels.index(firstStr)#标签中的位置
  5. for key in secondDict.keys():
  6. if testVec[featIndex] == int(key):#分支
  7. if type(secondDict[key]).__name__=='dict':#如果还是字典说明还得划分
  8. classLabels = classify(secondDict[key],featLabels,testVec)#迭代划分
  9. else: classLabels = secondDict[key]#不是字典说明已经分类
  10. return classLabels

     (7)存储决策树函数

  (8)总程序设计

      注意:我用的是Numpy数据,而不是List数据,这是有区别的,没有完全按照书上编写!

  1. import numpy as np
  2. import matplotlib.pyplot as ply
  3. import math
  4. import operator
  5.  
  6. def claShannonEnt(setData):
  7. lengthData = len(setData)
  8. dicData = {}
  9. for cnt in range(lengthData):
  10. if setData[cnt,-1] not in dicData.keys():
  11. dicData[setData[cnt,-1]] = 0
  12. dicData[setData[cnt,-1]] += 1
  13. Hent = 0.0#输出信息ent
  14. for key in dicData.keys():
  15. pData = float(dicData[key])/lengthData
  16. Hent -= pData*math.log(pData,2)
  17. return Hent
  18.  
  19. def splitData(setData,axis,value):
  20. ''' setData: sample sata
  21. axis : 轴的位置
  22. value : 满足条件的值
  23. '''
  24. lengthData = setData.shape[0]
  25. resultMat = np.zeros([1,setData.shape[1]])
  26. for count in range(lengthData):
  27. if int(setData[count,axis]) == int(value) :
  28. resultMat = np.vstack((resultMat,setData[count,:]))
  29. returnMat = resultMat[1:,:]
  30. resultMat = np.hstack((returnMat[:,0:axis],returnMat[:,axis+1:]))
  31. return resultMat
  32.  
  33. def chooseBestTeature(setData):
  34. numFeature = setData.shape[1] - 1 #特征数量
  35. baceEntropy = claShannonEnt(setData) #信息熵
  36. bestGain = 0.0 #最好增益
  37. bestFeature = 0 #最好特征
  38. for i in range(numFeature):
  39. #featList = [example[i] for example in setData]
  40. featList = setData[:,i]
  41. uniquaVals = set(featList) #不同的Value值,set之后就变成无序集合
  42. newEntropy = 0.0
  43. for value in uniquaVals:
  44. subDataSet = splitData(setData,i,value)#分割特征
  45. prob = len(subDataSet)/float(len(setData))
  46. newEntropy += prob * claShannonEnt(subDataSet)#平均信息熵
  47. infoGain = baceEntropy - newEntropy
  48. if (infoGain > bestGain):#求得最大增益
  49. bestGain = infoGain
  50. bestFeature = i
  51. return bestFeature
  52.  
  53. #计算分类之后的标签
  54. def majorityCnt(classList):
  55. classCount = {}
  56. for vote in classList:
  57. if vote not in classCount.keys():
  58. classCount[vote] = 0
  59. classCount[vote] += 1
  60. sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
  61. return sortedClassCount
  62.  
  63. def creatTree(dataSet,labels):
  64. classList = dataSet[:,-1]
  65. #标签全部相等的时候退出
  66. if list(classList).count(classList[0]) == len(list(classList)):
  67. return classList[0]
  68. #最后的标签不相同,这个时候没办法分割,所以只能选择一个占比例大的标签了,博客会给具体例子
  69. if len(dataSet[0,:]) == 1:
  70. return majorityCnt(classList)
  71. bestFeat = chooseBestTeature(dataSet)
  72. bestFeatLabel = labels[bestFeat]
  73. myTree = {bestFeatLabel:{}}
  74. del(labels[bestFeat])
  75. featValue = dataSet[:,bestFeat]
  76. uniqueVals = set(featValue)
  77. for value in uniqueVals:
  78. subLabels = labels[:]
  79. myTree[bestFeatLabel][value] = creatTree(splitData(dataSet,bestFeat,value),subLabels)
  80. return myTree
  1. import numpy as np
  2. import trees
  3.  
  4. if __name__ == '__main__':
  5. testData = np.array([[1,1,'yes'],[1,1,'no'],[1,0,'no'],[1,0,'yes'],[0,0,'no'],[0,0,'yes'],[0,1,'no'],[0,1,'yes']])
  6. myTree = trees.creatTree(testData,['no surfacing','flippers'])#['yes','yes','no','no','no']
  7. print(myTree)

《机器学习实战》ID3算法实现的更多相关文章

  1. 机器学习笔记----- ID3算法的python实战

    本文申明:本文原创,如有转载请申明.数据代码来自实验数据都是来自[美]Peter Harrington 写的<Machine Learning in Action>这本书,侵删. Hell ...

  2. 机器学习决策树ID3算法,手把手教你用Python实现

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第21篇文章,我们一起来看一个新的模型--决策树. 决策树的定义 决策树是我本人非常喜欢的机器学习模型,非常直观容易理解 ...

  3. 学习笔记之机器学习实战 (Machine Learning in Action)

    机器学习实战 (豆瓣) https://book.douban.com/subject/24703171/ 机器学习是人工智能研究领域中一个极其重要的研究方向,在现今的大数据时代背景下,捕获数据并从中 ...

  4. Python四步实现决策树ID3算法,参考机器学习实战

    一.编写计算历史数据的经验熵函数 from math import log def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCo ...

  5. 决策树ID3算法python实现 -- 《机器学习实战》

    from math import log import numpy as np import matplotlib.pyplot as plt import operator #计算给定数据集的香农熵 ...

  6. 《机器学习实战》学习笔记第三章 —— 决策树之ID3、C4.5算法

    主要内容: 一.决策树模型 二.信息与熵 三.信息增益与ID3算法 四.信息增益比与C4.5算法 五.决策树的剪枝 一.决策树模型 1.所谓决策树,就是根据实例的特征对实例进行划分的树形结构.其中有两 ...

  7. python机器学习笔记 ID3决策树算法实战

    前面学习了决策树的算法原理,这里继续对代码进行深入学习,并掌握ID3的算法实践过程. ID3算法是一种贪心算法,用来构造决策树,ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性 ...

  8. 机器学习实战 -- 决策树(ID3)

    机器学习实战 -- 决策树(ID3)   ID3是什么我也不知道,不急,知道他是干什么的就行   ID3是最经典最基础的一种决策树算法,他会将每一个特征都设为决策节点,有时候,一个数据集中,某些特征属 ...

  9. 《机器学习实战》学习笔记第九章 —— 决策树之CART算法

    相关博文: <机器学习实战>学习笔记第三章 —— 决策树 主要内容: 一.CART算法简介 二.分类树 三.回归树 四.构建回归树 五.回归树的剪枝 六.模型树 七.树回归与标准回归的比较 ...

  10. 机器学习实战笔记(Python实现)-01-K近邻算法(KNN)

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

随机推荐

  1. C#:memcached安装及.NET中的Memcached.ClientLibrary使用详解

    memcached分布式缓存的负载均衡配置比例,数据压缩,socket的详细配置等,以及在.net中的常用方法. 下载地址:http://pan.baidu.com/s/1yVILw       提取 ...

  2. php上传导入文件 nginx-502错误

    4. php程序执行时间过长而超时,检查nginx和fastcgi中各种timeout设置.(nginx 中的  fastcgi_connect_timeout 300;fastcgi_send_ti ...

  3. [3] 注解(Annotation)-- 深入理解Java:注解(Annotation)--注解处理器

    转载 http://www.cnblogs.com/peida/archive/2013/04/26/3038503.html 深入理解Java:注解(Annotation)--注解处理器 如果没有用 ...

  4. IKAnalyzer 添加扩展词库和自定义词

    原文链接http://blog.csdn.net/whzhaochao/article/details/50130605 IKanalyzer分词器 IK分词器源码位置 http://git.osch ...

  5. esp8266尝鲜

    请将当前用户添加到dialout组,否则会提示打开/dev/ttyUSB0权限不足 sudo usermod -a -G dialout `whoami` dmeg查看驱动安装信息 dmesg | g ...

  6. VirtualBox 虚拟机复制

    本文简单讲两种情况下的复制方式 1 跨电脑复制 2 同一virtrul box下 虚拟机复制 ---------------------------------------------- 1 跨电脑复 ...

  7. centos 配置puTTY rsa自动登录

    vim /etc/ssh/sshd_config, 下面三行去掉注释符号# RSAAuthentication yes PubkeyAuthentication yes AuthorizedKeysF ...

  8. 【RPC】使用Hessian构建RPC的简单示例

    服务接口和实现 public interface HelloService { // 服务方法 String sayHello(String name); } public class HelloSe ...

  9. 关于ros里ppp拨号隧道比如pptp,l2tp,sstp等等,造成多条路由,ospf的时候需要汇总为一条宣告的解决方案

    官方解决方案: https://wiki.mikrotik.com/wiki/OSPF_and_PPPoE_Setup 实际解决步骤: So to get rid of /32 routes * on ...

  10. python面向对象 : 反射和内置方法

    一. 反射 1. isinstance()和issubclass() isinstance( 对象名, 类名) : 判断对象所属关系,包括父类  (注:type(对象名) is 类名 : 判断对象所属 ...