来自OpenCV2.3.1 sample/c/mushroom.cpp

1.首先读入agaricus-lepiota.data的训练样本。

样本中第一项是e或p代表有毒或无毒的标志位;其他是特征,可以把每个样本看做一个特征向量;

cvSeqPush( seq, el_ptr );读入序列seq中,每一项都存储一个样本即特征向量;

之后,把特征向量与标志位分别读入CvMat* data与CvMat* reponses中

还有一个CvMat* missing保留丢失位当前小于0位置;

2.训练样本

  1. dtree = new CvDTree;
  2. dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
  3. CvDTreeParams( 8, // max depth
  4. 10, // min sample count 样本数小于10时,停止分裂
  5. 0, // regression accuracy: N/A here;回归树的限制精度
  6. true, // compute surrogate split, as we have missing data;;为真时,计算missing data和变量的重要性
  7. 15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义
  8. 10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds
  9. true, // use 1SE rule => smaller tree;If true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确
  10. true, // throw away the pruned tree branches
  11. priors //错分类的代价我们判断的:有毒VS无毒 错误的代价比 the array of priors, the bigger p_weight, the more attention
  12. // to the poisonous mushrooms
  13. // (a mushroom will be judjed to be poisonous with bigger chance)
  14. ));

3.

  1. double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;

