这是半成品, 已完成了 fit() 部分, 形成了包含一棵完整树的 node 对象.

后续工作是需解析该 node 对象, 完成 predict() 工作.

  1. # !/usr/bin/python
  2. # -*- coding:utf-8 -*-
  3. """
  4. Re-implement ID3 algorithm as a practice
  5. Only information gain criterion supplied in our DT algorithm.
  6. 使用该 ID3 re-implement 的前提:
  7. 1. train data 的标签必须转成0,1,2,...的形式
  8. 2. 只能处理连续特征
  9. """
  10. # Author: 相忠良(Zhong-Liang Xiang) <ugoood@163.com>
  11. # Finished at July ***, 2017
  12. import numpy as np
  13. from sklearn import datasets, cross_validation
  14. ## load data
  15. def load_data():
  16. iris = datasets.load_iris()
  17. return cross_validation.train_test_split(iris.data, iris.target, test_size=0.25, random_state=0)
  18. class DecisionNode():
  19. def __init__(self, feature_i=None, threshold=None, value=None, left_branch=None, right_branch=None):
  20. self.feature_i = feature_i # Best feature's index
  21. self.threshold = threshold # Best split threshold in the feature
  22. self.value = value # Value if the node is a leaf in the tree
  23. self.left_branch = left_branch # 'Left' subtree
  24. self.right_branch = right_branch # 'Right' subtree
  25. # print feature_i, 'feature_i'
  26. print self.value, 'value'
  27. class MyDecisionTreeClassifier():
  28. trees = []
  29. num_eles_in_class_label = 3 # 分类标签类的个数
  30. tree = {}
  31. predict_label = []
  32. X_train = []
  33. y_train = []
  34. max_depth = 3
  35. max_leaf_nodes = 30
  36. min_samples_leaf = 1
  37. count = 0
  38. def __init__(self, ):
  39. self.root = None
  40. # TODO
  41. def fit(self, X, y):
  42. self.root = DecisionNode(self.createTree(X, y))
  43. def predict(self, X):
  44. pass
  45. def score(self, X, y):
  46. pass
  47. ## entropy
  48. # e.g entropy(y_test)
  49. def __entropy(self, label_list):
  50. bincount = np.bincount(label_list, minlength=self.num_eles_in_class_label)
  51. sum = np.sum(bincount)
  52. # print 'sum in entropy ', sum
  53. temp = 1.0 * bincount / sum
  54. tot = 0
  55. # to avoid log2(0)
  56. for e in temp:
  57. if (e != 0):
  58. tot += e * (-np.log2(e))
  59. return tot
  60. def gain(self, pre_split_label_list, after_split_label_list_2d):
  61. total = 0
  62. n = after_split_label_list_2d[0].__len__() + after_split_label_list_2d[1].__len__()
  63. for item in after_split_label_list_2d:
  64. total += self.__entropy(item) * (1.0 * item.__len__() / n)
  65. return self.__entropy(pre_split_label_list) - total
  66. ## 针对np.bincount()的结果,如[37 34 41],判断是否为纯节点,既[0 22 0]的形式
  67. def isPure(self, bincount_list):
  68. sb = sorted(bincount_list)
  69. if ((sb[-1] != 0) & (sb[-2] == 0)):
  70. return True
  71. else:
  72. return False
  73. ## 计算出现次数最多的类别标签
  74. def maxCate(self, bincount_list):
  75. bincount_list = np.array(bincount_list)
  76. return bincount_list.argmax()
  77. ## 递归停止条件:
  78. # 如果样例小于等于10,停止
  79. # 如果样例大于10 且 点纯,停止
  80. # 否则 继续分裂
  81. def createTree(self, X, y):
  82. bincount_list = np.bincount(y, minlength=self.num_eles_in_class_label)
  83. if ((self.isPure(bincount_list)) & (np.sum(bincount_list) > 10)):
  84. print bincount_list, '11111'
  85. return DecisionNode(value=self.maxCate(bincount_list))
  86. elif (np.sum(bincount_list) <= 10):
  87. print bincount_list, '22222'
  88. return DecisionNode(value=self.maxCate(bincount_list))
  89. else:
  90. print bincount_list, '33333'
  91. f, v, g = self.seek_best_split_feature(X, y)
  92. mask_big = X[:, f] > v
  93. mask_sma = X[:, f] <= v
  94. bigger_X = []
  95. bigger_y = []
  96. smaller_X = []
  97. smaller_y = []
  98. bigger_X.append(X[mask_big])
  99. bigger_y.append(y[mask_big])
  100. smaller_X.append(X[mask_sma])
  101. smaller_y.append(y[mask_sma])
  102. left_branch = self.createTree(bigger_X[0], bigger_y[0])
  103. right_branch = self.createTree(smaller_X[0], smaller_y[0])
  104. return DecisionNode(feature_i=f, threshold=v, left_branch=left_branch, right_branch=right_branch)
  105. ## k>=2 特征区间切分点个数
  106. # samples 样本
  107. # labels 样本对应的标签
  108. # return: best_feature, best_split_point, gain_on_that_point
  109. def seek_best_split_feature(self, samples, labels, k=10): # 2 2.84 0.915290847812
  110. samples = np.array(samples)
  111. labels = np.array(labels)
  112. best_split_point_pool = {} # 最佳分裂特征,点,及对应的gain
  113. col_indx = 0
  114. # 遍历所有特征,寻找某特征最佳分裂点
  115. while col_indx < samples.shape[1]:
  116. max = np.max(samples[:, col_indx])
  117. min = np.min(samples[:, col_indx])
  118. split_point = np.linspace(min, max, k, False)[1:]
  119. # 寻找某特征最佳分裂点
  120. temp = []
  121. dic = {}
  122. for p in split_point:
  123. index_less = np.where(samples[:, col_indx] < p)[0] # [1 2]
  124. index_bigger = np.where(samples[:, col_indx] >= p)[0]
  125. label_less = labels[index_less]
  126. label_bigger = labels[index_bigger]
  127. temp.append(list(label_less))
  128. temp.append(list(label_bigger))
  129. g = self.gain(labels, temp)
  130. dic[p] = g
  131. temp = []
  132. best_key = sorted(dic, key=lambda x: dic[x])[-1] # 返回value最大的那个key
  133. dic_temp = {}
  134. dic_temp[best_key] = dic[best_key]
  135. best_split_point_pool[col_indx] = dic_temp
  136. col_indx += 1
  137. # 特征列表
  138. feature_name_box = list(best_split_point_pool.keys())
  139. b = list(best_split_point_pool.values()) # 临时表
  140. # 最大gain列表
  141. gain_box = []
  142. # 最佳切分点列表
  143. point_box = []
  144. for item in b:
  145. gain_box.append(item.values()[0])
  146. point_box.append(item.keys()[0])
  147. best_feature = feature_name_box[np.argmax(gain_box)]
  148. best_split_point = point_box[np.argmax(gain_box)]
  149. gain_on_that_point = np.max(gain_box)
  150. return best_feature, best_split_point, gain_on_that_point
  151. ## 测试用例
  152. X_train, X_test, y_train, y_test = load_data()
  153. cls = MyDecisionTreeClassifier()
  154. a = [[9, 2, 3, 4],
  155. [5, 6, 7, 8],
  156. [1, 10, 11, 12],
  157. [13, 14, 15, 16]]
  158. b = [0, 1, 2, 3]
  159. a = np.array(a)
  160. b = np.array(b)
  161. # xx = [2,1,1]
  162. # print cls.maxCate(xx),'11111111111111111111111'
  163. cls.fit(X_train, y_train)
  164. tree = cls.root
  165. print type(cls.root)
  166. '''
  167. 下面是编程过程中留下的经验
  168. '''
  169. # 重要1: np.linspace(0,1,5) 0-1之间,等分5份,包括首尾
  170. # np.linspace(0,1,5)
  171. # [ 0. 0.25 0.5 0.75 1. ]
  172. # 重要2: np.where(a[:,0]>2) 返回矩阵a中第0列值大于2的那些行的索引号
  173. # 返回值的样子 (array([1, 2]),)
  174. # 重要3: 返回value最大的那个key
  175. # print(sorted(dic, key=lambda x: dic[x])[-1])
  176. # 重要4: np.bincount()指定最小长度
  177. # xxx = [1,1,1,1,1]
  178. # print np.bincount(xxx,minlength=3)
  179. # 结果: [0 5 0]

