1. 关于ID3和C4.5的原理介绍这里不赘述,网上到处都是,可以下载讲义c9641_c001.pdf或者参考李航的《统计学习方法》.

2. 数据与数据处理

  • 本文采用下面的训练数据:

  • 数据处理:本文只采用了"Outlook", "Humidity", "Windy"三个属性,然后根据Humidity的值是否大于75,将Humidity的值归为两类,Play Golf 的值就是类别标签,只有yes 和 no两类
  • 训练集是字符和数字的混合,这会给编程带来麻烦,所以首先把训练集用数字表示出来:
    1. const unsigned att_num = ;
    2. const unsigned rule_num = ;
    3. string decision_tree_name("Play Golf ?");
    4. string attribute_names[] = {"Outlook", "Humidity", "Windy"};
    5. string attribute_values[] = {"Sunny", "Overcast", "Rainy", "> 75", "<= 75", "True", "False", "Yes", "No"};
    6. //训练集最后一列为分类标签,所以总列数为属性数加1
    7. unsigned train_data[rule_num][att_num + ] = {
    8. {, , , },{, , , },{, , , },
    9. {, , , },{, , , },{, , , },
    10. {, , , },{, , , },{, , , },
    11. {, , , },{, , , },{, , , },
    12. {, , , },{, , , }
    13. };

    以train_data的第一行{0, 3, 6, 8}为例解释:前三列值对应的属性与attribute_names中的元素分别对应,最后一列是类别标签的值,0 表示 attribute_values的第1个元素,即”Sunny“,类似3便是attribute_values的第4个元素"> 75",6 是 "False",8 是"No",所以{0, 3, 6, 8} 代表的实例就是:



3. 编写必要函数


4. 确定数据结构


  1. struct Tree{
  2. unsigned root;//节点属性值
  3. vector<unsigned> branches;//节点可能取值
  4. vector<Tree> children; //孩子节点
  5. };


5. 构造决策树


6. 打印决策树


