转自:http://blog.csdn.net/abcjennifer/article/details/8198352

聚类算法K-Means, K-Medoids, GMM, Spectral clustering,Ncut一文中我们给出了GMM算法的基本模型与似然函数,在EM算法原理中对EM算法的实现与收敛性证明进行了详细说明。本文主要针对如何用EM算法在混合高斯模型下进行聚类进行代码上的分析说明。

1. GMM模型:

每个 GMM 由 K 个 Gaussian 分布组成,每个 Gaussian 称为一个“Component”,这些 Component 线性加成在一起就组成了 GMM 的概率密度函数:

根据上面的式子,如果我们要从 GMM 的分布中随机地取一个点的话,实际上可以分为两步:首先随机地在这 K个Gaussian Component 之中选一个,每个 Component 被选中的概率实际上就是它的系数 pi(k) ,选中了 Component 之后,再单独地考虑从这个 Component 的分布中选取一个点就可以了──这里已经回到了普通的 Gaussian 分布,转化为了已知的问题。

那么如何用 GMM 来做 clustering 呢?其实很简单,现在我们有了数据,假定它们是由 GMM 生成出来的,那么我们只要根据数据推出 GMM 的概率分布来就可以了,然后 GMM 的 K 个 Component 实际上就对应了 K 个 cluster 了。根据数据来推算概率密度通常被称作 density estimation ,特别地,当我们在已知(或假定)了概率密度函数的形式,而要估计其中的参数的过程被称作“参数估计”。

2. 参数与似然函数:

现在假设我们有 N 个数据点,并假设它们服从某个分布(记作 p(x) ),现在要确定里面的一些参数的值,例如,在 GMM 中,我们就需要确定 影响因子pi(k)、各类均值pMiu(k) 和 各类协方差pSigma(k) 这些参数。 我们的想法是,找到这样一组参数,它所确定的概率分布生成这些给定的数据点的概率最大,而这个概率实际上就等于  ,我们把这个乘积称作似然函数 (Likelihood Function)。通常单个点的概率都很小,许多很小的数字相乘起来在计算机里很容易造成浮点数下溢,因此我们通常会对其取对数,把乘积变成加和 ,得到 log-likelihood function 。接下来我们只要将这个函数最大化(通常的做法是求导并令导数等于零,然后解方程),亦即找到这样一组参数值,它让似然函数取得最大值,我们就认为这是最合适的参数,这样就完成了参数估计的过程。

下面让我们来看一看 GMM 的 log-likelihood function :

由于在对数函数里面又有加和,我们没法直接用求导解方程的办法直接求得最大值。为了解决这个问题,我们采取之前从 GMM 中随机选点的办法:分成两步,实际上也就类似于K-means 的两步。

3. 算法流程:

1.  估计数据由每个 Component 生成的概率(并不是每个 Component 被选中的概率):对于每个数据  来说,它由第  个 Component 生成的概率为

其中N(xi | μk,Σk)就是后验概率

2. 通过极大似然估计可以通过求到令参数=0得到参数pMiu,pSigma的值。具体请见这篇文章第三部分。

其中  ,并且  也顺理成章地可以估计为  。

3. 重复迭代前面两步,直到似然函数的值收敛为止。

4. matlab实现GMM聚类代码与解释:

 

说明:fea为训练样本数据,gnd为样本标号。算法中的思想和上面写的一模一样,在最后的判断accuracy方面,由于聚类和分类不同,只是得到一些 cluster ,而并不知道这些 cluster 应该被打上什么标签,或者说。由于我们的目的是衡量聚类算法的 performance ,因此直接假定这一步能实现最优的对应关系,将每个 cluster 对应到一类上去。一种办法是枚举所有可能的情况并选出最优解,另外,对于这样的问题,我们还可以用 Hungarian algorithm 来求解。具体的Hungarian代码我放在了资源里,调用方法已经写在下面函数中了。

注意:资源里我放的是Kmeans的代码,大家下载的时候只要用bestMap.m等几个文件就好~

 

