HaarTraining关键的部分是建立基分类器classifier,OpenCV中所採用的是CART(决策树的一种):通过调用cvCreateMTStumpClassifier来完毕。

这里我讨论利用回归的方法来分裂结点。分类的方法仅仅是在分裂结点的方法与之不同而已。

cvCreateMTStumpClassifier

	//设置决策树分类误差计算方法
stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error; //设置class step和ydata
ydata = trainClasses->data.ptr;
if( trainClasses->rows == 1 )
{
m = trainClasses->cols;
ystep = CV_ELEM_SIZE( trainClasses->type );
}
else
{
m = trainClasses->rows;
ystep = trainClasses->step;
}
//设置weight step和wdata
wdata = weights->data.ptr;
if( weights->rows == 1 )
{
assert( weights->cols == m );
wstep = CV_ELEM_SIZE( weights->type );
}
else
{
assert( weights->rows == m );
wstep = weights->step;
} //设置步长,地址等參数,用于获取idxCache内容
if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
{
sortedtype =
CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
|| sortedtype == CV_32FC1 );
sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
sortedsstep = CV_ELEM_SIZE( sortedtype );
sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
} if( trainData == NULL )
{
assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
assert( n > 0 );
}
//设置步长,地址等參数,用于获取dataCache内容
else
{
assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
data = trainData->data.ptr;
if( CV_IS_ROW_SAMPLE( flags ) )
{
cstep = CV_ELEM_SIZE( trainData->type );
sstep = trainData->step;
assert( m == trainData->rows );
datan = n = trainData->cols;
}
else
{
sstep = CV_ELEM_SIZE( trainData->type );
cstep = trainData->step;
assert( m == trainData->cols );
datan = n = trainData->rows;
}
if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
{
n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
}
}

可能研究代码到这里的朋友仍然不清楚idxCache和valCache的作用。

我这里做一点简单的说明:

valCache是设置在训练前有多少特征值被提前算出存放在内存中。idxCache是valCache中每种特征按特征值从小到大排列的样本的序号。

内存大小是通过执行程序的命令行參数设置的。在cvHaartraining.cpp中我们能够找到这句话,当中float和short,各自是valCache和idxCache存放内容的基本类型。

	//1MB == 1048576B  计算一个样本中有多少个特征能被pre计算放在内存中
numprecalculated = (int) ( ((size_t) mem) * ((size_t) 1048576) /
( ((size_t) (npos + nneg)) * (sizeof( float ) + sizeof( short )) ) );

为了方便理解,我把两者的内存模型画了出来

要注意idxCache中每行的index排列是示意图。

比方第一行代表feature1从小到大的index顺序。从图中能够看出。sample1的特征值feature1 < sample0的特征值feature1<...<sample n < sample n-1。

利用idxCache数组我们能够方便按特征值的从小到大遍历valCache,而且节省了空间。从float->short。

理解了这两个cache之后我们再回到上面的代码,能够发现这里做得仅仅是设置步长和cache首地址的一些操作。为以下開始的遍历做好准备。

跳过一些变量的初始化步骤,我们来到构建决策树stump的部分,而且为了方便阅读核心代码。去掉了其它一些基于移植的代码

 while( t_compidx < n )
{
//选择计算前100种特征
t_n = portion;
if( t_compidx < datan )
{
t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
t_data = data;
t_cstep = cstep;
t_sstep = sstep;
}
else
{ } if( sorteddata != NULL )
{ }
else
{
/* have sorted indices */
switch( sortedtype )
{
case CV_16SC1: //选择某个样本的某个特征值作为结点
for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
{
if( findStumpThreshold_16s[stumperror](
t_data + ti * t_cstep, t_sstep,
wdata, wstep, ydata, ystep,
sorteddata + ti * sortedcstep, sortedsstep, sortedm,
&lerror, &rerror,
&threshold, &left, &right,
&sumw, &sumwy, &sumwyy ) )
{
optcompidx = ti;
}
}
break; }
}
}}

这里datan代表的是一个检測窗体包括的特征数目。portion代表以多少的行为单位进行计算。每一个循环选取valCache中的portion行进行计算,应该是为了发挥并行计算的优势,假如设置了并行计算的宏的话。

findStumpThreshold_32[stumperror]是一个函数指针,利用这个函数我们能够选择某个样本的某个特征值作为决策树的一个结点。

这里结点分裂方法我选择的是最小残差和的方法。即统计利用某个特征值进行分类后,左右子树中类间的残差之和。最小残差和相应的特征值就是满足要求的结点。

