2014-03-17 12:12 15010人阅读 评论(41) 收藏 举报
 分类:
Data Mining(25)  Python(24)  Machine Learning(46) 

版权声明:本文为博主原创文章,未经博主允许不得转载。

本文基于python逐步实现Decision Tree(决策树),分为以下几个步骤:

  • 加载数据集
  • 熵的计算
  • 根据最佳分割feature进行数据分割
  • 根据最大信息增益选择最佳分割feature
  • 递归构建决策树
  • 样本分类

关于决策树的理论方面本文几乎不讲,详情请google keywords:“决策树 信息增益  熵”

将分别体现于代码。

本文只建一个.py文件,所有代码都在这个py里

1.加载数据集

我们选用UCI经典Iris为例

Brief of IRIS:

Data Set Characteristics:  

Multivariate

Number of Instances:

150

Area:

Life

Attribute Characteristics:

Real

Number of Attributes:

4

Date Donated

1988-07-01

Associated Tasks:

Classification

Missing Values?

No

Number of Web Hits:

533125

Code:

  1. from numpy import *
  2. #load "iris.data" to workspace
  3. traindata = loadtxt("D:\ZJU_Projects\machine learning\ML_Action\Dataset\Iris.data",delimiter = ',',usecols = (0,1,2,3),dtype = float)
  4. trainlabel = loadtxt("D:\ZJU_Projects\machine learning\ML_Action\Dataset\Iris.data",delimiter = ',',usecols = (range(4,5)),dtype = str)
  5. feaname = ["#0","#1","#2","#3"] # feature names of the 4 attributes (features)

Result:

           

左图为实际数据集,四个离散型feature,一个label表示类别(有Iris-setosa, Iris-versicolor,Iris-virginica 三个类)

2. 熵的计算

entropy是香农提出来的(信息论大牛),定义见wiki

注意这里的entropy是H(C|X=xi)而非H(C|X), H(C|X)的计算见第下一个点,还要乘以概率加和

Code:

  1. from math import log
  2. def calentropy(label):
  3. n = label.size # the number of samples
  4. #print n
  5. count = {} #create dictionary "count"
  6. for curlabel in label:
  7. if curlabel not in count.keys():
  8. count[curlabel] = 0
  9. count[curlabel] += 1
  10. entropy = 0
  11. #print count
  12. for key in count:
  13. pxi = float(count[key])/n #notice transfering to float first
  14. entropy -= pxi*log(pxi,2)
  15. return entropy
  16. #testcode:
  17. #x = calentropy(trainlabel)

Result:

3. 根据最佳分割feature进行数据分割

假定我们已经得到了最佳分割feature,在这里进行分割(最佳feature为splitfea_idx)

第二个函数idx2data是根据splitdata得到的分割数据的两个index集合返回datal (samples less than pivot), datag(samples greater than pivot), labell, labelg。 这里我们根据所选特征的平均值作为pivot

  1. #split the dataset according to label "splitfea_idx"
  2. def splitdata(oridata,splitfea_idx):
  3. arg = args[splitfea_idx] #get the average over all dimensions
  4. idx_less = [] #create new list including data with feature less than pivot
  5. idx_greater = [] #includes entries with feature greater than pivot
  6. n = len(oridata)
  7. for idx in range(n):
  8. d = oridata[idx]
  9. if d[splitfea_idx] < arg:
  10. #add the newentry into newdata_less set
  11. idx_less.append(idx)
  12. else:
  13. idx_greater.append(idx)
  14. return idx_less,idx_greater
  15. #testcode:2
  16. #idx_less,idx_greater = splitdata(traindata,2)
  17. #give the data and labels according to index
  18. def idx2data(oridata,label,splitidx,fea_idx):
  19. idxl = splitidx[0] #split_less_indices
  20. idxg = splitidx[1] #split_greater_indices
  21. datal = []
  22. datag = []
  23. labell = []
  24. labelg = []
  25. for i in idxl:
  26. datal.append(append(oridata[i][:fea_idx],oridata[i][fea_idx+1:]))
  27. for i in idxg:
  28. datag.append(append(oridata[i][:fea_idx],oridata[i][fea_idx+1:]))
  29. labell = label[idxl]
  30. labelg = label[idxg]
  31. return datal,datag,labell,labelg

这里args是参数,决定分裂节点的阈值(每个参数对应一个feature,大于该值分到>branch,小于该值分到<branch),我们可以定义如下:

  1. args = mean(traindata,axis = 0)

测试:按特征2进行分类,得到的less和greater set of indices分别为:

也就是按args[2]进行样本集分割,<和>args[2]的branch分别有57和93个样本。

4. 根据最大信息增益选择最佳分割feature

