各机器学习方法代码(OpenCV2)
#include <iostream>
#include <math.h>
#include <string>
#include "cv.h"
#include "ml.h"
#include "highgui.h" using namespace cv;
using namespace std; bool plotSupportVectors=true;
int numTrainingPoints=;
int numTestPoints=;
int size=;
int eq=; // accuracy
float evaluate(cv::Mat& predicted, cv::Mat& actual) {
assert(predicted.rows == actual.rows);
int t = ;
int f = ;
for(int i = ; i < actual.rows; i++) {
float p = predicted.at<float>(i,);
float a = actual.at<float>(i,);
if((p >= 0.0 && a >= 0.0) || (p <= 0.0 && a <= 0.0)) {
t++;
} else {
f++;
}
}
return (t * 1.0) / (t + f);
} // plot data and class
void plot_binary(cv::Mat& data, cv::Mat& classes, string name) {
cv::Mat plot(size, size, CV_8UC3);
plot.setTo(cv::Scalar(255.0,255.0,255.0));
for(int i = ; i < data.rows; i++) { float x = data.at<float>(i,) * size;
float y = data.at<float>(i,) * size; if(classes.at<float>(i, ) > ) {
cv::circle(plot, Point(x,y), , CV_RGB(,,),);
} else {
cv::circle(plot, Point(x,y), , CV_RGB(,,),);
}
}
cv::imshow(name, plot);
} // function to learn
int f(float x, float y, int equation) {
switch(equation) {
case :
return y > sin(x*) ? - : ;
break;
case :
return y > cos(x * ) ? - : ;
break;
case :
return y > *x ? - : ;
break;
case :
return y > tan(x*) ? - : ;
break;
default:
return y > cos(x*) ? - : ;
}
} // label data with equation
cv::Mat labelData(cv::Mat points, int equation) {
cv::Mat labels(points.rows, , CV_32FC1);
for(int i = ; i < points.rows; i++) {
float x = points.at<float>(i,);
float y = points.at<float>(i,);
labels.at<float>(i, ) = f(x, y, equation);
}
return labels;
} void svm(cv::Mat& trainingData, cv::Mat& trainingClasses, cv::Mat& testData, cv::Mat& testClasses) {
CvSVMParams param = CvSVMParams(); param.svm_type = CvSVM::C_SVC;
param.kernel_type = CvSVM::RBF; //CvSVM::RBF, CvSVM::LINEAR ...
param.degree = ; // for poly
param.gamma = ; // for poly/rbf/sigmoid
param.coef0 = ; // for poly/sigmoid param.C = ; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
param.nu = 0.0; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
param.p = 0.0; // for CV_SVM_EPS_SVR param.class_weights = NULL; // for CV_SVM_C_SVC
param.term_crit.type = CV_TERMCRIT_ITER +CV_TERMCRIT_EPS;
param.term_crit.max_iter = ;
param.term_crit.epsilon = 1e-; // SVM training (use train auto for OpenCV>=2.0)
CvSVM svm(trainingData, trainingClasses, cv::Mat(), cv::Mat(), param); cv::Mat predicted(testClasses.rows, , CV_32F); for(int i = ; i < testData.rows; i++) {
cv::Mat sample = testData.row(i); float x = sample.at<float>(,);
float y = sample.at<float>(,); predicted.at<float>(i, ) = svm.predict(sample);
} cout << "Accuracy_{SVM} = " << evaluate(predicted, testClasses) << endl;
plot_binary(testData, predicted, "Predictions SVM"); // plot support vectors
if(plotSupportVectors) {
cv::Mat plot_sv(size, size, CV_8UC3);
plot_sv.setTo(cv::Scalar(255.0,255.0,255.0)); int svec_count = svm.get_support_vector_count();
for(int vecNum = ; vecNum < svec_count; vecNum++) {
const float* vec = svm.get_support_vector(vecNum);
cv::circle(plot_sv, Point(vec[]*size, vec[]*size), , CV_RGB(, , ));
}
cv::imshow("Support Vectors", plot_sv);
}
} void mlp(cv::Mat& trainingData, cv::Mat& trainingClasses, cv::Mat& testData, cv::Mat& testClasses) { cv::Mat layers = cv::Mat(, , CV_32SC1); layers.row() = cv::Scalar();
layers.row() = cv::Scalar();
layers.row() = cv::Scalar();
layers.row() = cv::Scalar(); CvANN_MLP mlp;
CvANN_MLP_TrainParams params;
CvTermCriteria criteria;
criteria.max_iter = ;
criteria.epsilon = 0.00001f;
criteria.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
params.train_method = CvANN_MLP_TrainParams::BACKPROP;
params.bp_dw_scale = 0.05f;
params.bp_moment_scale = 0.05f;
params.term_crit = criteria; mlp.create(layers); // train
mlp.train(trainingData, trainingClasses, cv::Mat(), cv::Mat(), params); cv::Mat response(, , CV_32FC1);
cv::Mat predicted(testClasses.rows, , CV_32F);
for(int i = ; i < testData.rows; i++) {
cv::Mat response(, , CV_32FC1);
cv::Mat sample = testData.row(i); mlp.predict(sample, response);
predicted.at<float>(i,) = response.at<float>(,); } cout << "Accuracy_{MLP} = " << evaluate(predicted, testClasses) << endl;
plot_binary(testData, predicted, "Predictions Backpropagation");
} void knn(cv::Mat& trainingData, cv::Mat& trainingClasses, cv::Mat& testData, cv::Mat& testClasses, int K) { CvKNearest knn(trainingData, trainingClasses, cv::Mat(), false, K);
cv::Mat predicted(testClasses.rows, , CV_32F);
for(int i = ; i < testData.rows; i++) {
const cv::Mat sample = testData.row(i);
predicted.at<float>(i,) = knn.find_nearest(sample, K);
} cout << "Accuracy_{KNN} = " << evaluate(predicted, testClasses) << endl;
plot_binary(testData, predicted, "Predictions KNN"); } void bayes(cv::Mat& trainingData, cv::Mat& trainingClasses, cv::Mat& testData, cv::Mat& testClasses) { CvNormalBayesClassifier bayes(trainingData, trainingClasses);
cv::Mat predicted(testClasses.rows, , CV_32F);
for (int i = ; i < testData.rows; i++) {
const cv::Mat sample = testData.row(i);
predicted.at<float> (i, ) = bayes.predict(sample);
} cout << "Accuracy_{BAYES} = " << evaluate(predicted, testClasses) << endl;
plot_binary(testData, predicted, "Predictions Bayes"); } void decisiontree(cv::Mat& trainingData, cv::Mat& trainingClasses, cv::Mat& testData, cv::Mat& testClasses) { CvDTree dtree;
cv::Mat var_type(, , CV_8U); // define attributes as numerical
var_type.at<unsigned int>(,) = CV_VAR_NUMERICAL;
var_type.at<unsigned int>(,) = CV_VAR_NUMERICAL;
// define output node as numerical
var_type.at<unsigned int>(,) = CV_VAR_NUMERICAL; dtree.train(trainingData,CV_ROW_SAMPLE, trainingClasses, cv::Mat(), cv::Mat(), var_type, cv::Mat(), CvDTreeParams());
cv::Mat predicted(testClasses.rows, , CV_32F);
for (int i = ; i < testData.rows; i++) {
const cv::Mat sample = testData.row(i);
CvDTreeNode* prediction = dtree.predict(sample);
predicted.at<float> (i, ) = prediction->value;
} cout << "Accuracy_{TREE} = " << evaluate(predicted, testClasses) << endl;
plot_binary(testData, predicted, "Predictions tree"); } int main() { cv::Mat trainingData(numTrainingPoints, , CV_32FC1);
cv::Mat testData(numTestPoints, , CV_32FC1); cv::randu(trainingData,,);
cv::randu(testData,,); cv::Mat trainingClasses = labelData(trainingData, eq);
cv::Mat testClasses = labelData(testData, eq); plot_binary(trainingData, trainingClasses, "Training Data");
plot_binary(testData, testClasses, "Test Data"); svm(trainingData, trainingClasses, testData, testClasses);
mlp(trainingData, trainingClasses, testData, testClasses);
knn(trainingData, trainingClasses, testData, testClasses, );
bayes(trainingData, trainingClasses, testData, testClasses);
decisiontree(trainingData, trainingClasses, testData, testClasses); cv::waitKey(); return ;
}
图像分类结果:
各机器学习方法代码(OpenCV2)的更多相关文章
- R语言进行机器学习方法及实例(一)
版权声明:本文为博主原创文章,转载请注明出处 机器学习的研究领域是发明计算机算法,把数据转变为智能行为.机器学习和数据挖掘的区别可能是机器学习侧重于执行一个已知的任务,而数据发掘是在大数据中寻找有 ...
- Stanford机器学习---第六讲. 怎样选择机器学习方法、系统
原文:http://blog.csdn.net/abcjennifer/article/details/7797502 本栏目(Machine learning)包括单参数的线性回归.多参数的线性回归 ...
- 美团网基于机器学习方法的POI品类推荐算法
美团网基于机器学习方法的POI品类推荐算法 前言 在美团商家数据中心(MDC),有超过100w的已校准审核的POI数据(我们一般将商家标示为POI,POI基础信息包括:门店名称.品类.电话.地址.坐标 ...
- 程序编码(机器级代码+汇编代码+C代码+反汇编)
[-1]相关声明 本文总结于csapp: 了解详情,或有兴趣,建议看原版书籍: [0]程序编码 GCC调用了一系列程序,将源代码转化成可执行代码的流程如下: (1)C预处理器扩展源代码,插入所有用#i ...
- 关于”机器学习方法“,"深度学习方法"系列
"机器学习/深度学习方法"系列,我本着开放与共享(open and share)的精神撰写,目的是让很多其它的人了解机器学习的概念,理解其原理,学会应用.如今网上各种技术类文章非常 ...
- 机器学习方法、距离度量、K_Means
特征向量 1.特征向量:以人为例,每个元素可能就对应这人的某些方面,这就是特征,例如:身高.年龄.性别.国际....2.特征工程:目的就是将现有数据中可作为信号的特征与那些仅是噪声的特征区分开来:当数 ...
- 不平衡数据下的机器学习方法简介 imbalanced time series classification
imbalanced time series classification http://www.vipzhuanli.com/pat/books/201510229367.5/2.html?page ...
- 基于CRF工具的机器学习方法命名实体识别的过
[转自百度文库] 基于CRF工具的机器学习方法命名实体识别的过程 | 浏览:226 | 更新:2014-04-11 09:32 这里只讲基本过程,不涉及具体实现,我也是初学者,想给其他初学者一些帮助, ...
- 机器学习方法(六):随机森林Random Forest,bagging
欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 前面机器学习方法(四)决策树讲了经典 ...
随机推荐
- dubbo多网卡时,服务提供者的错误IP注册到注册中心导致消费端连接不上
使用了虚拟机之后,启动了dubbo服务提供者应用,又连了正式环境的注册中心: 一旦dubbo获取的ip错误后, 这种情况即使提供者服务停掉,目前dubbo没有能力清除这类错误的提供者: (需要修改源码 ...
- gitlab或github下fork后如何同步源的新更新内容?
两种方式: 项目 fetch 到本地,通过命令行的方式 merge 懒人方法,只用 Github ,不用命令行 1.项目 fetch 到本地,通过命令行的方式 merge 提示:跟上游仓库同步代码之前 ...
- Python入门 日志打印
logging # logging导入 import logging # 设置打印的最低级别 logging.basicConfig(level = logging.DEBUG) 使用 debug, ...
- [原][译]JSBSim官方源码文档翻译(google翻译)
/*%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% CLASS DOCUMENTATION ...
- windows service 2008 R2 安装net4.6环境失败,windows service 2008 R2 升级sp1问题
一.错误 1.因为我的程序是以vs2017开发的,在windows service 2008 R2 IIS部署项目文件报出错误,因此要安装net4.6的环境. 2.windows service 2 ...
- python入门-基础语法
一.变量 定义字符串要加单引号‘’ 变量命名规范: 变量名只能是字母.数字或下划线的任意组合 变量名的第一个字符不能是数字 变量名不能用关键字 变量名不要用中文 变量名不要太长,区分大小写 面就用单引 ...
- Mongo 查询(可视化工具)
distinct MongoDB 的 distinct 命令是获取特定字段中不同值列表的最简单工具. 该命令适用于普通字段.数组字段以及数组内嵌文档(集合对象). db.getCollection(' ...
- 通用Mapper环境下,mapper接口无法注入问题
写了一个mapper接口 package com.nyist.mapper; import com.nyist.entity.User; import tk.mybatis.mapper.common ...
- zzw原创_非root用户启动apache的问题解决(非root用户启动apache的1024以下端口)
场景:普通用户编译的apache,要在该用户下启动1024端口以下的apache端口 1.假设普通用户为sims20,用该用户编译 安装了一个apache,安装路径为/opt/aspire/produ ...
- libcurl编译使用,实现ftp功能
Libcurl实现ftp的下载,上传功能.版本为curl-7.63.0 1.编译vs2015 参考资料:https://blog.csdn.net/yaojingkao/article/details ...