//随机树分类
Ptr<StatModel> lpmlBtnClassify::buildRtreesClassifier(Mat data, Mat responses, int ntrain_samples)
{

Ptr<RTrees> model;
Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);
model = RTrees::create();
model->setMaxDepth(10);
model->setMinSampleCount(10);
model->setRegressionAccuracy(0);
model->setUseSurrogates(false);
model->setMaxCategories(15);
model->setPriors(Mat());
model->setCalculateVarImportance(false);
model->setTermCriteria(setIterCondition(100, 0.01f));
model->train(tdata);

return model;
}

//adaboost分类
Ptr<StatModel> lpmlBtnClassify::buildAdaboostClassifier(Mat data, Mat responses, int ntrain_samples,int param0)
{
Mat weak_responses;
int i, j, k;
Ptr<Boost> model;

int nsamples_all = data.rows;
int var_count = data.cols;

Mat new_data(ntrain_samples*class_count, var_count + 1, CV_32F);
Mat new_responses(ntrain_samples*class_count, 1, CV_32S);

for (i = 0; i < ntrain_samples; i++)
{
const float* data_row = data.ptr<float>(i);
for (j = 0; j < class_count; j++)
{
float* new_data_row = (float*)new_data.ptr<float>(i*class_count + j);
memcpy(new_data_row, data_row, var_count*sizeof(data_row[0]));
new_data_row[var_count] = (float)j;
new_responses.at<int>(i*class_count + j) = responses.at<int>(i) == j;
}
}

Mat var_type(1, var_count + 2, CV_8U);
var_type.setTo(Scalar::all(VAR_ORDERED));
var_type.at<uchar>(var_count) = var_type.at<uchar>(var_count + 1) = VAR_CATEGORICAL;

Ptr<TrainData> tdata = TrainData::create(new_data, ROW_SAMPLE, new_responses,
noArray(), noArray(), noArray(), var_type);

model = Boost::create();
model->setBoostType(Boost::GENTLE);
model->setWeakCount(param0);
model->setWeightTrimRate(0.95);
model->setMaxDepth(5);
model->setUseSurrogates(false);
model->train(tdata);

return model;
}

//多层感知机分类(ANN)
Ptr<StatModel> lpmlBtnClassify::buildMlpClassifier(Mat data, Mat responses, int ntrain_samples)
{
//read_num_class_data(data_filename, 16, &data, &responses);
Ptr<ANN_MLP> model;
Mat train_data = data.rowRange(0, ntrain_samples);
Mat train_responses = Mat::zeros(ntrain_samples, class_count, CV_32F);

// 1. unroll the responses
for (int i = 0; i < ntrain_samples; i++)
{
int cls_label = responses.at<int>(i);
train_responses.at<float>(i, cls_label) = 1.f;
}

// 2. train classifier
int layer_sz[] = { data.cols, 100, 100, class_count };
int nlayers = (int)(sizeof(layer_sz) / sizeof(layer_sz[0]));
Mat layer_sizes(1, nlayers, CV_32S, layer_sz);

#if 1
int method = ANN_MLP::BACKPROP;
double method_param = 0.001;
int max_iter = 300;
#else
int method = ANN_MLP::RPROP;
double method_param = 0.1;
int max_iter = 1000;
#endif

Ptr<TrainData> tdata = TrainData::create(train_data, ROW_SAMPLE, train_responses);
model = ANN_MLP::create();
model->setLayerSizes(layer_sizes);
model->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
model->setTermCriteria(setIterCondition(max_iter, 0));
model->setTrainMethod(method, method_param);
model->train(tdata);
return model;
}

//贝叶斯分类
Ptr<StatModel> lpmlBtnClassify::buildNbayesClassifier(Mat data, Mat responses, int ntrain_samples)
{
Ptr<NormalBayesClassifier> model;
Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);
model = NormalBayesClassifier::create();
model->train(tdata);

return model;
}

Ptr<StatModel> lpmlBtnClassify::buildKnnClassifier(Mat data, Mat responses, int ntrain_samples, int K)
{
Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);
Ptr<KNearest> model = KNearest::create();
model->setDefaultK(K);
model->setIsClassifier(true);
model->train(tdata);

return model;
}

//svm分类
Ptr<StatModel> lpmlBtnClassify::buildSvmClassifier(Mat data, Mat responses, int ntrain_samples)
{
Ptr<SVM> model;
Ptr<TrainData> tdata = prepareTrainData(data, responses, ntrain_samples);
model = SVM::create();
model->setType(SVM::C_SVC);
model->setKernel(SVM::RBF);
model->setC(1);
model->train(tdata);
return model;
}