信息增益为代码中的info_gain, 注释中是熵的计算

  1. #select the best branch to split
  2. def choosebest_splitnode(oridata,label):
  3. n_fea = len(oridata[0])
  4. n = len(label)
  5. base_entropy = calentropy(label)
  6. best_gain = -1
  7. for fea_i in range(n_fea): #calculate entropy under each splitting feature
  8. cur_entropy = 0
  9. idxset_less,idxset_greater = splitdata(oridata,fea_i)
  10. prob_less = float(len(idxset_less))/n
  11. prob_greater = float(len(idxset_greater))/n
  12. #entropy(value|X) = \sum{p(xi)*entropy(value|X=xi)}
  13. cur_entropy += prob_less*calentropy(label[idxset_less])
  14. cur_entropy += prob_greater * calentropy(label[idxset_greater])
  15. info_gain = base_entropy - cur_entropy #notice gain is before minus after
  16. if(info_gain>best_gain):
  17. best_gain = info_gain
  18. best_idx = fea_i
  19. return best_idx
  20. #testcode:
  21. #x = choosebest_splitnode(traindata,trainlabel)

这里的测试针对所有数据,分裂一次选择哪个特征呢?

5. 递归构建决策树

详见code注释,buildtree递归地构建树。

递归终止条件:

①该branch内没有样本(subset为空) or

②分割出的所有样本属于同一类 or

③由于每次分割消耗一个feature,当没有feature的时候停止递归,返回当前样本集中大多数sample的label

  1. #create the decision tree based on information gain
  2. def buildtree(oridata, label):
  3. if label.size==0: #if no samples belong to this branch
  4. return "NULL"
  5. listlabel = label.tolist()
  6. #stop when all samples in this subset belongs to one class
  7. if listlabel.count(label[0])==label.size:
  8. return label[0]
  9. #return the majority of samples' label in this subset if no extra features avaliable
  10. if len(feanamecopy)==0:
  11. cnt = {}
  12. for cur_l in label:
  13. if cur_l not in cnt.keys():
  14. cnt[cur_l] = 0
  15. cnt[cur_l] += 1
  16. maxx = -1
  17. for keys in cnt:
  18. if maxx < cnt[keys]:
  19. maxx = cnt[keys]
  20. maxkey = keys
  21. return maxkey
  22. bestsplit_fea = choosebest_splitnode(oridata,label) #get the best splitting feature
  23. print bestsplit_fea,len(oridata[0])
  24. cur_feaname = feanamecopy[bestsplit_fea] # add the feature name to dictionary
  25. print cur_feaname
  26. nodedict = {cur_feaname:{}}
  27. del(feanamecopy[bestsplit_fea]) #delete current feature from feaname
  28. split_idx = splitdata(oridata,bestsplit_fea) #split_idx: the split index for both less and greater
  29. data_less,data_greater,label_less,label_greater = idx2data(oridata,label,split_idx,bestsplit_fea)
  30. #build the tree recursively, the left and right tree are the "<" and ">" branch, respectively
  31. nodedict[cur_feaname]["<"] = buildtree(data_less,label_less)
  32. nodedict[cur_feaname][">"] = buildtree(data_greater,label_greater)
  33. return nodedict
  34. #testcode:
  35. #mytree = buildtree(traindata,trainlabel)
  36. #print mytree

Result:

mytree就是我们的结果,#1表示当前使用第一个feature做分割,'<'和'>'分别对应less 和 greater的数据。

6. 样本分类

根据构建出的mytree进行分类,递归走分支

  1. #classify a new sample
  2. def classify(mytree,testdata):
  3. if type(mytree).__name__ != 'dict':
  4. return mytree
  5. fea_name = mytree.keys()[0] #get the name of first feature
  6. fea_idx = feaname.index(fea_name) #the index of feature 'fea_name'
  7. val = testdata[fea_idx]
  8. nextbranch = mytree[fea_name]
  9. #judge the current value > or < the pivot (average)
  10. if val>args[fea_idx]:
  11. nextbranch = nextbranch[">"]
  12. else:
  13. nextbranch = nextbranch["<"]
  14. return classify(nextbranch,testdata)
  15. #testcode
  16. tt = traindata[0]
  17. x = classify(mytree,tt)
  18. print x

Result:

为了验证代码准确性,我们换一下args参数,把它们都设成0(很小)

args = [0,0,0,0]

建树和分类的结果如下:

可见没有小于pivot(0)的项,于是dict中每个<的key对应的value都为空。

本文中全部代码下载:决策树python实现

Reference: Machine Learning in Action

from: http://blog.csdn.net/abcjennifer/article/details/20905311