1. gmm.m,最核心的函数,进行模型与参数确定。

  1. function varargout = gmm(X, K_or_centroids)
  2. % ============================================================
  3. % Expectation-Maximization iteration implementation of
  4. % Gaussian Mixture Model.
  5. %
  6. % PX = GMM(X, K_OR_CENTROIDS)
  7. % [PX MODEL] = GMM(X, K_OR_CENTROIDS)
  8. %
  9. %  - X: N-by-D data matrix.
  10. %  - K_OR_CENTROIDS: either K indicating the number of
  11. %       components or a K-by-D matrix indicating the
  12. %       choosing of the initial K centroids.
  13. %
  14. %  - PX: N-by-K matrix indicating the probability of each
  15. %       component generating each point.
  16. %  - MODEL: a structure containing the parameters for a GMM:
  17. %       MODEL.Miu: a K-by-D matrix.
  18. %       MODEL.Sigma: a D-by-D-by-K matrix.
  19. %       MODEL.Pi: a 1-by-K vector.
  20. % ============================================================
  21. % @SourceCode Author: Pluskid (http://blog.pluskid.org)
  22. % @Appended by : Sophia_qing (http://blog.csdn.net/abcjennifer)
  23. %% Generate Initial Centroids
  24. threshold = 1e-15;
  25. [N, D] = size(X);
  26. if isscalar(K_or_centroids) %if K_or_centroid is a 1*1 number
  27. K = K_or_centroids;
  28. Rn_index = randperm(N); %random index N samples
  29. centroids = X(Rn_index(1:K), :); %generate K random centroid
  30. else % K_or_centroid is a initial K centroid
  31. K = size(K_or_centroids, 1);
  32. centroids = K_or_centroids;
  33. end
  34. %% initial values
  35. [pMiu pPi pSigma] = init_params();
  36. Lprev = -inf; %上一次聚类的误差
  37. %% EM Algorithm
  38. while true
  39. %% Estimation Step
  40. Px = calc_prob();
  41. % new value for pGamma(N*k), pGamma(i,k) = Xi由第k个Gaussian生成的概率
  42. % 或者说xi中有pGamma(i,k)是由第k个Gaussian生成的
  43. pGamma = Px .* repmat(pPi, N, 1); %分子 = pi(k) * N(xi | pMiu(k), pSigma(k))
  44. pGamma = pGamma ./ repmat(sum(pGamma, 2), 1, K); %分母 = pi(j) * N(xi | pMiu(j), pSigma(j))对所有j求和
  45. %% Maximization Step - through Maximize likelihood Estimation
  46. Nk = sum(pGamma, 1); %Nk(1*k) = 第k个高斯生成每个样本的概率的和,所有Nk的总和为N。
  47. % update pMiu
  48. pMiu = diag(1./Nk) * pGamma' * X; %update pMiu through MLE(通过令导数 = 0得到)
  49. pPi = Nk/N;
  50. % update k个 pSigma
  51. for kk = 1:K
  52. Xshift = X-repmat(pMiu(kk, :), N, 1);
  53. pSigma(:, :, kk) = (Xshift' * ...
  54. (diag(pGamma(:, kk)) * Xshift)) / Nk(kk);
  55. end
  56. % check for convergence
  57. L = sum(log(Px*pPi'));
  58. if L-Lprev < threshold
  59. break;
  60. end
  61. Lprev = L;
  62. end
  63. if nargout == 1
  64. varargout = {Px};
  65. else
  66. model = [];
  67. model.Miu = pMiu;
  68. model.Sigma = pSigma;
  69. model.Pi = pPi;
  70. varargout = {Px, model};
  71. end
  72. %% Function Definition
  73. function [pMiu pPi pSigma] = init_params()
  74. pMiu = centroids; %k*D, 即k类的中心点
  75. pPi = zeros(1, K); %k类GMM所占权重(influence factor)
  76. pSigma = zeros(D, D, K); %k类GMM的协方差矩阵,每个是D*D的
  77. % 距离矩阵,计算N*K的矩阵(x-pMiu)^2 = x^2+pMiu^2-2*x*Miu
  78. distmat = repmat(sum(X.*X, 2), 1, K) + ... %x^2, N*1的矩阵replicateK列
  79. repmat(sum(pMiu.*pMiu, 2)', N, 1) - ...%pMiu^2,1*K的矩阵replicateN行
  80. 2*X*pMiu';
  81. [~, labels] = min(distmat, [], 2);%Return the minimum from each row
  82. for k=1:K
  83. Xk = X(labels == k, :);
  84. pPi(k) = size(Xk, 1)/N;
  85. pSigma(:, :, k) = cov(Xk);
  86. end
  87. end
  88. function Px = calc_prob()
  89. %Gaussian posterior probability
  90. %N(x|pMiu,pSigma) = 1/((2pi)^(D/2))*(1/(abs(sigma))^0.5)*exp(-1/2*(x-pMiu)'pSigma^(-1)*(x-pMiu))
  91. Px = zeros(N, K);
  92. for k = 1:K
  93. Xshift = X-repmat(pMiu(k, :), N, 1); %X-pMiu
  94. inv_pSigma = inv(pSigma(:, :, k));
  95. tmp = sum((Xshift*inv_pSigma) .* Xshift, 2);
  96. coef = (2*pi)^(-D/2) * sqrt(det(inv_pSigma));
  97. Px(:, k) = coef * exp(-0.5*tmp);
  98. end
  99. end
  100. end

2. gmm_accuracy.m调用gmm.m,计算准确率:

[cpp] view plaincopy

 
  1. function [ Accuracy ] = gmm_accuracy( Data_fea, gnd_label, K )
  2. %Calculate the accuracy Clustered by GMM model
  3. px = gmm(Data_fea,K);
  4. [~, cls_ind] = max(px,[],1); %cls_ind = cluster label
  5. Accuracy = cal_accuracy(cls_ind, gnd_label);
  6. function [acc] = cal_accuracy(gnd,estimate_label)
  7. res = bestMap(gnd,estimate_label);
  8. acc = length(find(gnd == res))/length(gnd);
  9. end
  10. end

3. 主函数调用

gmm_acc = gmm_accuracy(fea,gnd,N_classes);

写了本文进行总结后自己很受益,也希望大家可以好好YM下上面pluskid的gmm.m,不光是算法,其中的矩阵处理代码也写的很简洁,很值得学习。

另外看了两份东西非常受益,一个是pluskid大牛的漫谈 Clustering (3): Gaussian Mixture Model》,一个是JerryLead的EM算法详解,大家有兴趣也可以看一下,写的很好

GMM的EM算法实现的更多相关文章

  1. GMM及EM算法

    GMM及EM算法 标签(空格分隔): 机器学习 前言: EM(Exception Maximizition) -- 期望最大化算法,用于含有隐变量的概率模型参数的极大似然估计: GMM(Gaussia ...

  2. 高斯混合模型GMM与EM算法的Python实现

    GMM与EM算法的Python实现 高斯混合模型(GMM)是一种常用的聚类模型,通常我们利用最大期望算法(EM)对高斯混合模型中的参数进行估计. 1. 高斯混合模型(Gaussian Mixture ...

  3. 【机器学习】GMM和EM算法

    机器学习算法-GMM和EM算法 目录 机器学习算法-GMM和EM算法 1. GMM模型 2. GMM模型参数求解 2.1 参数的求解 2.2 参数和的求解 3. GMM算法的实现 3.1 gmm类的定 ...

  4. [转载]GMM的EM算法实现

    在聚类算法K-Means, K-Medoids, GMM, Spectral clustering,Ncut一文中我们给出了GMM算法的基本模型与似然函数,在EM算法原理中对EM算法的实现与收敛性证明 ...

  5. GMM的EM算法

    在聚类算法K-Means, K-Medoids, GMM, Spectral clustering,Ncut一文中我们给出了GMM算法的基本模型与似然函数,在EM算法原理中对EM算法的实现与收敛性证明 ...

  6. GMM与EM算法

    用EM算法估计GMM模型参数 参考  西瓜书 再看下算法流程

  7. 5. EM算法-高斯混合模型GMM+Lasso

    1. EM算法-数学基础 2. EM算法-原理详解 3. EM算法-高斯混合模型GMM 4. EM算法-GMM代码实现 5. EM算法-高斯混合模型+Lasso 1. 前言 前面几篇博文对EM算法和G ...

  8. 4. EM算法-高斯混合模型GMM详细代码实现

    1. EM算法-数学基础 2. EM算法-原理详解 3. EM算法-高斯混合模型GMM 4. EM算法-高斯混合模型GMM详细代码实现 5. EM算法-高斯混合模型GMM+Lasso 1. 前言 EM ...

  9. 3. EM算法-高斯混合模型GMM

    1. EM算法-数学基础 2. EM算法-原理详解 3. EM算法-高斯混合模型GMM 4. EM算法-高斯混合模型GMM详细代码实现 5. EM算法-高斯混合模型GMM+Lasso 1. 前言 GM ...

随机推荐

  1. android-APP-bluetooth

    1.创建工程项目 2.工程界面(教程3) 如下目录所示:src目录下MainActivity.java是程序:res下面都是图标等资源文件,layout下的activity_main.xml是按钮等界 ...

  2. Spell-DBC

    Spell.dbc 1  ID2  Attributes               属性3  AttributesEx             属性 4  AttributesExB         ...

  3. Android proguard 详解

    本文转载于:http://blog.csdn.net/banketree/article/details/41928175 简介 Java代码是非常容易反编译的.为了很好的保护Java源代码,我们往往 ...

  4. 【转】android 属性动画之 ObjectAnimator

    原文网址:http://blog.csdn.net/feiduclear_up/article/details/39255083 前面一篇博客讲解了 android 简单动画之 animtion,这里 ...

  5. 远程调试js注意事项

    1:使用host切换工具,先注释掉93服务器的地址,打开链接,点击高级选项,进去后登陆账号密码(如果不行重启浏览器): 2:进入后,增加93服务器上的host地址,重启浏览器,css样式生效: 3:使 ...

  6. pm2使用

    简单教程 首先需要安装pm2: npm install -g pm2 运行: pm2 start app.js 初次安装并运行,会有一个高大上的界面: 高大上的界面 直接我们介绍过forever,那么 ...

  7. consul 安装

    1. linux 下consul 安装 首先查看机器信息: uname -a Linux centos-linux.shared 3.10.0-327.el7.x86_64 #1 SMP Thu No ...

  8. [Maven] - 安装与Eclipse搭建

    Maven的具体参考书可以看:<Maven实战> 下载maven可以到:http://maven.apache.org/ Maven的eclipse基本使用可以在这里看到:http://w ...

  9. FormatFloat 格式化浮点数

    #和0的区别: #是对应位有值显示,无值不显示 0是对应位有值显示,无值显示0 分号后的字符串是对负值的格式化特殊定义:  s := FormatFloat(.);   .);   .);   .); ...

  10. pstools使用教程

    pstools是sysinternals开发的一个功能强大的nt/2k远程管理工具包. 官方网址为http://www.sysinternals.com/ 下载地址为http://www.sysint ...