重写轮子之 ID3的更多相关文章

  1. 重写轮子之 GaussionNB

    我仿照sk-learn 中 GaussionNB 的结构, 重写了该算法的轮子,命名为 MyGaussionNB, 如下: # !/usr/bin/python # -*- coding:utf-8 ...

  2. 重写轮子之 kNN

    # !/usr/bin/python # -*- coding:utf-8 -*- """ Re-implement kNN algorithm as a practic ...

  3. 关于重写ID3 Algorithm Based On MapReduceV1/C++/Streaming的一些心得体会

    心血来潮,同时想用C++连连手.面对如火如荼的MP,一阵念头闪过,如果把一些ML领域的玩意整合到MP里面是不是很有意思 确实很有意思,可惜mahout来高深,我也看不懂.干脆自动动手丰衣足食,加上自己 ...

  4. 【转】C# 重写WndProc 拦截 发送 系统消息 + windows消息常量值(1)

    C# 重写WndProc 拦截 发送 系统消息 + windows消息常量值(1) #region 截获消息        /// 截获消息  处理XP不能关机问题        protected ...

  5. Asp.net Mvc 请求是如何到达 MvcHandler的——UrlRoutingModule、MvcRouteHandler分析,并造个轮子

    这个是转载自:http://www.cnblogs.com/keyindex/archive/2012/08/11/2634005.html(那个比较容易忘记,希望博主不要生气的) 前言 本文假定读者 ...

  6. 拆解轮子之XRecyclerView

    简介 这个轮子是对RecyclerView的封装,主要完成了下拉刷新.上拉加载更多.RecyclerView头部.在我的Material Design学习项目中使用到了项目地址,感觉还不错.趁着毕业答 ...

  7. 跨平台技术实践案例: 用 reactxp 重写墨刀的移动端

    Authors:  Gao Cong, Perry Poon Illustrators:  Shena Bian April 20, 2019 重新编写,又一次,我们又一次重新编写了移动端应用和移动端 ...

  8. 星级评分原理 N次重写的分析

    使用的是雪碧图,用的软件是CSS Sprite Tools 第一次实现与分析: <!DOCTYPE html> <html> <head> <meta cha ...

  9. [18/11/29] 继承(extends)和方法的重写(override,不是重载)

    一.何为继承?(对原有类的扩充) 继承让我们更加容易实现类的扩展. 比如,我们定义了人类,再定义Boy类就只需要扩展人类即可.实现了代码的重用,不用再重新发明轮子(don’t  reinvent  w ...

随机推荐

  1. HTTP头HOST

    http request header 中的host行的作用 在早期的Http 1.0版中,Http 的request请求头中是不带host行的,在Http 1.0的加强版和Http 1.1中加入了h ...

  2. javascript学习(2)修改html元素和提示对话框

    一.修改html元素 1.修改p元素 1.1.源代码 1.2.执行前 1.3.执行后 2.修改div元素的className 2.1.源代码 1.2.执行前 1.3.执行后 3.直接在当前位置输出内容 ...

  3. spring-oauth-server实践:OAuth2.0 通过header 传递 access_token 验证

    一.解析查找 access_token 1.OAuth2AuthenticationProcessingFilter.tokenExtractor 2.发现来源可以有两处:请求的头或者请求的参数 二. ...

  4. jedis配置

    public interface IJedisClientFactory { Jedis getJedis(); } JedisClientFactoryImpl.java @Service publ ...

  5. python开发:初识python

    python简介 Python可以应用于众多领域,如:数据分析.组件集成.网络服务.图像处理.数值计算和科学计算等众多领域.目前业内几乎所有大中型互联网企业都在使用Python,如:Youtube.D ...

  6. 使用Git进行代码版本管理及协同工作

    Git简介: git是一种较为先进的代码版本管理及协同工作平台,采用分布式文件块存储: 1.  分布式: 代码保存在所有协同成员的计算机上,网速较差时依然可用:而传统的集中式代码版本管理系统则较难脱离 ...

  7. pymysql.err.ProgrammingError: 1064 (Python字符串转义问题)

    代码: sql = """INSERT INTO video_info(video_id, title) VALUES("%s","%s&q ...

  8. java创建线程的三种方法

    这里不会贴代码,只是将创建线程的三种方法做个笼统的介绍,再根据源码添加上自己的分析. 通过三种方法可以创建java线程: 1.继承Thread类. 2.实现Runnable接口. 3.实现Callab ...

  9. [LeetCode] Daily Temperatures 日常温度

    Given a list of daily temperatures, produce a list that, for each day in the input, tells you how ma ...

  10. python包安装和使用机制

    python语言的魅力之一就是大量的外置数据包,能够帮助使用者节省很多时间,提高效率.模块下载和引用是最常见的操作,现在解析内部的原理和背后发生的故事,做到心里有数. 导航: 基本定义 模块使用 模块 ...