决策树Decision Tree 及实现的更多相关文章

  1. 机器学习算法实践:决策树 (Decision Tree)(转载)

    前言 最近打算系统学习下机器学习的基础算法,避免眼高手低,决定把常用的机器学习基础算法都实现一遍以便加深印象.本文为这系列博客的第一篇,关于决策树(Decision Tree)的算法实现,文中我将对决 ...

  2. 数据挖掘 决策树 Decision tree

    数据挖掘-决策树 Decision tree 目录 数据挖掘-决策树 Decision tree 1. 决策树概述 1.1 决策树介绍 1.1.1 决策树定义 1.1.2 本质 1.1.3 决策树的组 ...

  3. 用于分类的决策树(Decision Tree)-ID3 C4.5

    决策树(Decision Tree)是一种基本的分类与回归方法(ID3.C4.5和基于 Gini 的 CART 可用于分类,CART还可用于回归).决策树在分类过程中,表示的是基于特征对实例进行划分, ...

  4. (ZT)算法杂货铺——分类算法之决策树(Decision tree)

    https://www.cnblogs.com/leoo2sk/archive/2010/09/19/decision-tree.html 3.1.摘要 在前面两篇文章中,分别介绍和讨论了朴素贝叶斯分 ...

  5. 决策树decision tree原理介绍_python sklearn建模_乳腺癌细胞分类器(推荐AAA)

    sklearn实战-乳腺癌细胞数据挖掘(博主亲自录制视频) https://study.163.com/course/introduction.htm?courseId=1005269003& ...

  6. 机器学习方法(四):决策树Decision Tree原理与实现技巧

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 前面三篇写了线性回归,lass ...

  7. 机器学习-决策树 Decision Tree

    咱们正式进入了机器学习的模型的部分,虽然现在最火的的机器学习方面的库是Tensorflow, 但是这里还是先简单介绍一下另一个数据处理方面很火的库叫做sklearn.其实咱们在前面已经介绍了一点点sk ...

  8. 决策树 Decision Tree

    决策树是一个类似于流程图的树结构:其中,每个内部结点表示在一个属性上的测试,每个分支代表一个属性输出,而每个树叶结点代表类或类分布.树的最顶层是根结点.  决策树的构建 想要构建一个决策树,那么咱们 ...

  9. 【机器学习算法-python实现】决策树-Decision tree(2) 决策树的实现

    (转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景      接着上一节说,没看到请先看一下上一节关于数据集的划分数据集划分.如今我们得到了每一个特征值得 ...

随机推荐

  1. 20161005 NOIP 模拟赛 T2 解题报告

    beautiful 2.1 题目描述 一个长度为 n 的序列,对于每个位置 i 的数 ai 都有一个优美值,其定义是:找到序列中最 长的一段 [l, r],满足 l ≤ i ≤ r,且 [l, r] ...

  2. POJ 1321 简单dfs

    1.POJ 1321  棋盘问题 2.总结: 题意:给定棋盘上放k个棋子,要求同行同列都不重. #include<iostream> #include<cstring> #in ...

  3. UIView常见属性总结

    一 UIVIew 常见属性 .frame 位置和尺寸(以父控件的左上角为原点(,)) .center 中点 (以父控件的左上角为原点(,)) .bounds 位置和尺寸(以自己的左上角为原点 (,)) ...

  4. OSG使用更新回调来更改模型

    OSG使用更新回调来更改模型 转自:http://blog.sina.com.cn/s/blog_668aae7801017gl7.html 使用回调类实现对场景图形节点的更新.本节将讲解如何使用回调 ...

  5. 接口测试之soupui&groovy

    原著地址:http://www.cnblogs.com/wade-xu/p/4236295.html#3334654 需注意下方code的设置

  6. nginx“虚拟目录”不支持php的解决办法

    这几天在配置Nginx,PHP用FastCGI,想装一个phpMyAdmin管理数据库,phpMyAdmin不想放在网站根目录 下,这样不容易和网站应用混在一起,这样phpMyAdmin的目录就放在别 ...

  7. github提交失败并报错java.io.IOException: Authentication failed:

    一.概述 我最近在写一个android的项目. 软件:android studio.Android studio VCS integration(插件) Android studio VCS inte ...

  8. Centos 6.5 挂载硬盘 4K对齐 (笔记 实测)

    环境: 系统硬件:vmware vsphere (CPU:2*4核,内存2G) 系统版本:Linux centos 2.6.32-431.17.1.el6.x86_64(Centos-6.5-x86_ ...

  9. [转] - bashrc与profile的区别

    bashrc与profile的区别 要搞清bashrc与profile的区别,首先要弄明白什么是交互式shell和非交互式shell,什么是login shell 和non-login shell. ...

  10. Js特效--模仿滚动条(兼容IE8+,FF,Google)

    <html> <head> <style> *{margin:0px;padding:0px;} #box{width:200px;height:500px;pos ...