OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)
HaarTraining关键的部分是建立基分类器classifier,OpenCV中所採用的是CART(决策树的一种):通过调用cvCreateMTStumpClassifier来完毕。
这里我讨论利用回归的方法来分裂结点。分类的方法仅仅是在分裂结点的方法与之不同而已。
cvCreateMTStumpClassifier
//设置决策树分类误差计算方法
stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error; //设置class step和ydata
ydata = trainClasses->data.ptr;
if( trainClasses->rows == 1 )
{
m = trainClasses->cols;
ystep = CV_ELEM_SIZE( trainClasses->type );
}
else
{
m = trainClasses->rows;
ystep = trainClasses->step;
}
//设置weight step和wdata
wdata = weights->data.ptr;
if( weights->rows == 1 )
{
assert( weights->cols == m );
wstep = CV_ELEM_SIZE( weights->type );
}
else
{
assert( weights->rows == m );
wstep = weights->step;
} //设置步长,地址等參数,用于获取idxCache内容
if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
{
sortedtype =
CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
|| sortedtype == CV_32FC1 );
sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
sortedsstep = CV_ELEM_SIZE( sortedtype );
sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
} if( trainData == NULL )
{
assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
assert( n > 0 );
}
//设置步长,地址等參数,用于获取dataCache内容
else
{
assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
data = trainData->data.ptr;
if( CV_IS_ROW_SAMPLE( flags ) )
{
cstep = CV_ELEM_SIZE( trainData->type );
sstep = trainData->step;
assert( m == trainData->rows );
datan = n = trainData->cols;
}
else
{
sstep = CV_ELEM_SIZE( trainData->type );
cstep = trainData->step;
assert( m == trainData->cols );
datan = n = trainData->rows;
}
if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
{
n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
}
}
可能研究代码到这里的朋友仍然不清楚idxCache和valCache的作用。
我这里做一点简单的说明:
valCache是设置在训练前有多少特征值被提前算出存放在内存中。idxCache是valCache中每种特征按特征值从小到大排列的样本的序号。
内存大小是通过执行程序的命令行參数设置的。在cvHaartraining.cpp中我们能够找到这句话,当中float和short,各自是valCache和idxCache存放内容的基本类型。
//1MB == 1048576B 计算一个样本中有多少个特征能被pre计算放在内存中
numprecalculated = (int) ( ((size_t) mem) * ((size_t) 1048576) /
( ((size_t) (npos + nneg)) * (sizeof( float ) + sizeof( short )) ) );
为了方便理解,我把两者的内存模型画了出来
要注意idxCache中每行的index排列是示意图。
比方第一行代表feature1从小到大的index顺序。从图中能够看出。sample1的特征值feature1 < sample0的特征值feature1<...<sample n < sample n-1。
利用idxCache数组我们能够方便按特征值的从小到大遍历valCache,而且节省了空间。从float->short。
理解了这两个cache之后我们再回到上面的代码,能够发现这里做得仅仅是设置步长和cache首地址的一些操作。为以下開始的遍历做好准备。
跳过一些变量的初始化步骤,我们来到构建决策树stump的部分,而且为了方便阅读核心代码。去掉了其它一些基于移植的代码
while( t_compidx < n )
{
//选择计算前100种特征
t_n = portion;
if( t_compidx < datan )
{
t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
t_data = data;
t_cstep = cstep;
t_sstep = sstep;
}
else
{ } if( sorteddata != NULL )
{ }
else
{
/* have sorted indices */
switch( sortedtype )
{
case CV_16SC1: //选择某个样本的某个特征值作为结点
for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
{
if( findStumpThreshold_16s[stumperror](
t_data + ti * t_cstep, t_sstep,
wdata, wstep, ydata, ystep,
sorteddata + ti * sortedcstep, sortedsstep, sortedm,
&lerror, &rerror,
&threshold, &left, &right,
&sumw, &sumwy, &sumwyy ) )
{
optcompidx = ti;
}
}
break; }
}
}}
这里datan代表的是一个检測窗体包括的特征数目。portion代表以多少的行为单位进行计算。每一个循环选取valCache中的portion行进行计算,应该是为了发挥并行计算的优势,假如设置了并行计算的宏的话。
findStumpThreshold_32[stumperror]是一个函数指针,利用这个函数我们能够选择某个样本的某个特征值作为决策树的一个结点。
这里结点分裂方法我选择的是最小残差和的方法。即统计利用某个特征值进行分类后,左右子树中类间的残差之和。最小残差和相应的特征值就是满足要求的结点。
OpenCV1.0中利用宏定义的方式实现了这个函数
#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error ) \
CV_BOOST_IMPL int icvFindStumpThreshold_##suffix( \
uchar* data, size_t datastep, \
uchar* wdata, size_t wstep, \
uchar* ydata, size_t ystep, \
uchar* idxdata, size_t idxstep, int num, \
float* lerror, \
float* rerror, \
float* threshold, float* left, float* right, \
float* sumw, float* sumwy, float* sumwyy ) \
{ \
int found = 0; \
float wyl = 0.0F; \
float wl = 0.0F; \
float wyyl = 0.0F; \
float wyr = 0.0F; \
float wr = 0.0F; \
\
float curleft = 0.0F; \
float curright = 0.0F; \
float* prevval = NULL; \
float* curval = NULL; \
float curlerror = 0.0F; \
float currerror = 0.0F; \
float wposl; \
float wposr; \
\
int i = 0; \
int idx = 0; \
\
wposl = wposr = 0.0F; \
if( *sumw == FLT_MAX ) \
{ \
/* calculate sums */ \
float *y = NULL; \
float *w = NULL; \
float wy = 0.0F; \
\
*sumw = 0.0F; \
*sumwy = 0.0F; \
*sumwyy = 0.0F; \
for( i = 0; i < num; i++ ) \
{ \
idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
w = (float*) (wdata + idx * wstep); \
*sumw += *w; \
y = (float*) (ydata + idx * ystep); \
wy = (*w) * (*y); \
*sumwy += wy; \
*sumwyy += wy * (*y); \
} \
} \
\
for( i = 0; i < num; i++ ) \
{ \
idx = (int) ( *((type*) (idxdata + i*idxstep)) ); \
curval = (float*) (data + idx * datastep); \
/* for debug purpose */ \
if( i > 0 ) assert( (*prevval) <= (*curval) ); \
\
wyr = *sumwy - wyl; \
wr = *sumw - wl; \
\
if( wl > 0.0 ) curleft = wyl / wl; \
else curleft = 0.0F; \
\
if( wr > 0.0 ) curright = wyr / wr; \
else curright = 0.0F; \
\
error \
\
if( curlerror + currerror < (*lerror) + (*rerror) ) \
{ \
(*lerror) = curlerror; \
(*rerror) = currerror; \
*threshold = *curval; \
if( i > 0 ) { \
*threshold = 0.5F * (*threshold + *prevval); \
} \
*left = curleft; \
*right = curright; \
found = 1; \
} \
\
do \
{ \
wl += *((float*) (wdata + idx * wstep)); \
wyl += (*((float*) (wdata + idx * wstep))) \
* (*((float*) (ydata + idx * ystep))); \
wyyl += *((float*) (wdata + idx * wstep)) \
* (*((float*) (ydata + idx * ystep))) \
* (*((float*) (ydata + idx * ystep))); \
} \
while( (++i) < num && \
( *((float*) (data + (idx = \
(int) ( *((type*) (idxdata + i*idxstep))) ) * datastep)) \
== *curval ) ); \
--i; \
prevval = curval; \
} /* for each value */ \
\
return found; \
}
这里几个关键的变量:ydata是某个特征所代表的类别,正负样本分别以1。-1进行标注,wdata是正负样本相应的权值,data指的就是valCache的某一行。
程序进来的时候推断sumw是否初始化,没有初始化就进行赋值。因为同一个训练集每一个样本都仅仅相应一个ydata和wdata(每一个样本相应非常多个Haar特征,两者有差别),因此这里的sumw,sumwyy,sumwy都是一个确定的值。提前计算好,在后面的迭代中就不必反复计算。
接下来,依据idxCache中某一行(视迭代次数而定)的index,按从小到大的顺序遍历ValCache中相应行的特征值。也就是不相同本的同一特征值。并将其作为结点,尝试对样本进行划分。curleft和curright分别代表左右子树的类别的加权平均值。然后利用error宏计算左右子树的残差
#define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type ) \
ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type, \
/* calculate error (sum of squares) */ \
/* err = sum( w * (y - left(rigt)Val)^2 ) */ \
curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl; \
currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
)
最后的一个do-while循环。就是用来跳过和当前结点同样的特征值。尽管以兴许的、同样的值作为结点划分左右子树,残差平方和可能会改变,可是决策树划分的最小单位是特征值的种类,由于在利用决策树进行分类的时候,必须对同样的特征值做出一样的决策(该划入左子树还是该划入右子树)。
总结
OpenCV HaarTraining代码解析(二)cvCreateMTStumpClassifier(建立决策树)的更多相关文章
- java代码解析二维码
java代码解析二维码一般步骤 本文采用的是google的zxing技术进行解析二维码技术,解析二维码的一般步骤如下: 一.下载zxing-core的jar包: 二.创建一个BufferedImage ...
- GraphSAGE 代码解析(二) - layers.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(三) - aggregator ...
- GraphSAGE 代码解析(四) - models.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(二) - layers.py ...
- GraphSAGE 代码解析(三) - aggregators.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(一) - unsupervised_train.py GraphSAGE 代码解析(二) - layers.py ...
- GraphSAGE 代码解析(一) - unsupervised_train.py
原创文章-转载请注明出处哦.其他部分内容参见以下链接- GraphSAGE 代码解析(二) - layers.py GraphSAGE 代码解析(三) - aggregators.py GraphSA ...
- asp.net C#生成和解析二维码代码
类库文件我们在文件最后面下载 [ThoughtWorks.QRCode.dll 就是类库] 使用时需要增加: using ThoughtWorks.QRCode.Codec;using Thought ...
- JavaScript “跑马灯”抽奖活动代码解析与优化(二)
既然是要编写插件.那么叫做"插件"的东西肯定是具有的某些特征能够满足我们平时开发的需求或者是提高我们的开发效率.那么叫做插件的东西应该具有哪些基本特征呢?让我们来总结一下: 1.J ...
- RobHess的SIFT代码解析步骤三
平台:win10 x64 +VS 2015专业版 +opencv-2.4.11 + gtk_-bundle_2.24.10_win32 主要参考:1.代码:RobHess的SIFT源码 2.书:王永明 ...
- Fixflow引擎解析(二)(模型) - BPMN2.0读写
Fixflow引擎解析(四)(模型) - 通过EMF扩展BPMN2.0元素 Fixflow引擎解析(三)(模型) - 创建EMF模型来读写XML文件 Fixflow引擎解析(二)(模型) - BPMN ...
随机推荐
- JSP 网页格式判定执行哪一块html
JSP 网页格式判定执行哪一块html <!-- start --> <td height="166" colspan="3&q ...
- 暴力或随机-hdu-4712-Hamming Distance
题目链接: http://acm.hdu.edu.cn/showproblem.php?pid=4712 题目大意: 求n个20位0.1二进制串中,两两抑或最少的1的个数. 解题思路: 两种解法: 1 ...
- IT大数据服务管理高级课程(IT服务,大数据,云计算,智能城市)
个人简历 金石先生是马克思主义中国化的研究学者,上海财经大学经济学和管理学硕士,中国民主建国会成员,中国特色社会主义人文科技管理哲学的理论奠基人之一.金石先生博学多才,对问题有独到见解.专于工作且乐于 ...
- jquery 下拉多选插件
Jquery多选下拉列表插件jquery multiselect功能介绍及使用 Chosen 替代样式表 Bootstrap Chosen
- AVL树----java
AVL树----java AVL ...
- Swift - 使用相机拍摄照片
1,打开相机拍照 通过设置图片控制器UIImagePickerController的来源为UIImagePickerControllerSourceType.Camera,便可以打开相机 1 2 3 ...
- 第13章、布局Layouts之RelativeLayout相对布局(从零開始学Android)
RelativeLayout相对布局 RelativeLayout是一种相对布局,控件的位置是依照相对位置来计算的,后一个控件在什么位置依赖于前一个控件的基本位置,是布局最经常使用,也是最灵活的一种布 ...
- Oracle判断指定列是否全部为数字
select nvl2(translate(name,'\1234567890 ', '\'),'is characters ','is number ') from customer_inf ...
- OCA读书笔记(2) - 安装Oracle软件
Objectives: •Describe your role as a database administrator (DBA) and explain typical tasks and tool ...
- uva 657
很简单的题,就是题意不懂……! 就是判断每个'*'区域内‘X’区域块的个数 WA了好多次,就是太差了: 1.结果排序输出 2.因为是骰子所以不再1-6范围内的数字要舍弃 3.格式要求要空一行…… 4. ...