OpenCV1.0中利用宏定义的方式实现了这个函数

#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
CV_BOOST_IMPL int icvFindStumpThreshold_##suffix( \
uchar* data, size_t datastep, \
uchar* wdata, size_t wstep, \
uchar* ydata, size_t ystep, \
uchar* idxdata, size_t idxstep, int num, \
float* lerror, \
float* rerror, \
float* threshold, float* left, float* right, \
float* sumw, float* sumwy, float* sumwyy ) \
{ \
int found = 0; \
float wyl = 0.0F; \
float wl = 0.0F; \
float wyyl = 0.0F; \
float wyr = 0.0F; \
float wr = 0.0F; \
\
float curleft = 0.0F; \
float curright = 0.0F; \
float* prevval = NULL; \
float* curval = NULL; \
float curlerror = 0.0F; \
float currerror = 0.0F; \
float wposl; \
float wposr; \
\
int i = 0; \
int idx = 0; \
\
wposl = wposr = 0.0F; \
if( *sumw == FLT_MAX ) \
{ \
/* calculate sums */ \
float *y = NULL; \
float *w = NULL; \
float wy = 0.0F; \
\
*sumw = 0.0F; \
*sumwy = 0.0F; \
*sumwyy = 0.0F; \
for( i = 0; i < num; i++ ) \
{ \
idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
w = (float*) (wdata + idx * wstep); \
*sumw += *w; \
y = (float*) (ydata + idx * ystep); \
wy = (*w) * (*y); \
*sumwy += wy; \
*sumwyy += wy * (*y); \
} \
} \
\
for( i = 0; i < num; i++ ) \
{ \
idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
curval = (float*) (data + idx * datastep); \
/* for debug purpose */ \
if( i > 0 ) assert( (*prevval) <= (*curval) ); \
\
wyr = *sumwy - wyl; \
wr = *sumw - wl; \
\
if( wl > 0.0 ) curleft = wyl / wl; \
else curleft = 0.0F; \
\
if( wr > 0.0 ) curright = wyr / wr; \
else curright = 0.0F; \
\
error \
\
if( curlerror + currerror < (*lerror) + (*rerror) ) \
{ \
(*lerror) = curlerror; \
(*rerror) = currerror; \
*threshold = *curval; \
if( i > 0 ) { \
*threshold = 0.5F * (*threshold + *prevval); \
} \
*left = curleft; \
*right = curright; \
found = 1; \
} \
\
do \
{ \
wl += *((float*) (wdata + idx * wstep)); \
wyl += (*((float*) (wdata + idx * wstep))) \
* (*((float*) (ydata + idx * ystep))); \
wyyl += *((float*) (wdata + idx * wstep)) \
* (*((float*) (ydata + idx * ystep))) \
* (*((float*) (ydata + idx * ystep))); \
} \
while( (++i) < num && \
( *((float*) (data + (idx = \
(int) ( *((type*) (idxdata + i*idxstep))) ) * datastep)) \
== *curval ) ); \
--i; \
prevval = curval; \
} /* for each value */ \
\
return found; \
}

这里几个关键的变量:ydata是某个特征所代表的类别,正负样本分别以1。-1进行标注,wdata是正负样本相应的权值,data指的就是valCache的某一行。

程序进来的时候推断sumw是否初始化,没有初始化就进行赋值。因为同一个训练集每一个样本都仅仅相应一个ydata和wdata(每一个样本相应非常多个Haar特征,两者有差别),因此这里的sumw,sumwyy,sumwy都是一个确定的值。提前计算好,在后面的迭代中就不必反复计算。

接下来,依据idxCache中某一行(视迭代次数而定)的index,按从小到大的顺序遍历ValCache中相应行的特征值。也就是不相同本的同一特征值。并将其作为结点,尝试对样本进行划分。curleft和curright分别代表左右子树的类别的加权平均值。然后利用error宏计算左右子树的残差

#define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type )                                  \
ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type, \
/* calculate error (sum of squares) */ \
/* err = sum( w * (y - left(rigt)Val)^2 ) */ \
curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl; \
currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
)

最后的一个do-while循环。就是用来跳过和当前结点同样的特征值。尽管以兴许的、同样的值作为结点划分左右子树,残差平方和可能会改变,可是决策树划分的最小单位是特征值的种类,由于在利用决策树进行分类的时候,必须对同样的特征值做出一样的决策(该划入左子树还是该划入右子树)。

总结

