IMPLEMENTED IN PYTHON +1 | CART生成树
Introduction:
分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。
Algorithm:
step . 分别计算所有特征中各个分类的基尼系数 step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点 step 3. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致
Formula:
Code:
"""
Created on Thu Jan 30 15:36:39 2014 @filename: test.py
""" import cart c = cart.Cart()
c.trainDecisionTree('decision_tree_text.txt')
print c.trainresult
view test.py
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 18:05:22 2014 @filename: cart.py
"""
FILENAME = 'decision_tree_text.txt'
MAXDEPTH = 10 import numpy as np
import plottree class Cart():
def __init__(self):
self.trainresult = 'WARNING : please trainDecisionTree first!'
pass def trainDecisionTree(self, filename):
self.__loadDataSet(filename)
self.__optimalTree(self.__datamat) def __loadDataSet(self, filename):
fread = open(filename)
self.__dataset = np.array([row.strip().split('\t') \
for row in fread.readlines()])
self.__textdic = {}
for col in self.__dataset.T:
i = .0
for cell in col:
if not self.__textdic.has_key(cell):
self.__textdic[cell] = i
i += 1
self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell) \
for cell in row]) \
for row in self.__dataset]) def __getSampleCount(self, setd, col = -1, s = None):
dic = {} if s is not None:
newset = self.__getSampleMat(setd,col,s)[:,-1]
else:
newset = setd[:,col] for cell in newset:
if not dic.has_key(cell):
dic[cell] = 1.
else:
dic[cell] += 1
return dic def __getSampleMat(self, setd, col, s):
lista = []; listb = []
for row in setd:
if row[col] == s:
lista.append(row)
else:
listb.append(row)
return np.array(lista), np.array(listb) def __getGiniD(self, setd):
sample_count = self.__getSampleCount(setd)
gini = 0
for item in sample_count.items():
gini += item[1]/len(setd) * (1- item[1]/len(setd))
return gini def __getGiniDA(self, setd, a):
sample_count = self.__getSampleCount(setd, a)
dic = {}
for item in sample_count.items():
setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0])
gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) + \
(1- item[1]/len(setd)) * self.__getGiniD(setd_part_b)
dic[item[0]]=gini
return min(dic.items()), dic def __optimalNode(self, setd):
coln = 0
ginicol = 0
mingini = {1:1}
for col in setd[:,:-1].T:
gini, dic = self.__getGiniDA(setd, coln)
if gini[1] < mingini[1]:
mingini = gini
ginicol = coln
coln += 1
return ginicol, mingini[0], mingini[1] def __optimalNodeText(self, col, value):
row = 0
tex = None
for cell in self.__dataset.T[col]:
if self.__datamat[row,col] == value:
tex = cell
break
row += 1
return tex def __optimalTree(self, setd):
arr = setd
count = MAXDEPTH-1
features = np.array(range(len(arr.T)))
lst = []
defaultc = None
while count > 0:
count -= 1
ginicol, value, gini = self.__optimalNode(arr)
parts = self.__getSampleMat(arr, ginicol, value)
args = [np.unique(part[:,-1]) for part in parts]
realvalues = [np.unique(part[:,ginicol])[0] for part in parts]
realcol = features[ginicol]
features = np.delete(features, ginicol)
if gini == 0 or len(arr.T) == 2:
if args[0] == defaultc:
value = realvalues[0]
else:
value = realvalues[1]
self.trainresult = self.__buildList(lst, realcol, value, gini)
self.__plotTree(self.trainresult)
return
if len(args[0]) == 1:
defaultc = args[0]
self.__buildList(lst, realcol, realvalues[0], gini)
arr = np.concatenate((parts[1][:,:ginicol], \
parts[1][:,ginicol+1:]), axis=1)
else:
defaultc = args[1]
self.__buildList(lst, realcol, realvalues[1], gini)
arr = np.concatenate((parts[0][:,:ginicol], \
parts[0][:,ginicol+1:]), axis=1) def __plotTree(self, lst):
dic = {}
for item in lst:
if dic == {}:
dic[item[0]] = {item[1]:'c1','ELSE':'c2'}
else:
dic = {item[0]:{item[1]:'c1','ELSE':dic}}
tree = plottree.retrieveTree(dic)
self.trainresult = tree
plottree.createPlot(tree) def __buildList(self, lst, col, value, gini):
print 'feature col:', col, \
' feature val:', self.__optimalNodeText(col, value), \
' Gini:', gini, '\n'
lst.insert(0,[col,str(self.__optimalNodeText(col, \
value))+':'+str(value)])
return lst if __name__ == '__main__':
cart = Cart()
view cart.py
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 11:45:18 2014 @filename: plottree.py
""" import matplotlib.pyplot as plt decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "1.0")
arrow_args = dict(arrowstyle = "<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, \
xycoords = 'axes fraction', xytext = centerPt, \
textcoords = 'axes fraction', va= "center",\
ha = "center", bbox = nodeType, arrowprops = arrow_args) def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ is 'dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth def retrieveTree(dic = {'have house': {'yes': 'c1', 'no':{'have job': \
{'yes': 'c1','no': 'c2'}}}}):
return dic def plotMidText(centrPt, parentPt, txtString):
xMid = (parentPt[0] - centrPt[0]) /2.0 + centrPt[0]
yMid = (parentPt[1] - centrPt[1]) /2.0 + centrPt[1]
createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
firstStr = myTree.keys()[0]
centrPt = [plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
plotTree.yOff]
plotMidText(centrPt, parentPt, nodeTxt)
plotNode(firstStr, centrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], centrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
centrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), centrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show() if __name__ == '__main__':
myTree = retrieveTree()
createPlot(myTree)
view plottree.py
输入数据
输出结果
feature col: 2 feature val: 是 Gini: 0.266666666667 feature col: 1 feature val: 是 Gini: 0.0
Reference:
Harrington P. Machine Learning in Action
李航. 统计学习方法
IMPLEMENTED IN PYTHON +1 | CART生成树的更多相关文章
- Python实现CART(基尼指数)
Python实现CART(基尼指数) 运行环境 Pyhton3 treePlotter模块(画图所需,不画图可不必) matplotlib(如果使用上面的模块必须) 计算过程 st=>start ...
- 机器学习之分类回归树(python实现CART)
之前有文章介绍过决策树(ID3).简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的.按照某种特征切分数据后,该特征在以后切分数据集时就不再使用,因此存在切分 ...
- Algorithm: quick sort implemented in python 算法导论 快速排序
import random def partition(A, lo, hi): pivot_index = random.randint(lo, hi) pivot = A[pivot_index] ...
- leetcode-happy number implemented in python
视频分析: http://v.youku.com/v_show/id_XMTMyODkyNDA0MA==.html?from=y1.7-1.2 class Solution(object): def ...
- Awesome Python
Awesome Python A curated list of awesome Python frameworks, libraries, software and resources. Insp ...
- Python开源框架、库、软件和资源大集合
A curated list of awesome Python frameworks, libraries, software and resources. Inspired by awesome- ...
- Python 库汇总英文版
Awesome Python A curated list of awesome Python frameworks, libraries, software and resources. Insp ...
- Python框架、库以及软件资源汇总
转自:http://developer.51cto.com/art/201507/483510.htm 很多来自世界各地的程序员不求回报的写代码为别人造轮子.贡献代码.开发框架.开放源代码使得分散在世 ...
- Python Scopes and Namespaces
Before introducing classes, I first have to tell you something about Python's scope rules. Class def ...
随机推荐
- jQuery代码优化 事件委托篇
<转自 http://www.jb51.net/article/28770.htm> 参考文章: 解密jQuery事件核心 - 绑定设计(一) 参考文章: 解密jQuery事件核心 - ...
- java.lang.UnsupportedClassVersionError: Bad version number in .class file
java.lang.UnsupportedClassVersionError: Bad version number in .class file造成这种过错是ni的支撑Tomcat运行的JDK版本与 ...
- GridView的初级使用
使用GridView自带的分页功能,需要激发PageIndexChanging protected void gvNewsList_PageIndexChanging(object sender, G ...
- (转)Android调用系统自带的文件管理器进行文件选择并获得路径
Android区别于iOS的沙盒模式,可以通过文件浏览器浏览本地的存储器.Android API也提供了相应的接口. 基本思路,先通过Android API调用系统自带的文件浏览器选取文件获得URI, ...
- Swift - 10 - assert(断言)
//: Playground - noun: a place where people can play import UIKit var str = "Hello, playground& ...
- 关于margin-top失效的解决方法
常出现两种情况:(一)margin-top失效 先看下面代码: <div><div class="box1" >float:left</div> ...
- 解决Eclipse中编辑xml文件的智能提示问题,最简单的是第二种方法。
Eclipse for Android xml 文件代码自动提示功能,介绍Eclipse 编辑器中实现xml 文件代码自动智能提示功能,解决eclipse 代码提示失效.eclipse 不能自动提示. ...
- [转]STL的内存分配器
题记:内存管理一直是C/C++程序的红灯区.关于内存管理的话题,大致有两类侧重点,一类是内存的正确使用,例如C++中new和delete应该成对出现,用RAII技巧管理内存资源,auto_ptr等方面 ...
- JQuery中$.ajax()方法参数
url: 要求为String类型的参数,(默认为当前页地址)发送请求的地址. type: 要求为String类型的参数,请求方式(post或get)默认为get.注意其他http请求方法,例如put和 ...
- js 鼠标双击滚动单击停止
<!DOCTYPE html> <html> <head> <title>双击滚动代码</title> <script src=&qu ...