4.interactive_classification通过人工输入特征来判断。

    1. #include "opencv2/core/core_c.h"
    2. #include "opencv2/ml/ml.hpp"
    3. #include <stdio.h>
    4. void help()
    5. {
    6. printf("\nThis program demonstrated the use of OpenCV's decision tree function for learning and predicting data\n"
    7. "Usage :\n"
    8. "./mushroom <path to agaricus-lepiota.data>\n"
    9. "\n"
    10. "The sample demonstrates how to build a decision tree for classifying mushrooms.\n"
    11. "It uses the sample base agaricus-lepiota.data from UCI Repository, here is the link:\n"
    12. "\n"
    13. "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"
    14. "UCI Repository of machine learning databases\n"
    15. "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"
    16. "Irvine, CA: University of California, Department of Information and Computer Science.\n"
    17. "\n"
    18. "// loads the mushroom database, which is a text file, containing\n"
    19. "// one training sample per row, all the input variables and the output variable are categorical,\n"
    20. "// the values are encoded by characters.\n\n");
    21. }
    22. int mushroom_read_database( const char* filename, CvMat** data, CvMat** missing, CvMat** responses )
    23. {
    24. const int M = 1024;
    25. FILE* f = fopen( filename, "rt" );
    26. CvMemStorage* storage;
    27. CvSeq* seq;
    28. char buf[M+2], *ptr;
    29. float* el_ptr;
    30. CvSeqReader reader;
    31. int i, j, var_count = 0;
    32. if( !f )
    33. return 0;
    34. // read the first line and determine the number of variables
    35. if( !fgets( buf, M, f ))
    36. {
    37. fclose(f);
    38. return 0;
    39. }
    40. for( ptr = buf; *ptr != '\0'; ptr++ )
    41. var_count += *ptr == ',';//计算每个样本的数量,每个样本一个“,”,样本数量=var_count+1;
    42. assert( ptr - buf == (var_count+1)*2 );
    43. // create temporary memory storage to store the whole database
    44. //把样本存入seq中,存储空间是storage;
    45. el_ptr = new float[var_count+1];
    46. storage = cvCreateMemStorage();
    47. seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );//
    48. for(;;)
    49. {
    50. for( i = 0; i <= var_count; i++ )
    51. {
    52. int c = buf[i*2];
    53. el_ptr[i] = c == '?' ? -1.f : (float)c;
    54. }
    55. if( i != var_count+1 )
    56. break;
    57. cvSeqPush( seq, el_ptr );
    58. if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
    59. break;
    60. }
    61. fclose(f);
    62. // allocate the output matrices and copy the base there
    63. *data = cvCreateMat( seq->total, var_count, CV_32F );//行数:样本数量;列数:样本大小;
    64. *missing = cvCreateMat( seq->total, var_count, CV_8U );
    65. *responses = cvCreateMat( seq->total, 1, CV_32F );//样本标志;
    66. cvStartReadSeq( seq, &reader );
    67. for( i = 0; i < seq->total; i++ )
    68. {
    69. const float* sdata = (float*)reader.ptr + 1;
    70. float* ddata = data[0]->data.fl + var_count*i;
    71. float* dr = responses[0]->data.fl + i;
    72. uchar* dm = missing[0]->data.ptr + var_count*i;
    73. for( j = 0; j < var_count; j++ )
    74. {
    75. ddata[j] = sdata[j];
    76. dm[j] = sdata[j] < 0;
    77. }
    78. *dr = sdata[-1];//样本的第一个位置是标志;
    79. CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    80. }
    81. cvReleaseMemStorage( &storage );
    82. delete el_ptr;
    83. return 1;
    84. }
    85. CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,
    86. const CvMat* responses, float p_weight )
    87. {
    88. CvDTree* dtree;
    89. CvMat* var_type;
    90. int i, hr1 = 0, hr2 = 0, p_total = 0;
    91. float priors[] = { 1, p_weight };
    92. var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
    93. cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categorical
    94. dtree = new CvDTree;
    95. dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
    96. CvDTreeParams( 8, // max depth
    97. 10, // min sample count样本数小于10时,停止分裂
    98. 0, // regression accuracy: N/A here;回归树的限制精度
    99. true, // compute surrogate split, as we have missing data;为真时,计算missing data和可变的重要性正确度
    100. 15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义
    101. 10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation
    102. true, // use 1SE rule => smaller treeIf true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确
    103. true, // throw away the pruned tree branches
    104. priors // the array of priors, the bigger p_weight, the more attention
    105. // to the poisonous mushrooms
    106. // (a mushroom will be judjed to be poisonous with bigger chance)
    107. ));
    108. // compute hit-rate on the training database, demonstrates predict usage.
    109. for( i = 0; i < data->rows; i++ )
    110. {
    111. CvMat sample, mask;
    112. cvGetRow( data, &sample, i );
    113. cvGetRow( missing, &mask, i );
    114. double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;
    115. int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;//大于阈值FLT_EPSILON被判断为误检
    116. if( d )
    117. {
    118. if( r != 'p' )
    119. hr1++;
    120. else
    121. hr2++;
    122. }
    123. p_total += responses->data.fl[i] == 'p';
    124. }
    125. printf( "Results on the training database:\n"
    126. "\tPoisonous mushrooms mis-predicted: %d (%g%%)\n"
    127. "\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,
    128. hr2, (double)hr2*100/(data->rows - p_total) );
    129. cvReleaseMat( &var_type );
    130. return dtree;
    131. }
    132. static const char* var_desc[] =
    133. {
    134. "cap shape (bell=b,conical=c,convex=x,flat=f)",
    135. "cap surface (fibrous=f,grooves=g,scaly=y,smooth=s)",
    136. "cap color (brown=n,buff=b,cinnamon=c,gray=g,green=r,\n\tpink=p,purple=u,red=e,white=w,yellow=y)",
    137. "bruises? (bruises=t,no=f)",
    138. "odor (almond=a,anise=l,creosote=c,fishy=y,foul=f,\n\tmusty=m,none=n,pungent=p,spicy=s)",
    139. "gill attachment (attached=a,descending=d,free=f,notched=n)",
    140. "gill spacing (close=c,crowded=w,distant=d)",
    141. "gill size (broad=b,narrow=n)",
    142. "gill color (black=k,brown=n,buff=b,chocolate=h,gray=g,\n\tgreen=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y)",
    143. "stalk shape (enlarging=e,tapering=t)",
    144. "stalk root (bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r)",
    145. "stalk surface above ring (ibrous=f,scaly=y,silky=k,smooth=s)",
    146. "stalk surface below ring (ibrous=f,scaly=y,silky=k,smooth=s)",
    147. "stalk color above ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
    148. "stalk color below ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
    149. "veil type (partial=p,universal=u)",
    150. "veil color (brown=n,orange=o,white=w,yellow=y)",
    151. "ring number (none=n,one=o,two=t)",
    152. "ring type (cobwebby=c,evanescent=e,flaring=f,large=l,\n\tnone=n,pendant=p,sheathing=s,zone=z)",
    153. "spore print color (black=k,brown=n,buff=b,chocolate=h,green=r,\n\torange=o,purple=u,white=w,yellow=y)",
    154. "population (abundant=a,clustered=c,numerous=n,\n\tscattered=s,several=v,solitary=y)",
    155. "habitat (grasses=g,leaves=l,meadows=m,paths=p\n\turban=u,waste=w,woods=d)",
    156. 0
    157. };
    158. void print_variable_importance( CvDTree* dtree, const char** var_desc )
    159. {
    160. const CvMat* var_importance = dtree->get_var_importance();
    161. int i;
    162. char input[1000];
    163. if( !var_importance )
    164. {
    165. printf( "Error: Variable importance can not be retrieved\n" );
    166. return;
    167. }
    168. printf( "Print variable importance information? (y/n) " );
    169. scanf( "%1s", input );
    170. if( input[0] != 'y' && input[0] != 'Y' )
    171. return;
    172. for( i = 0; i < var_importance->cols*var_importance->rows; i++ )
    173. {
    174. double val = var_importance->data.db[i];
    175. if( var_desc )
    176. {
    177. char buf[100];
    178. int len = strchr( var_desc[i], '(' ) - var_desc[i] - 1;
    179. strncpy( buf, var_desc[i], len );
    180. buf[len] = '\0';
    181. printf( "%s", buf );
    182. }
    183. else
    184. printf( "var #%d", i );
    185. printf( ": %g%%\n", val*100. );
    186. }
    187. }
    188. void interactive_classification( CvDTree* dtree, const char** var_desc )
    189. {
    190. char input[1000];
    191. const CvDTreeNode* root;
    192. CvDTreeTrainData* data;
    193. if( !dtree )
    194. return;
    195. root = dtree->get_root();
    196. data = dtree->get_data();
    197. for(;;)
    198. {
    199. const CvDTreeNode* node;
    200. printf( "Start/Proceed with interactive mushroom classification (y/n): " );
    201. scanf( "%1s", input );
    202. if( input[0] != 'y' && input[0] != 'Y' )
    203. break;
    204. printf( "Enter 1-letter answers, '?' for missing/unknown value...\n" );
    205. // custom version of predict
    206. //传统的预测方式;
    207. node = root;
    208. for(;;)
    209. {
    210. CvDTreeSplit* split = node->split;
    211. int dir = 0;
    212. if( !node->left || node->Tn <= dtree->get_pruned_tree_idx() || !node->split )
    213. break;
    214. for( ; split != 0; )
    215. {
    216. int vi = split->var_idx, j;
    217. int count = data->cat_count->data.i[vi];
    218. const int* map = data->cat_map->data.i + data->cat_ofs->data.i[vi];
    219. printf( "%s: ", var_desc[vi] );
    220. scanf( "%1s", input );
    221. if( input[0] == '?' )
    222. {
    223. split = split->next;
    224. continue;
    225. }
    226. // convert the input character to the normalized value of the variable
    227. for( j = 0; j < count; j++ )
    228. if( map[j] == input[0] )
    229. break;
    230. if( j < count )
    231. {
    232. dir = (split->subset[j>>5] & (1 << (j&31))) ? -1 : 1;
    233. if( split->inversed )
    234. dir = -dir;
    235. break;
    236. }
    237. else
    238. printf( "Error: unrecognized value\n" );
    239. }
    240. if( !dir )
    241. {
    242. printf( "Impossible to classify the sample\n");
    243. node = 0;
    244. break;
    245. }
    246. node = dir < 0 ? node->left : node->right;
    247. }
    248. if( node )
    249. printf( "Prediction result: the mushroom is %s\n",
    250. node->class_idx == 0 ? "EDIBLE" : "POISONOUS" );
    251. printf( "\n-----------------------------\n" );
    252. }
    253. }
    254. int main( int argc, char** argv )
    255. {
    256. CvMat *data = 0, *missing = 0, *responses = 0;
    257. CvDTree* dtree;
    258. const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";
    259. help();
    260. if( !mushroom_read_database( base_path, &data, &missing, &responses ) )
    261. {
    262. printf( "\nUnable to load the training database\n\n");
    263. help();
    264. return -1;
    265. }
    266. dtree = mushroom_create_dtree( data, missing, responses,
    267. 10 // poisonous mushrooms will have 10x higher weight in the decision tree
    268. );
    269. cvReleaseMat( &data );
    270. cvReleaseMat( &missing );
    271. cvReleaseMat( &responses );
    272. print_variable_importance( dtree, var_desc );
    273. interactive_classification( dtree, var_desc );
    274. delete dtree;
    275. return 0;
    276. }
    277. //from: http://blog.csdn.net/yangtrees/article/details/7490852

