OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)
HaarTraining关键的部分是建立基分类器classifier,OpenCV中所採用的是CART(决策树的一种):通过调用cvCreateMTStumpClassifier来完毕。
这里我讨论利用回归的方法来分裂结点。分类的方法仅仅是在分裂结点的方法与之不同而已。
cvCreateMTStumpClassifier
- //设置决策树分类误差计算方法
- stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;
- //设置class step和ydata
- ydata = trainClasses->data.ptr;
- if( trainClasses->rows == 1 )
- {
- m = trainClasses->cols;
- ystep = CV_ELEM_SIZE( trainClasses->type );
- }
- else
- {
- m = trainClasses->rows;
- ystep = trainClasses->step;
- }
- //设置weight step和wdata
- wdata = weights->data.ptr;
- if( weights->rows == 1 )
- {
- assert( weights->cols == m );
- wstep = CV_ELEM_SIZE( weights->type );
- }
- else
- {
- assert( weights->rows == m );
- wstep = weights->step;
- }
- //设置步长,地址等參数,用于获取idxCache内容
- if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
- {
- sortedtype =
- CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
- assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
- || sortedtype == CV_32FC1 );
- sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
- sortedsstep = CV_ELEM_SIZE( sortedtype );
- sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
- sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
- sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
- }
- if( trainData == NULL )
- {
- assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
- n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
- assert( n > 0 );
- }
- //设置步长,地址等參数,用于获取dataCache内容
- else
- {
- assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
- data = trainData->data.ptr;
- if( CV_IS_ROW_SAMPLE( flags ) )
- {
- cstep = CV_ELEM_SIZE( trainData->type );
- sstep = trainData->step;
- assert( m == trainData->rows );
- datan = n = trainData->cols;
- }
- else
- {
- sstep = CV_ELEM_SIZE( trainData->type );
- cstep = trainData->step;
- assert( m == trainData->cols );
- datan = n = trainData->rows;
- }
- if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
- {
- n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
- }
- }
可能研究代码到这里的朋友仍然不清楚idxCache和valCache的作用。
我这里做一点简单的说明:
valCache是设置在训练前有多少特征值被提前算出存放在内存中。idxCache是valCache中每种特征按特征值从小到大排列的样本的序号。
内存大小是通过执行程序的命令行參数设置的。在cvHaartraining.cpp中我们能够找到这句话,当中float和short,各自是valCache和idxCache存放内容的基本类型。
- //1MB == 1048576B 计算一个样本中有多少个特征能被pre计算放在内存中
- numprecalculated = (int) ( ((size_t) mem) * ((size_t) 1048576) /
- ( ((size_t) (npos + nneg)) * (sizeof( float ) + sizeof( short )) ) );
为了方便理解,我把两者的内存模型画了出来
要注意idxCache中每行的index排列是示意图。
比方第一行代表feature1从小到大的index顺序。从图中能够看出。sample1的特征值feature1 < sample0的特征值feature1<...<sample n < sample n-1。
利用idxCache数组我们能够方便按特征值的从小到大遍历valCache,而且节省了空间。从float->short。
理解了这两个cache之后我们再回到上面的代码,能够发现这里做得仅仅是设置步长和cache首地址的一些操作。为以下開始的遍历做好准备。
跳过一些变量的初始化步骤,我们来到构建决策树stump的部分,而且为了方便阅读核心代码。去掉了其它一些基于移植的代码
- while( t_compidx < n )
- {
- //选择计算前100种特征
- t_n = portion;
- if( t_compidx < datan )
- {
- t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
- t_data = data;
- t_cstep = cstep;
- t_sstep = sstep;
- }
- else
- {
- }
- if( sorteddata != NULL )
- {
- }
- else
- {
- /* have sorted indices */
- switch( sortedtype )
- {
- case CV_16SC1:
- //选择某个样本的某个特征值作为结点
- for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
- {
- if( findStumpThreshold_16s[stumperror](
- t_data + ti * t_cstep, t_sstep,
- wdata, wstep, ydata, ystep,
- sorteddata + ti * sortedcstep, sortedsstep, sortedm,
- &lerror, &rerror,
- &threshold, &left, &right,
- &sumw, &sumwy, &sumwyy ) )
- {
- optcompidx = ti;
- }
- }
- break;
- }
- }
- }}
这里datan代表的是一个检測窗体包括的特征数目。portion代表以多少的行为单位进行计算。每一个循环选取valCache中的portion行进行计算,应该是为了发挥并行计算的优势,假如设置了并行计算的宏的话。
findStumpThreshold_32[stumperror]是一个函数指针,利用这个函数我们能够选择某个样本的某个特征值作为决策树的一个结点。
这里结点分裂方法我选择的是最小残差和的方法。即统计利用某个特征值进行分类后,左右子树中类间的残差之和。最小残差和相应的特征值就是满足要求的结点。
OpenCV1.0中利用宏定义的方式实现了这个函数
- #define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error ) \
- CV_BOOST_IMPL int icvFindStumpThreshold_##suffix( \
- uchar* data, size_t datastep, \
- uchar* wdata, size_t wstep, \
- uchar* ydata, size_t ystep, \
- uchar* idxdata, size_t idxstep, int num, \
- float* lerror, \
- float* rerror, \
- float* threshold, float* left, float* right, \
- float* sumw, float* sumwy, float* sumwyy ) \
- { \
- int found = 0; \
- float wyl = 0.0F; \
- float wl = 0.0F; \
- float wyyl = 0.0F; \
- float wyr = 0.0F; \
- float wr = 0.0F; \
- \
- float curleft = 0.0F; \
- float curright = 0.0F; \
- float* prevval = NULL; \
- float* curval = NULL; \
- float curlerror = 0.0F; \
- float currerror = 0.0F; \
- float wposl; \
- float wposr; \
- \
- int i = 0; \
- int idx = 0; \
- \
- wposl = wposr = 0.0F; \
- if( *sumw == FLT_MAX ) \
- { \
- /* calculate sums */ \
- float *y = NULL; \
- float *w = NULL; \
- float wy = 0.0F; \
- \
- *sumw = 0.0F; \
- *sumwy = 0.0F; \
- *sumwyy = 0.0F; \
- for( i = 0; i < num; i++ ) \
- { \
- idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
- w = (float*) (wdata + idx * wstep); \
- *sumw += *w; \
- y = (float*) (ydata + idx * ystep); \
- wy = (*w) * (*y); \
- *sumwy += wy; \
- *sumwyy += wy * (*y); \
- } \
- } \
- \
- for( i = 0; i < num; i++ ) \
- { \
- idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
- curval = (float*) (data + idx * datastep); \
- /* for debug purpose */ \
- if( i > 0 ) assert( (*prevval) <= (*curval) ); \
- \
- wyr = *sumwy - wyl; \
- wr = *sumw - wl; \
- \
- if( wl > 0.0 ) curleft = wyl / wl; \
- else curleft = 0.0F; \
- \
- if( wr > 0.0 ) curright = wyr / wr; \
- else curright = 0.0F; \
- \
- error \
- \
- if( curlerror + currerror < (*lerror) + (*rerror) ) \
- { \
- (*lerror) = curlerror; \
- (*rerror) = currerror; \
- *threshold = *curval; \
- if( i > 0 ) { \
- *threshold = 0.5F * (*threshold + *prevval); \
- } \
- *left = curleft; \
- *right = curright; \
- found = 1; \
- } \
- \
- do \
- { \
- wl += *((float*) (wdata + idx * wstep)); \
- wyl += (*((float*) (wdata + idx * wstep))) \
- * (*((float*) (ydata + idx * ystep))); \
- wyyl += *((float*) (wdata + idx * wstep)) \
- * (*((float*) (ydata + idx * ystep))) \
- * (*((float*) (ydata + idx * ystep))); \
- } \
- while( (++i) < num && \
- ( *((float*) (data + (idx = \
- (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep)) \
- == *curval ) ); \
- --i; \
- prevval = curval; \
- } /* for each value */ \
- \
- return found; \
- }
这里几个关键的变量:ydata是某个特征所代表的类别,正负样本分别以1。-1进行标注,wdata是正负样本相应的权值,data指的就是valCache的某一行。
程序进来的时候推断sumw是否初始化,没有初始化就进行赋值。因为同一个训练集每一个样本都仅仅相应一个ydata和wdata(每一个样本相应非常多个Haar特征,两者有差别),因此这里的sumw,sumwyy,sumwy都是一个确定的值。提前计算好,在后面的迭代中就不必反复计算。
接下来,依据idxCache中某一行(视迭代次数而定)的index,按从小到大的顺序遍历ValCache中相应行的特征值。也就是不相同本的同一特征值。并将其作为结点,尝试对样本进行划分。curleft和curright分别代表左右子树的类别的加权平均值。然后利用error宏计算左右子树的残差
- #define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type ) \
- ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type, \
- /* calculate error (sum of squares) */ \
- /* err = sum( w * (y - left(rigt)Val)^2 ) */ \
- curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl; \
- currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
- )
最后的一个do-while循环。就是用来跳过和当前结点同样的特征值。尽管以兴许的、同样的值作为结点划分左右子树,残差平方和可能会改变,可是决策树划分的最小单位是特征值的种类,由于在利用决策树进行分类的时候,必须对同样的特征值做出一样的决策(该划入左子树还是该划入右子树)。
总结
OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)的更多相关文章
- java代码解析二维码
java代码解析二维码一般步骤 本文采用的是google的zxing技术进行解析二维码技术,解析二维码的一般步骤如下: 一.下载zxing-core的jar包: 二.创建一个BufferedImage ...
- GraphSAGE 代码解析(二) - layers.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(三) - aggregator ...
- GraphSAGE 代码解析(四) - models.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(二) - layers.py ...
- GraphSAGE 代码解析(三) - aggregators.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(二) - layers.py ...
- GraphSAGE 代码解析(一) - unsupervised_train.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(二) - layers.py GraphSAGE 代码解析(三) - aggregators.py GraphSA ...
- asp.net C#生成和解析二维码代码
类库文件我们在文件最后面下载 [ThoughtWorks.QRCode.dll 就是类库] 使用时需要增加: using ThoughtWorks.QRCode.Codec;using Thought ...
- JavaScript “跑马灯”抽奖活动代码解析与优化(二)
既然是要编写插件.那么叫做"插件"的东西肯定是具有的某些特征能够满足我们平时开发的需求或者是提高我们的开发效率.那么叫做插件的东西应该具有哪些基本特征呢?让我们来总结一下: 1.J ...
- RobHess的SIFT代码解析步骤三
平台:win10 x64 +VS 2015专业版 +opencv-2.4.11 + gtk_-bundle_2.24.10_win32 主要参考:1.代码:RobHess的SIFT源码 2.书:王永明 ...
- Fixflow引擎解析(二)(模型) - BPMN2.0读写
Fixflow引擎解析(四)(模型) - 通过EMF扩展BPMN2.0元素 Fixflow引擎解析(三)(模型) - 创建EMF模型来读写XML文件 Fixflow引擎解析(二)(模型) - BPMN ...
随机推荐
- android 从其他app接收分享的内容
Receiving Content from Other Apps[从其他app接收分享的内容] 就像你的程序能够发送数据到其他程序一样,其他程序也能够简单的接收发送过来的数据.需要考虑的是用户与你的 ...
- Join的实现步骤 以及连接的概念
Join的实现步骤 以及连接的概念 我们常说连接有三种,即 交叉连接.内连接.外连接,这三者的概念很容易模糊,现在我们先放下概念,搞清楚完整连接实现的步骤: 一个完整的连接有三个步骤:.做笛卡儿积: ...
- Ubuntu 挂载ISO文件的方法
1.在终端中输入:sudo mkdir /media/iso 在/media下生成一个iso文件夹用来挂载iso文件2.然后输入:sudo mount -o loop /home/X/X/XXXX.i ...
- Hough变换在opencv中的应用
霍夫曼变换(Hough Transform)的原理 霍夫曼变换是一种可以检测出某种特殊形状的算法,OpenCV中用霍夫曼变换来检测出图像中的直线.椭圆和其他几何图形.由它改进的算法,可以用来检测任何形 ...
- jQuery.localStorage() - jQuery SDK API
jQuery.localStorage() - jQuery SDK API jQuery.localStorage() From jQuery SDK API Jump to: navigati ...
- ASP.NET - 匹配标签中的内容
string str = @"<td>Csdn</td>\r\n<td>V1.0</td>\r\n<td>2014-10-23&l ...
- 编写自定义的JDBC框架与策略模式
本篇根据上一篇利用数据库的几种元数据来仿造Apache公司的开源DbUtils工具类集合来编写自己的JDBC框架.也就是说在本篇中很大程度上的代码都和DbUtils中相似,学完本篇后即更容易了解DbU ...
- Windows Azure入门教学系列 (八):使用Windows Azure Drive
我们知道,由于云端的特殊性,通常情况下,对文件系统的读写建议使用Blob Storage来代替.这就产生了一个问题:对于一个已经写好的本地应用程序,其中使用了NTFS API对本地文件系统读写的代码是 ...
- 在webx.ml中 配置struts2 后 welcome-file-list 失效的解决办法
struts2 <filter-mapping> <filter-name>struts2</filter-name> <url-pattern>*.a ...
- Swift - 给游戏添加背景音乐和音效(SpriteKit游戏开发)
游戏少不了背景音乐和音效.下面我们通过创建一个管理音效的类,来实现背景音乐的播放,同时点击屏幕可以播放相应的音效. 声音管理类 SoundManager.swift 1 2 3 4 5 6 7 8 9 ...