不同于其它的机器学习模型,EM算法是一种非监督的学习算法,它的输入数据事先不需要进行标注。相反,该算法从给定的样本集中,能计算出高斯混和参数的最大似然估计。也能得到每个样本对应的标注值,类似于kmeans聚类(输入样本数据,输出样本数据的标注)。实际上,高斯混和模型GMM和kmeans都是EM算法的应用。

在opencv3.0中,EM算法的函数是trainEM,函数原型为:

  1. bool trainEM(InputArray samples, OutputArray logLikelihoods=noArray(),OutputArray labels=noArray(),OutputArray probs=noArray())

四个参数:

samples: 输入的样本,一个单通道的矩阵。从这个样本中,进行高斯混和模型估计。

logLikelihoods: 可选项,输出一个矩阵,里面包含每个样本的似然对数值。

labels: 可选项,输出每个样本对应的标注。

probs: 可选项,输出一个矩阵,里面包含每个隐性变量的后验概率

这个函数没有输入参数的初始化值,是因为它会自动执行kmeans算法,将kmeans算法得到的结果作为参数初始化。

这个trainEM函数实际把E步骤和M步骤都包含进去了,我们也可以对两个步骤分开执行,OPENCV3.0中也提供了分别执行的函数:

  1. bool trainE(InputArray samples, InputArray means0,
  2. InputArray covs0=noArray(),
  3. InputArray weights0=noArray(),
  4. OutputArray logLikelihoods=noArray(),
  5. OutputArray labels=noArray(),
  6. OutputArray probs=noArray())
  1. bool trainM(InputArray samples, InputArray probs0,
  2. OutputArray logLikelihoods=noArray(),
  3. OutputArray labels=noArray(),
  4. OutputArray probs=noArray())
  1. trainEM函数的功能和kmeans差不多,都是实现自动聚类,输出每个样本对应的标注值。但它比kmeans还多出一个功能,就是它还能起到训练分类器的作用,用于后续新样本的预测。
  2.  
  3. 预测函数原型为:
  1. Vec2d predict2(InputArray sample, OutputArray probs) const

sample: 待测样本

probs : 和上面一样,一个可选的输出值,包含每个隐性变量的后验概率

返回一个Vec2d类型的数,包括两个元素的double向量,第一个元素为样本的似然对数值,第二个元素为最大可能混和分量的索引值。

在本文中,我们用两个实例来学习opencv中的EM算法的应用。

一、opencv3.0中自带的例子,既包括聚类trianEM,也包括预测predict2

代码:

  1. #include "stdafx.h"
  2. #include "opencv2/opencv.hpp"
  3. #include <iostream>
  4. using namespace std;
  5. using namespace cv;
  6. using namespace cv::ml;
  7.  
  8. //使用EM算法实现样本的聚类及预测
  9. int main()
  10. {
  11. const int N = ; //分成4类
  12. const int N1 = (int)sqrt((double)N);
  13. //定义四种颜色,每一类用一种颜色表示
  14. const Scalar colors[] =
  15. {
  16. Scalar(, , ), Scalar(, , ),
  17. Scalar(, , ), Scalar(, , )
  18. };
  19.  
  20. int i, j;
  21. int nsamples = ; //100个样本点
  22. Mat samples(nsamples, , CV_32FC1); //样本矩阵,100行2列,即100个坐标点
  23. Mat img = Mat::zeros(Size(, ), CV_8UC3); //待测数据,每一个坐标点为一个待测数据
  24. samples = samples.reshape(, );
  25.  
  26. //循环生成四个类别样本数据,共样本100个,每类样本25个
  27. for (i = ; i < N; i++)
  28. {
  29.  
  30. Mat samples_part = samples.rowRange(i*nsamples / N, (i + )*nsamples / N);
  31.  
  32. //设置均值
  33. Scalar mean(((i%N1) + )*img.rows / (N1 + ),
  34. ((i / N1) + )*img.rows / (N1 + ));
  35. //设置标准差
  36. Scalar sigma(, );
  37. randn(samples_part, mean, sigma); //根据均值和标准差,随机生成25个正态分布坐标点作为样本
  38. }
  39. samples = samples.reshape(, );
  40. // 训练分类器
  41. Mat labels; //标注,不需要事先知道
  42. Ptr<EM> em_model = EM::create();
  43. em_model->setClustersNumber(N);
  44. em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
  45. em_model->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, , 0.1));
  46. em_model->trainEM(samples, noArray(), labels, noArray());
  47.  
  48. //对每个坐标点进行分类,并根据类别用不同的颜色画出
  49. Mat sample(, , CV_32FC1);
  50. for (i = ; i < img.rows; i++)
  51. {
  52. for (j = ; j < img.cols; j++)
  53. {
  54. sample.at<float>() = (float)j;
  55. sample.at<float>() = (float)i;
  56. //predict2返回的是double值,用cvRound进行四舍五入得到整型
  57. //此处返回的是两个值Vec2d,取第二个值作为样本标注
  58. int response = cvRound(em_model->predict2(sample, noArray())[]);
  59. Scalar c = colors[response]; //为不同类别设定颜色
  60. circle(img, Point(j, i), , c*0.75, FILLED);
  61. }
  62. }
  63.  
  64. //画出样本点
  65. for (i = ; i < nsamples; i++)
  66. {
  67. Point pt(cvRound(samples.at<float>(i, )), cvRound(samples.at<float>(i, )));
  68. circle(img, pt, , colors[labels.at<int>(i)], FILLED);
  69. }
  70.  
  71. imshow("EM聚类结果", img);
  72. waitKey();
  73.  
  74. return ;
  75. }