opencv3.0机器学习算法使用的更多相关文章

  1. opencv3中的机器学习算法之:EM算法

    不同于其它的机器学习模型,EM算法是一种非监督的学习算法,它的输入数据事先不需要进行标注.相反,该算法从给定的样本集中,能计算出高斯混和参数的最大似然估计.也能得到每个样本对应的标注值,类似于kmea ...

  2. 在opencv3中的机器学习算法

    在opencv3.0中,提供了一个ml.cpp的文件,这里面全是机器学习的算法,共提供了这么几种: 1.正态贝叶斯:normal Bayessian classifier    我已在另外一篇博文中介 ...

  3. 在opencv3中实现机器学习算法之:利用最近邻算法(knn)实现手写数字分类

    手写数字digits分类,这可是深度学习算法的入门练习.而且还有专门的手写数字MINIST库.opencv提供了一张手写数字图片给我们,先来看看 这是一张密密麻麻的手写数字图:图片大小为1000*20 ...

  4. Spark2.0机器学习系列之1: 聚类算法(LDA)

    在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法:      (1)K-means      (2)Latent Dirichlet allocation (LDA)  ...

  5. opencv3.0中contrib模块的添加+实现SIFT/SURF算法

    平台:win10 x64 +VS 2015专业版 +opencv-3.x.+CMake+Anaconda3(python3.7.0) Issue说明:Opencv3.0版本已经发布了有一段时间,在这段 ...

  6. Atitit opencv3.0  3.1 3.2 新特性attilax总结

    Atitit opencv3.0  3.1 3.2 新特性attilax总结 1. 3.0OpenCV 3 的改动在哪?1 1.1. 模块构成该看哪些模块?2 2. 3.1新特性 2015-12-21 ...

  7. OpenCV3 Java 机器学习使用方法汇总

    原文链接:OpenCV3 Java 机器学习使用方法汇总  前言 按道理来说,C++版本的OpenCV训练的版本XML文件,在java中可以无缝使用.但要注意OpenCV本身的版本问题.从2.4 到3 ...

  8. 机器学习&数据挖掘笔记_16(常见面试之机器学习算法思想简单梳理)

    前言: 找工作时(IT行业),除了常见的软件开发以外,机器学习岗位也可以当作是一个选择,不少计算机方向的研究生都会接触这个,如果你的研究方向是机器学习/数据挖掘之类,且又对其非常感兴趣的话,可以考虑考 ...

  9. 建模分析之机器学习算法(附python&R代码)

    0序 随着移动互联和大数据的拓展越发觉得算法以及模型在设计和开发中的重要性.不管是现在接触比较多的安全产品还是大互联网公司经常提到的人工智能产品(甚至人类2045的的智能拐点时代).都基于算法及建模来 ...

随机推荐

  1. Vue二次精度随笔(1)

    1.button.input标签的disabled属性 该标签可以控制按钮是否可用,如果他的值为以上几种的话,则他都不会在标签上渲染出这个属性,一旦这个属性出现的话,就说明他是禁用的 2.移除动态绑定 ...

  2. 「NOIP2011」Mayan游戏

    传送门 Luogu 解题思路 爆搜,并考虑几个剪枝. 不交换颜色相同的方块(有争议,但是可以过联赛数据 \(Q \omega Q\)) 左边为空才往左换 右边不为空才往右换 因为对于两个相邻方块,右边 ...

  3. DotNet中的继承,剖析面向对象中继承的意义

    继承是面向对象程序设计不可缺少的机制,有了继承这个东西,可以提高代码的重用,提高代码的效率,减轻代码员的负担. 面向对象三大特性:封装.继承.多态是相辅相成的.封装为了继承,限制了父类的哪些成员被子类 ...

  4. Codeforces Round #568 (Div. 2) 选做

    A.B 略,相信大家都会做 ^_^ C. Exam in BerSU 题意 给你一个长度为 \(n\) 的序列 \(a_i\) .对于每个 \(i\in [1,N]\) 求 \([1,i-1]\) 中 ...

  5. HDU 5045 状压DP 上海网赛

    比赛的时候想的是把n个n个的题目进行状压 但这样不能讲究顺序,当时精神面貌也不好,真是挫死了 其实此题的另一个角度就是一个n个数的排列,如果我对n个人进行状压,外面套一个按题目循序渐进的大循环,那么, ...

  6. Deepctr框架代码阅读

    DeepCtr是一个简易的CTR模型框架,集成了深度学习流行的所有模型,适合学推荐系统模型的人参考. 我在参加比赛中用到了这个框架,但是效果一般,为了搞清楚原因从算法和框架两方面入手.在读代码的过程中 ...

  7. DEDE后台升级后不显示编辑器

    dede5.7不显示编辑器不能编辑文章的解决办法:进入系统后台系统配置-系统基本参数-核心设置将fck换成ckeditor保存,当然需要fck编辑器也可以到dede官网下载.dede5.7不显示编辑器 ...

  8. P1066 图像过滤

    P1066 图像过滤 转跳点:

  9. mysql性能优化及 Comparison method violates its general contract

    项目上嵌套结果集查询,查询的列表再根据每个id进行查询计算,嵌套的sql如下: SELECT SUM(IFNULL(t.out_rate,0)) totalOutRate, SUM(IF(IFNULL ...

  10. GoJS组织结构图

    <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...