OpenCV码源笔记——Decision Tree决策树的更多相关文章

  1. OpenCV码源笔记——RandomTrees (二)(Forest)

    源码细节: ● 训练函数 bool CvRTrees::train( const CvMat* _train_data, int _tflag,                        cons ...

  2. OpenCV码源笔记——RandomTrees (一)

    OpenCV2.3中Random Trees(R.T.)的继承结构: API: CvRTParams 定义R.T.训练用参数,CvDTreeParams的扩展子类,但并不用到CvDTreeParams ...

  3. Decision tree(决策树)算法初探

    0. 算法概述 决策树(decision tree)是一种基本的分类与回归方法.决策树模型呈树形结构(二分类思想的算法模型往往都是树形结构) 0x1:决策树模型的不同角度理解 在分类问题中,表示基于特 ...

  4. decision tree 决策树(一)

    一 决策树 原理:分类决策树模型是一种描述对实例进行分类的树形结构.决策树由结点(node)和有向边(directed edge)组成.结点有两种类型:内部结点(internal node)和叶结点( ...

  5. Decision tree——决策树

    基本流程 决策树是通过分次判断样本属性来进行划分样本类别的机器学习模型.每个树的结点选择一个最优属性来进行样本的分流,最终将样本类别划分出来. 决策树的关键就是分流时最优属性$a$的选择.使用所谓信息 ...

  6. 决策树Decision Tree 及实现

    Decision Tree 及实现 标签: 决策树熵信息增益分类有监督 2014-03-17 12:12 15010人阅读 评论(41) 收藏 举报  分类: Data Mining(25)  Pyt ...

  7. Spark MLlib - Decision Tree源码分析

    http://spark.apache.org/docs/latest/mllib-decision-tree.html 以决策树作为开始,因为简单,而且也比较容易用到,当前的boosting或ran ...

  8. [ML学习笔记] 决策树与随机森林(Decision Tree&Random Forest)

    [ML学习笔记] 决策树与随机森林(Decision Tree&Random Forest) 决策树 决策树算法以树状结构表示数据分类的结果.每个决策点实现一个具有离散输出的测试函数,记为分支 ...

  9. 【机器学习】决策树(Decision Tree) 学习笔记

    [机器学习]决策树(decision tree) 学习笔记 标签(空格分隔): 机器学习 决策树简介 决策树(decision tree)是一个树结构(可以是二叉树或非二叉树).其每个非叶节点表示一个 ...

随机推荐

  1. 【转】 管理CPU 亲和性

    简单地说,CPU 亲和性(affinity) 就是进程要在某个给定的 CPU 上尽量长时间地运行而不被迁移到其他处理器的倾向性.Linux 内核进程调度器天生就具有被称为 软 CPU 亲和性(affi ...

  2. Teradata 的rank() 和 row_number() 函数

    Teradata数据库中也有和oracle类似的分析函数,功能基本一样.示例如下: RANK() 函数   SELECT * FROM salestbl ORDER BY 1,2; storeid p ...

  3. CPU 时间片 分时 轮转调度

    时间片即CPU分配给各个程序的时间,每个线程被分配一个时间段,称作它的时间片,即该进程允许运行的时间,使各个程序从表面上看是同时进行的.如果在时间片结束时进程还在运行,则CPU将被剥夺并分配给另一个进 ...

  4. 20145120 《Java程序设计》实验一实验报告

    20145120 <Java程序设计>实验一实验报告 实验名称:Java开发环境的熟悉 实验目的与要求: 1.使用JDK编译.运行简单的Java程序:(第1周学习总结) 2.编辑.编译.运 ...

  5. Oracle 多行记录合并/连接/聚合字符串的几种方法

    怎么合并多行记录的字符串,一直是oracle新手喜欢问的SQL问题之一,关于这个问题的帖子我看过不下30个了,现在就对这个问题,进行一个总结.-什么是合并多行字符串(连接字符串)呢,例如: SQL&g ...

  6. Backbone.Events—纯净MVC框架的双向绑定基石

    Backbone.Events-纯净MVC框架的双向绑定基石 为什么Backbone是纯净MVC? 在这个大前端时代,各路MV*框架如雨后春笋搬涌现出来,在infoQ上有一篇 12种JavaScrip ...

  7. Http之Get/Post请求区别

    Http之Get/Post请求区别 1.HTTP请求格式: <request line> <headers> <blank line> [<request-b ...

  8. QT for android 比较完美解决 全屏问题

    项目用到QT qml,需要在android下面全屏显示,折腾了一晚上,搞定,分享下,希望能帮助他人. 参考 Qt on Android:让 Qt Widgets 和 Qt Quick 应用全屏显示 该 ...

  9. Exception in thread "http-bio-8081-exec-3" java.lang.OutOfMemoryError: PermGen space

    前言: 在http://www.cnblogs.com/wql025/p/4865673.html一文中我曾描述这种异常也提供了解决方式,但效果不太理想,现在用本文的方式,效果显著. 目前此项目只能登 ...

  10. Eclipse插件开发 swt ComboBoxCellEditor CCombo 下拉框高度

    效果图:     代码如下 bindingPageTableViewer.setCellModifier(new ICellModifier() { public boolean canModify( ...