结果:

二、只用trainEM实现自动聚类功能,进行图片中的目标检测

代码:

  1. #include "stdafx.h"
  2. #include "opencv2/opencv.hpp"
  3. #include <iostream>
  4. using namespace std;
  5. using namespace cv;
  6. using namespace cv::ml;
  7.  
  8. int main()
  9. {
  10. const int MAX_CLUSTERS = ;
  11. Vec3b colorTab[] =
  12. {
  13. Vec3b(, , ),
  14. Vec3b(, , ),
  15. Vec3b(, , ),
  16. Vec3b(, , ),
  17. Vec3b(, , )
  18. };
  19. Mat data, labels;
  20. Mat pic = imread("d:/woman.png");
  21. for (int i = ; i < pic.rows; i++)
  22. for (int j = ; j < pic.cols; j++)
  23. {
  24. Vec3b point = pic.at<Vec3b>(i, j);
  25. Mat tmp = (Mat_<float>(, ) << point[], point[], point[]);
  26. data.push_back(tmp);
  27. }
  28.  
  29. int N =; //聚成3类
  30. Ptr<EM> em_model = EM::create();
  31. em_model->setClustersNumber(N);
  32. em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
  33. em_model->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, , 0.1));
  34. em_model->trainEM(data, noArray(), labels, noArray());
  35.  
  36. int n = ;
  37. //显示聚类结果,不同的类别用不同的颜色显示
  38. for (int i = ; i < pic.rows; i++)
  39. for (int j = ; j < pic.cols; j++)
  40. {
  41. int clusterIdx = labels.at<int>(n);
  42. pic.at<Vec3b>(i, j) = colorTab[clusterIdx];
  43. n++;
  44. }
  45. imshow("pic", pic);
  46. waitKey();
  47.  
  48. return ;
  49. }

测试图片

测试结果:

opencv3中的机器学习算法之:EM算法的更多相关文章

  1. 在opencv3中的机器学习算法

    在opencv3.0中,提供了一个ml.cpp的文件,这里面全是机器学习的算法,共提供了这么几种: 1.正态贝叶斯:normal Bayessian classifier    我已在另外一篇博文中介 ...

  2. 机器学习优化算法之EM算法

    EM算法简介 EM算法其实是一类算法的总称.EM算法分为E-Step和M-Step两步.EM算法的应用范围很广,基本机器学习需要迭代优化参数的模型在优化时都可以使用EM算法. EM算法的思想和过程 E ...

  3. MM 算法与 EM算法概述

    1.MM 算法: MM算法是一种迭代优化方法,利用函数的凸性来寻找它们的最大值或最小值. MM表示 “majorize-minimize MM 算法” 或“minorize maximize MM 算 ...

  4. 在opencv3中实现机器学习算法之:利用最近邻算法(knn)实现手写数字分类

    手写数字digits分类,这可是深度学习算法的入门练习.而且还有专门的手写数字MINIST库.opencv提供了一张手写数字图片给我们,先来看看 这是一张密密麻麻的手写数字图:图片大小为1000*20 ...

  5. python机器学习笔记:EM算法

    EM算法也称期望最大化(Expectation-Maximum,简称EM)算法,它是一个基础算法,是很多机器学习领域的基础,比如隐式马尔科夫算法(HMM),LDA主题模型的变分推断算法等等.本文对于E ...

  6. Python实现机器学习算法:EM算法

    ''' 数据集:伪造数据集(两个高斯分布混合) 数据集长度:1000 ------------------------------ 运行结果: ---------------------------- ...

  7. 【机器学习】K-means聚类算法与EM算法

    初始目的 将样本分成K个类,其实说白了就是求一个样本例的隐含类别y,然后利用隐含类别将x归类.由于我们事先不知道类别y,那么我们首先可以对每个样例假定一个y吧,但是怎么知道假定的对不对呢?怎样评价假定 ...

  8. 【机器学习笔记】EM算法及其应用

    极大似然估计 考虑一个高斯分布\(p(\mathbf{x}\mid{\theta})\),其中\(\theta=(\mu,\Sigma)\).样本集\(X=\{x_1,...,x_N\}\)中每个样本 ...

  9. python大战机器学习——聚类和EM算法

    注:本文中涉及到的公式一律省略(公式不好敲出来),若想了解公式的具体实现,请参考原著. 1.基本概念 (1)聚类的思想: 将数据集划分为若干个不想交的子集(称为一个簇cluster),每个簇潜在地对应 ...

随机推荐

  1. iOS 完美解决 interactivePopGestureRecognizer 卡住的问题

    interactivePopGestureRecognizer是iOS7推出的解决VeiwController滑动后退的新功能,虽然很实用,但是坑也很多啊,用过的同学肯定知道问题在哪里,所以具体问题我 ...

  2. 解决log4j:WARN Error initializing output writer. log4j:WARN Unsupported encoding?的问题

    异常名:log4j:WARN Error initializing output writer. log4j:WARN Unsupported encoding? 异常截图: 在一般的javaweb项 ...

  3. tomcat下部署润乾报表

    因为项目需要,需要在项目中配置润乾报表,之前一直是用的jboss服务器,此处调整为tomcat时出错,然后各种找错,找答案,最后终于好了,然后总结一下. 首先在apache-tomcat-6.0.43 ...

  4. mybatis3.3 + struts2.3.24 + mysql5.1.22开发环境搭建及相关说明

    一.新建Web工程,并在lib目录下添加jar包 主要jar包:struts2相关包,mybatis3.3相关包,mysql-connector-java-5.1.22-bin.jar, gson-2 ...

  5. 最近开始研究PMD(一款采用BSD协议发布的Java程序代码检查工具)

    PMD是一款采用BSD协议发布的Java程序代码检查工具.该工具可以做到检查Java代码中是否含有未使用的变量.是否含有空的抓取块.是否含有不必要的对象等.该软件功能强大,扫描效率高,是Java程序员 ...

  6. Windows Server 2012之搭建域控制器DC

    安装域控制器,域(Domain) 1,本地管理员权限 2,设置静态IP 地址 3,至少有一个NTFS分区 4,操作系统版本(web版除外)   设置静态IP地址    dcpromo.exe命令不生效 ...

  7. mongo学习笔记(五):分片

    分片  人脸:       代表客户端,客户端肯定说,你数据库分片不分片跟我没关系,我叫你干啥就干啥,没什么好商量的. mongos: 首先我们要了解”片键“的概念,也就是说拆分集合的依据是什么?按照 ...

  8. 最短路径之迪杰斯特拉(Dijkstra)算法

    迪杰斯特拉(Dijkstra)算法主要是针对没有负值的有向图,求解其中的单一起点到其他顶点的最短路径算法.本文主要总结迪杰斯特拉(Dijkstra)算法的原理和算法流程,最后通过程序实现在一个带权值的 ...

  9. Node.js的模块载入方式与机制

    Node.js中模块可以通过文件路径或名字获取模块的引用.模块的引用会映射到一个js文件路径,除非它是一个Node内置模块.Node的内置模块公开了一些常用的API给开发者,并且它们在Node进程开始 ...

  10. git一些常用设置

    用法:git config [选项] 配置文件位置    --global              使用全局配置文件    --system              使用系统级配置文件    -- ...