7. 代码实现

  1. /*************************************************
  2. Copyright:1.0
  3. Author:90Zeng
  4. Date:2014-11-25
  5. Description:ID3/C4.5 algorithm
  6. **************************************************/
  8. #include <iostream>
  9. #include <cmath>
  10. #include <vector>
  11. #include <string>
  12. #include <algorithm>
  13. using namespace std;
  15. const unsigned att_num = ;
  16. const unsigned rule_num = ;
  17. string decision_tree_name("Play Golf ?");
  18. string attribute_names[] = {"Outlook", "Humidity", "Windy"};
  19. string attribute_values[] = {"Sunny", "Overcast", "Rainy", "> 75", "<= 75", "True", "False", "Yes", "No"};
  20. //训练集最后一列为分类标签,所以总列数为属性数加1
  21. unsigned train_data[rule_num][att_num + ] = {
  22. {, , , },{, , , },{, , , },
  23. {, , , },{, , , },{, , , },
  24. {, , , },{, , , },{, , , },
  25. {, , , },{, , , },{, , , },
  26. {, , , },{, , , }
  27. };
  29. /*************************************************
  30. Function: unique()
  31. Description: 将vector中重复元素合并,只保留一个
  32. Calls: 无
  33. Input: vector
  34. Output: vector
  35. *************************************************/
  36. template <typename T>
  37. vector<T> unique(vector<T> vals)
  38. {
  39. vector<T> unique_vals;
  40. vector<T>::iterator itr;
  41. vector<T>::iterator subitr;
  43. int flag = ;
  44. while( !vals.empty() )
  45. {
  46. unique_vals.push_back(vals[]);
  47. itr = vals.begin();
  48. subitr = unique_vals.begin() + flag;
  49. while ( itr != vals.end())
  50. {
  51. if (*subitr == *itr)
  52. itr = vals.erase(itr);
  53. else
  54. itr++;
  55. }
  56. flag++;
  57. }
  58. return unique_vals;
  59. }
  61. /*************************************************
  62. Function: log2()
  63. Description: 计算一个数值得以2为底的对数
  64. Calls: 无
  65. Input: double
  66. Output: double
  67. *************************************************/
  69. double log2(double n)
  70. {
  71. return log10(n) / log10(2.0);
  72. }
  74. /*************************************************
  75. Function: compute_entropy()
  76. Description: 根据属性的取值,计算该属性的熵
  77. Calls: unique(),log2(),count(),其中count()
  78. 在STL的algorithm库中
  79. Input: vector<unsigned>
  80. Output: double
  81. *************************************************/
  82. double compute_entropy(vector<unsigned> v)
  83. {
  84. vector<unsigned> unique_v;
  85. unique_v = unique(v);
  87. vector<unsigned>::iterator itr;
  88. itr = unique_v.begin();
  90. double entropy = 0.0;
  91. auto total = v.size();
  92. while(itr != unique_v.end())
  93. {
  94. double cnt = count(v.begin(), v.end(), *itr);
  95. entropy -= cnt / total * log2(cnt / total);
  96. itr++;
  97. }
  98. return entropy;
  99. }
  101. /*************************************************
  102. Function: compute_gain()
  103. Description: 计算数据集中所有属性的信息增益
  104. Calls: compute_entropy(),unique()
  105. Input: vector<vector<unsigned> >
  106. 相当于一个二维数组,存储着训练数据集
  107. Output: vector<double> 存储着所有属性的信息
  108. 增益
  109. *************************************************/
  110. vector<double> compute_gain(vector<vector<unsigned> > truths)
  111. {
  112. vector<double> gain(truths[].size() - , );
  113. vector<unsigned> attribute_vals;
  114. vector<unsigned> labels;
  115. for(unsigned j = ; j < truths.size(); j++)
  116. {
  117. labels.push_back(truths[j].back());
  118. }
  120. for(unsigned i = ; i < truths[].size() - ; i++)//最后一列是类别标签,没必要计算信息增益
  121. {
  122. for(unsigned j = ; j < truths.size(); j++)
  123. attribute_vals.push_back(truths[j][i]);
  125. vector<unsigned> unique_vals = unique(attribute_vals);
  126. vector<unsigned>::iterator itr = unique_vals.begin();
  127. vector<unsigned> subset;
  128. while(itr != unique_vals.end())
  129. {
  130. for(unsigned k = ; k < truths.size(); k++)
  131. {
  132. if (*itr == attribute_vals[k])
  133. {
  134. subset.push_back(truths[k].back());
  135. }
  136. }
  137. double A = (double)subset.size();
  138. gain[i] += A / truths.size() * compute_entropy(subset);
  139. itr++;
  140. subset.clear();
  141. }
  142. gain[i] = compute_entropy(labels) - gain[i];
  143. attribute_vals.clear();
  144. }
  145. return gain;
  146. }
  148. /*************************************************
  149. Function: compute_gain_ratio()
  150. Description: 计算数据集中所有属性的信息增益比
  151. C4.5算法中用到
  152. Calls: compute_gain();compute_entropy()
  153. Input: 训练数据集
  154. Output: 信息增益比
  155. *************************************************/
  156. vector<double> compute_gain_ratio(vector<vector<unsigned> > truths)
  157. {
  158. vector<double> gain = compute_gain(truths);
  159. vector<double> entropies;
  160. vector<double> gain_ratio;
  162. for(unsigned i = ; i < truths[].size() - ; i++)//最后一列是类别标签,没必要计算信息增益比
  163. {
  164. vector<unsigned> attribute_vals(truths.size(), );
  165. for(unsigned j = ; j < truths.size(); j++)
  166. {
  167. attribute_vals[j] = truths[j][i];
  168. }
  169. double current_entropy = compute_entropy(attribute_vals);
  170. if (current_entropy)
  171. {
  172. gain_ratio.push_back(gain[i] / current_entropy);
  173. }
  174. else
  175. gain_ratio.push_back(0.0);
  177. }
  178. return gain_ratio;
  179. }
  181. /*************************************************
  182. Function: find_most_common_label()
  183. Description: 找出数据集中最多的类别标签
  185. Calls: count();
  186. Input: 数据集
  187. Output: 类别标签
  188. *************************************************/
  189. template <typename T>
  190. T find_most_common_label(vector<vector<T> > data)
  191. {
  192. vector<T> labels;
  193. for (unsigned i = ; i < data.size(); i++)
  194. {
  195. labels.push_back(data[i].back());
  196. }
  197. vector<T>:: iterator itr = labels.begin();
  198. T most_common_label;
  199. unsigned most_counter = ;
  200. while (itr != labels.end())
  201. {
  202. unsigned current_counter = count(labels.begin(), labels.end(), *itr);
  203. if (current_counter > most_counter)
  204. {
  205. most_common_label = *itr;
  206. most_counter = current_counter;
  207. }
  208. itr++;
  209. }
  210. return most_common_label;
  211. }
  213. /*************************************************
  214. Function: find_attribute_values()
  215. Description: 根据属性,找出该属性可能的取值
  217. Calls: unique();
  218. Input: 属性,数据集
  219. Output: 属性所有可能的取值(不重复)
  220. *************************************************/
  221. template <typename T>
  222. vector<T> find_attribute_values(T attribute, vector<vector<T> > data)
  223. {
  224. vector<T> values;
  225. for (unsigned i = ; i < data.size(); i++)
  226. {
  227. values.push_back(data[i][attribute]);
  228. }
  229. return unique(values);
  230. }
  232. /*************************************************
  233. Function: drop_one_attribute()
  234. Description: 在构建决策树的过程中,如果某一属性已经考察过了
  235. 那么就从数据集中去掉这一属性,此处不是真正意义
  236. 上的去掉,而是将考虑过的属性全部标记为110,当
  237. 然可以是其他数字,只要能和原来训练集中的任意数
  238. 字区别开来即可
  239. Calls: unique();
  240. Input: 属性,数据集
  241. Output: 属性所有可能的取值(不重复)
  242. *************************************************/
  243. template <typename T>
  244. vector<vector<T> > drop_one_attribute(T attribute, vector<vector<T> > data)
  245. {
  246. vector<vector<T> > new_data(data.size(),vector<T>(data[].size() - , ));
  247. for (unsigned i = ; i < data.size(); i++)
  248. {
  249. data[i][attribute] = ;
  250. }
  251. return data;
  252. }
  254. struct Tree{
  255. unsigned root;//节点属性值
  256. vector<unsigned> branches;//节点可能取值
  257. vector<Tree> children; //孩子节点
  258. };
  260. /*************************************************
  261. Function: build_decision_tree()
  262. Description: 递归构建决策树
  264. Calls: unique(),count(),
  265. find_most_common_label()
  266. compute_gain()(ID3),
  267. compute_gain_ratio()(C4.5),
  268. find_attribute_values(),
  269. drop_one_attribute(),
  270. build_decision_tree()(递归,
  271. 当然要调用函数本身)
  272. Input: 训练数据集,一个空决策树
  273. Output: 无
  274. *************************************************/
  275. void build_decision_tree(vector<vector<unsigned> > examples, Tree &tree)
  276. {
  277. //第一步:判断所有实例是否都属于同一类,如果是,则决策树是单节点
  278. vector<unsigned> labels(examples.size(), );
  279. for (unsigned i = ; i < examples.size(); i++)
  280. {
  281. labels[i] = examples[i].back();
  282. }
  283. if (unique(labels).size() == )
  284. {
  285. tree.root = labels[];
  286. return;
  287. }
  289. //第二步:判断是否还有剩余的属性没有考虑,如果所有属性都已经考虑过了,
  290. //那么此时属性数量为0,将训练集中最多的类别标记作为该节点的类别标记
  291. if (count(examples[].begin(),examples[].end(),) == examples[].size() - )//只剩下一列类别标记
  292. {
  293. tree.root = find_most_common_label(examples);
  294. return;
  295. }
  296. //第三步:在上面两步的条件都判断失败后,计算信息增益,选择信息增益最大
  297. //的属性作为根节点,并找出该节点的所有取值
  299. vector<double> standard = compute_gain(examples);
  301. //要是采用C4.5,将上面一行注释掉,把下面一行的注释去掉即可
  302. //vector<double> standard = compute_gain_ratio(examples);
  303. tree.root = ;
  304. for (unsigned i = ; i < standard.size(); i++)
  305. {
  306. if (standard[i] >= standard[tree.root] && examples[][i] != )
  307. tree.root = i;
  308. }
  310. tree.branches = find_attribute_values(tree.root, examples);
  311. //第四步:根据节点的取值,将examples分成若干子集
  312. vector<vector<unsigned> > new_examples = drop_one_attribute(tree.root, examples);
  313. vector<vector<unsigned> > subset;
  314. for (unsigned i = ; i < tree.branches.size(); i++)
  315. {
  316. for (unsigned j = ; j < examples.size(); j++)
  317. {
  318. for (unsigned k = ; k < examples[].size(); k++)
  319. {
  320. if (tree.branches[i] == examples[j][k])
  321. subset.push_back(new_examples[j]);
  322. }
  323. }
  324. // 第五步:对每一个子集递归调用build_decision_tree()函数
  325. Tree new_tree;
  326. build_decision_tree(subset,new_tree);
  327. tree.children.push_back(new_tree);
  328. subset.clear();
  329. }
  330. }
  332. /*************************************************
  333. Function: print_tree()
  334. Description: 从第根节点开始,逐层将决策树输出到终
  335. 端显示
  337. Calls: print_tree();
  338. Input: 决策树,层数
  339. Output: 无
  340. *************************************************/
  341. void print_tree(Tree tree,unsigned depth)
  342. {
  343. for (unsigned d = ; d < depth; d++) cout << "\t";
  344. if (!tree.branches.empty()) //不是叶子节点
  345. {
  346. cout << attribute_names[tree.root] << endl;
  348. for (unsigned i = ; i < tree.branches.size(); i++)
  349. {
  350. for (unsigned d = ; d < depth + ; d++) cout << "\t";
  351. cout << attribute_values[tree.branches[i]] << endl;
  352. print_tree(tree.children[i],depth + );
  353. }
  354. }
  355. else //是叶子节点
  356. {
  357. cout << attribute_values[tree.root] << endl;
  358. }
  360. }
  362. int main()
  363. {
  364. vector<vector<unsigned> > rules(rule_num, vector<unsigned>(att_num + , ));
  365. for(unsigned i = ; i < rule_num; i++)
  366. {
  367. for(unsigned j = ; j <= att_num; j++)
  368. rules[i][j] = train_data[i][j];
  369. }
  370. Tree tree;
  371. build_decision_tree(rules, tree);
  372. cout << decision_tree_name << endl;
  373. print_tree(tree,);
  374. return ;
  375. }




所谓”百鸟在林,不如一鸟在手“, ID3和C4.5的思想都很简单,容易理解,但是在实现的的过程中由于数据结构的确定和递归调用等问题,还是调试了很久,收获很多,实践出真知!


