数据集如下:

 色泽    根蒂    敲声    纹理    脐部    触感    好瓜
青绿 蜷缩 浊响 清晰 凹陷 硬滑 是
乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 是
乌黑 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 蜷缩 沉闷 清晰 凹陷 硬滑 是
浅白 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 稍蜷 浊响 清晰 稍凹 软粘 是
乌黑 稍蜷 浊响 稍糊 稍凹 软粘 是
乌黑 稍蜷 浊响 清晰 稍凹 硬滑 是
乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 否
青绿 硬挺 清脆 清晰 平坦 软粘 否
浅白 硬挺 清脆 模糊 平坦 硬滑 否
浅白 蜷缩 浊响 模糊 平坦 软粘 否
青绿 稍蜷 浊响 稍糊 凹陷 硬滑 否
浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 否
乌黑 稍蜷 浊响 清晰 稍凹 软粘 否
浅白 蜷缩 浊响 模糊 平坦 硬滑 否
青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 否

基于信息增益的ID3决策树的原理这里不再赘述,读者如果不明白可参考西瓜书对这部分内容的讲解。

python实现代码如下:

 from math import log2
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties # 统计label出现次数
def get_counts(data):
total = len(data)
results = {}
for d in data:
results[d[-1]] = results.get(d[-1], 0) + 1
return results, total # 计算信息熵
def calcu_entropy(data):
results, total = get_counts(data)
ent = sum([-1.0*v/total*log2(v/total) for v in results.values()])
return ent # 计算每个feature的信息增益
def calcu_each_gain(column, update_data):
total = len(column)
grouped = update_data.iloc[:, -1].groupby(by=column)
temp = sum([len(g[1])/total*calcu_entropy(g[1]) for g in list(grouped)])
return calcu_entropy(update_data.iloc[:, -1]) - temp # 获取最大的信息增益的feature
def get_max_gain(temp_data):
columns_entropy = [(col, calcu_each_gain(temp_data[col], temp_data)) for col in temp_data.iloc[:, :-1]]
columns_entropy = sorted(columns_entropy, key=lambda f: f[1], reverse=True)
return columns_entropy[0] # 去掉数据中已存在的列属性内容
def drop_exist_feature(data, best_feature):
attr = pd.unique(data[best_feature])
new_data = [(nd, data[data[best_feature] == nd]) for nd in attr]
new_data = [(n[0], n[1].drop([best_feature], axis=1)) for n in new_data]
return new_data # 获得出现最多的label
def get_most_label(label_list):
label_dict = {}
for l in label_list:
label_dict[l] = label_dict.get(l, 0) + 1
sorted_label = sorted(label_dict.items(), key=lambda ll: ll[1], reverse=True)
return sorted_label[0][0] # 创建决策树
def create_tree(data_set, column_count):
label_list = data_set.iloc[:, -1]
if len(pd.unique(label_list)) == 1:
return label_list.values[0]
if all([len(pd.unique(data_set[i])) ==1 for i in data_set.iloc[:, :-1].columns]):
return get_most_label(label_list)
best_attr = get_max_gain(data_set)[0]
tree = {best_attr: {}}
exist_attr = pd.unique(data_set[best_attr])
if len(exist_attr) != len(column_count[best_attr]):
no_exist_attr = set(column_count[best_attr]) - set(exist_attr)
for nea in no_exist_attr:
tree[best_attr][nea] = get_most_label(label_list)
for item in drop_exist_feature(data_set, best_attr):
tree[best_attr][item[0]] = create_tree(item[1], column_count)
return tree # 决策树绘制基本参考《机器学习实战》书内的代码以及博客:http://blog.csdn.net/c406495762/article/details/76262487
# 获取树的叶子节点数目
def get_num_leafs(decision_tree):
num_leafs = 0
first_str = next(iter(decision_tree))
second_dict = decision_tree[first_str]
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
num_leafs += get_num_leafs(second_dict[k])
else:
num_leafs += 1
return num_leafs # 获取树的深度
def get_tree_depth(decision_tree):
max_depth = 0
first_str = next(iter(decision_tree))
second_dict = decision_tree[first_str]
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
this_depth = 1 + get_tree_depth(second_dict[k])
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth # 绘制节点
def plot_node(node_txt, center_pt, parent_pt, node_type):
arrow_args = dict(arrowstyle='<-')
font = FontProperties(fname=r'C:\Windows\Fonts\STXINGKA.TTF', size=15)
create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction', xytext=center_pt,
textcoords='axes fraction', va="center", ha="center", bbox=node_type,
arrowprops=arrow_args, FontProperties=font) # 标注划分属性
def plot_mid_text(cntr_pt, parent_pt, txt_str):
font = FontProperties(fname=r'C:\Windows\Fonts\MSYH.TTC', size=10)
x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
create_plot.ax1.text(x_mid, y_mid, txt_str, va="center", ha="center", color='red', FontProperties=font) # 绘制决策树
def plot_tree(decision_tree, parent_pt, node_txt):
d_node = dict(boxstyle="sawtooth", fc="0.8")
leaf_node = dict(boxstyle="round4", fc='0.8')
num_leafs = get_num_leafs(decision_tree)
first_str = next(iter(decision_tree))
cntr_pt = (plot_tree.xoff + (1.0 +float(num_leafs))/2.0/plot_tree.totalW, plot_tree.yoff)
plot_mid_text(cntr_pt, parent_pt, node_txt)
plot_node(first_str, cntr_pt, parent_pt, d_node)
second_dict = decision_tree[first_str]
plot_tree.yoff = plot_tree.yoff - 1.0/plot_tree.totalD
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
plot_tree(second_dict[k], cntr_pt, k)
else:
plot_tree.xoff = plot_tree.xoff + 1.0/plot_tree.totalW
plot_node(second_dict[k], (plot_tree.xoff, plot_tree.yoff), cntr_pt, leaf_node)
plot_mid_text((plot_tree.xoff, plot_tree.yoff), cntr_pt, k)
plot_tree.yoff = plot_tree.yoff + 1.0/plot_tree.totalD def create_plot(dtree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
plot_tree.totalW = float(get_num_leafs(dtree))
plot_tree.totalD = float(get_tree_depth(dtree))
plot_tree.xoff = -0.5/plot_tree.totalW
plot_tree.yoff = 1.0
plot_tree(dtree, (0.5, 1.0), '')
plt.show() if __name__ == '__main__':
my_data = pd.read_csv('./watermelon2.0.csv', encoding='gbk')
column_count = dict([(ds, list(pd.unique(my_data[ds]))) for ds in my_data.iloc[:, :-1].columns])
d_tree = create_tree(my_data, column_count)
create_plot(d_tree)

绘制的决策树如下:

python实现简单决策树(信息增益)——基于周志华的西瓜书数据的更多相关文章

  1. 周志华-机器学习西瓜书-第三章习题3.5 LDA

    本文为周志华机器学习西瓜书第三章课后习题3.5答案,编程实现线性判别分析LDA,数据集为书本第89页的数据 首先介绍LDA算法流程: LDA的一个手工计算数学实例: 课后习题的代码: # coding ...

  2. 支持向量机(SVM)算法分析——周志华的西瓜书学习

    1.线性可分 对于一个数据集: 如果存在一个超平面X能够将D中的正负样本精确地划分到S的两侧,超平面如下: 那么数据集D就是线性可分的,否则,不可分. w称为法向量,决定了超平面的方向:b为位移量,决 ...

  3. (二)《机器学习》(周志华)第4章 决策树 笔记 理论及实现——“西瓜树”——CART决策树

    CART决策树 (一)<机器学习>(周志华)第4章 决策树 笔记 理论及实现——“西瓜树” 参照上一篇ID3算法实现的决策树(点击上面链接直达),进一步实现CART决策树. 其实只需要改动 ...

  4. 【深度森林第三弹】周志华等提出梯度提升决策树再胜DNN

    [深度森林第三弹]周志华等提出梯度提升决策树再胜DNN   技术小能手 2018-06-04 14:39:46 浏览848 分布式 性能 神经网络   还记得周志华教授等人的“深度森林”论文吗?今天, ...

  5. 【Todo】【读书笔记】机器学习-周志华

    书籍位置: /Users/baidu/Documents/Data/Interview/机器学习-数据挖掘/<机器学习_周志华.pdf> 一共442页.能不能这个周末先囫囵吞枣看完呢.哈哈 ...

  6. [重磅]Deep Forest,非神经网络的深度模型,周志华老师最新之作,三十分钟理解!

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 深度学习最大的贡献,个人认为就是表征 ...

  7. 偶尔转帖:AI会议的总结(by南大周志华)

    偶尔转帖:AI会议的总结(by南大周志华) 说明: 纯属个人看法, 仅供参考. tier-1的列得较全, tier-2的不太全, tier-3的很不全. 同分的按字母序排列. 不很严谨地说, tier ...

  8. 【转载】 AI会议的总结(by南大周志华)

    原文地址: https://blog.csdn.net/LiFeitengup/article/details/8441054 最近在查找期刊会议级别的时候发现这篇博客,应该是2012年之前的内容,现 ...

  9. AI产业将更凸显个人英雄主义 周志华老师的观点是如此的有深度

    今天无意间在网上看的了一则推送,<周志华:AI产业将更凸显个人英雄主义> http://tech.163.com/18/0601/13/DJ7J39US00098IEO.html 摘录一些 ...

随机推荐

  1. scss-@at-root

    @at-root指令可以使一个或多个规则被限定输出在文档的根层级上,而不是被嵌套在其父选择器下. 下面就通过scss代码实例介绍一下它的作用: 没有使用@at-root命令的默认情况. .parent ...

  2. CSS文字有关属性

    font-size|family|weight|style 大小字体加粗斜体 color|opacity 颜色透明度 height+line-height:垂直居中 overflow:hidden|v ...

  3. wxpython 对话框

    . 消息对话框(wx.MessageDialog) 消息对话框 与用户通信最基本的机制是wx.MessageDialog,它是一个简单的提示框. wx.MessageDialog可用作一个简单的OK框 ...

  4. wxpyhon 对话框

    Python内置了好多定义好了的对话框供我们使用,这里先介绍三个最常用的: 1 Message dialog 2 Text entry 3 Choosing from a list 当然python还 ...

  5. 在 Windows Vista、Windows 7 和 Windows Server 2008 上设置 SharePoint 2010 开发环境

    适用范围: SharePoint Foundation 2010 | SharePoint Server 2010 本文内容 步骤 1:选择和预配置操作系统 步骤 2:安装 SharePoint 20 ...

  6. js:JSON对象与JSON字符串转换

    JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,采用完全独立于语言的文本格式,是理想的数据交换格式. 同时,JSON是 JavaScript 原生格式,这 ...

  7. javascript正则表达式 - 学习笔记

    JavaScript 正则表达式 学习笔记 标签(空格分隔): 基础 JavaScript 正则表达式是用于匹配字符串中字符组合的模式.在javascript中,正则表达式也是对象.这些模式被用于Re ...

  8. 【Spring实战】—— 5 设值注入

    本篇主要讲解了Spring的最常用的功能——依赖注入. 注入的方式,是使用Getter Setter注入,平时大多的编程也都是使用这种方法. 举个简单的例子,还是表演者. 表演者有自己的属性,年龄或者 ...

  9. STL容器及算法题:删除奇数的QQ号

    最近思考到这样一个题目:在STL的set和vector容器里存储了1亿个QQ号,编写函数删除奇数QQ号. 1. STL容器简介 首先了解一下 set 和 vector 以及其他类似的 STL 容器: ...

  10. vuejs 开发中踩到的坑

    用 v-for 循环式  每个item的值相等的情况下,会影响v-model的双向绑定: Modal 组件开发,主要用slot 标签来实现 <template> <transitio ...