HaarTraining的代码是有4,5K行。可是认真学习之后会收获非常多机器学习的算法和优秀代码的书写习惯。我会随着学习的深入不断更新自己的源代码研究体会,写的尽管不是非常具体,可是力求把重点突出来,将自己在阅读代码时碰到的困惑总结出来,给相同学习Training算法的朋友一点点帮助

OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)的更多相关文章

  1. java代码解析二维码

    java代码解析二维码一般步骤 本文采用的是google的zxing技术进行解析二维码技术,解析二维码的一般步骤如下: 一.下载zxing-core的jar包: 二.创建一个BufferedImage ...

  2. GraphSAGE 代码解析(二) - layers.py

    原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(三) - aggregator ...

  3. GraphSAGE 代码解析(四) - models.py

    原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(二) - layers.py ...

  4. GraphSAGE 代码解析(三) - aggregators.py

    原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(二) - layers.py ...

  5. GraphSAGE 代码解析(一) - unsupervised_train.py

    原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(二) - layers.py GraphSAGE 代码解析(三) - aggregators.py GraphSA ...

  6. asp.net C#生成和解析二维码代码

    类库文件我们在文件最后面下载 [ThoughtWorks.QRCode.dll 就是类库] 使用时需要增加: using ThoughtWorks.QRCode.Codec;using Thought ...

  7. JavaScript “跑马灯”抽奖活动代码解析与优化(二)

    既然是要编写插件.那么叫做"插件"的东西肯定是具有的某些特征能够满足我们平时开发的需求或者是提高我们的开发效率.那么叫做插件的东西应该具有哪些基本特征呢?让我们来总结一下: 1.J ...

  8. RobHess的SIFT代码解析步骤三

    平台:win10 x64 +VS 2015专业版 +opencv-2.4.11 + gtk_-bundle_2.24.10_win32 主要参考:1.代码:RobHess的SIFT源码 2.书:王永明 ...

  9. Fixflow引擎解析(二)(模型) - BPMN2.0读写

    Fixflow引擎解析(四)(模型) - 通过EMF扩展BPMN2.0元素 Fixflow引擎解析(三)(模型) - 创建EMF模型来读写XML文件 Fixflow引擎解析(二)(模型) - BPMN ...

随机推荐

  1. 利用d3.js绘制中国地图

    d3.js是一个比較强的数据可视化js工具. 利用它画了一幅中国地图,例如以下图所看到的: watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvc3ZhcDE=/ ...

  2. javascript笔记整理(数组)

    数组是一个可以存储一组或是一系列相关数据的容器. 一.为什么要使用数组. a.为了解决大量相关数据的存储和使用的问题. b.模拟真是的世界. 二.如何创建数组 A.通过对象的方式来创建——var a= ...

  3. 基于visual Studio2013解决面试题之1403插入排序

     题目

  4. [Cocos2d-x开发问题-3] cocos2dx动画Animation介绍

    Cocos2d-x为了减少开发难度,对于动画的实现採用的帧动画的方案.这也就是说Cocos2d-x中的动画是帧动画. 帧动画的原理相信大家都不陌生,就是多张图片循环播放以实现动画的效果. 一个简单的动 ...

  5. Android KeyCode(官方)

    Constants public static final int ACTION_DOWN Added in API level 1 getAction() value: the key has be ...

  6. Windows Azure入门教学系列 (七):使用REST API访问Storage Service

    本文是Windows Azure入门教学的第七篇文章. 本文将会介绍如何使用REST API来直接访问Storage Service. 在前三篇教学中,我们已经学习了使用Windows Azure S ...

  7. Web前端,高性能优化

    高性能HTML 一.避免使用iframe iframe也叫内联frame,可将一个HTML文档嵌入另一个HTML文档中. iframe的好处是,嵌入的文档独立于父文档,通常也借此使浏览器模拟多线程.缺 ...

  8. Rationnal Rose2003安装并破解

    1.安装Rational Rose2003时,在需选择安装项的时候,只选择Rational Rose EnterPrise Edition即可,不需选择其他项,之后选择“DeskTop Install ...

  9. .net程序员面试不完全指南

    程序员找工作难,想要被成功聘用更难.最常见的办法是经历一次又一次的面试失败后自己琢磨出面试技巧,当然也可以花钱到一些培训机构去接受专业的书面简历和模拟面试的指导.这些方法可能都会奏效,但是却并不是时间 ...

  10. 【 .NET 面向对象程序设计进阶》】【 《.NET 面向对象编程基础》】【《正则表达式助手》】

    <.NET 面向对象程序设计进阶> <.NET 面向对象程序设计进阶> <正则表达式助手>