今天要来讨论的是EM算法。第一眼看到EM我就想到了我大枫哥,EM Master,千里马。RUA!!!不知道看这个博客的人有没有懂这个梗的。

好的,言归正传。今天要讲的EM算法,全称是Expectation maximization。期望最大化。

怎么个意思呢,就是给你一堆观測样本。让你给出这个模型的參数预计。我靠,这套路我们前面讨论各种回归的时候不是已经用烂了吗?求期望,求对数期望,求导为0,得到參数预计值。这套路我懂啊,MLE!

但问题在于,假设这个问题存在中间的隐变量呢?会不会把我们的套路给带崩呢,我们通过两个样例来认识一下这两种情况。

====================================================================

不存在中间变量的EM。

如果有一天人类消除了性别的区别。全部的人都是同一个性别。

这个时候。我给了你一群人的身高让你给我做一个预计人身高的模型。

怎么办呢?感觉上身高应该是服从高斯分布吧,所以如果人的身高分布服从高斯分布N(Mu,Sigma^2),如今我又有了N个人的身高的数据,我就能够照着上面的套路进行了。先求对数似然函数

接着对两个參数求偏导为0

这样就得到了我们的參数预计

喜闻乐见的结果,又好求又符合我们的直觉,那我们再来看看还有一种情况。

====================================================================

存在中间变量的EM。

正如你所知。身高和人种的关系挺大的,而人类又有那么多种族,所以,再给你一堆人的数据,要做一个预计人身高的模型。那我们应该怎么做呢?

首先,如今分为不同种族若干类了,这些类别的概率肯定有个分布,其次,各种族其中身高是服从不同的分布的。那么这样身高的预计就变成了

Alpha代表了该样本属于某一人种的比例,事实上就是隐藏的中间变量。Muk和sigmak^2为各类高斯分布的參数。依照我们上面的套路就是求对数似然概率再求导得到參数的预计,那么先来看看似然函数

这下尴尬的情况出现了。对数里面带加号,这下求导就变得复杂异常了。并且没法求解,其实。这样的式子确实没有解析解。只是憋灰心啊。如果我们随便猜一个alpha的分布为Q,那么对数似然函数能够写成

因为Q是alpha的一个分布。所以似然函数能够看成是一个log(E(x)),log是个凹函数啊。割线始终在函数图像下方,Jensen不等式反向应用一下,有log(E(x))>=E(log(x))。所以上面的对数似然有

watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvc2luYXRfMjI1OTQzMDk=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" alt="">

冷静分析一下如今的情况,我们如今得到了一个对数似然函数的下界函数,我们採用曲线救国的战略,我们求解它的局部最大值,那么更新后的參数带入这个下界函数一定比之前的參数值大,而它本身又是对数似然函数的下界函数,所以參数更新后。我们的对数似然函数一定是变大了!所以。就利用这样的方法进行迭代,最后就能得到比較好的參数预计。还有点晕吗。没事,我从百度扒个图给你来个形象的解释

红色那条线就是我们的对数似然函数,蓝色那条是我们在当前參数下找到的对数似然的下界函数,能够看到。我们找到它的局部极值那么參数更新成thetanew,此时对数似然函数的值也得到了上升。这样反复进行下去。是不是就能够收敛到对数似然函数的一个局部极值了嘛。对的,局部极值,并不能保证是全局最优。但它就是个预计嘛,你还要她如何?。

到了这里。我们好像跳着先把第二步參数更新的工作做完了,那么另一个事情是我们要注意的,Q呢,Q是啥,没Q你算啥极值,更新啥參数。我们已经知道Q是alpha的一个分布,然后我们肯定是希望这个下界函数尽量贴近原来的对数似然函数。这样我们才干更快地更新參数。那下界函数啥时最大呢。等号成立呗。等号成立说明你求期望的对象是个常数呀,所以log和Q谁在前后都无所谓,那么就有了

直观地能够理解成第i个样本来自第k个类别的可能性。好了,如今Q也确定了,我们依据上面所说的方法更新參数,再更新Q,再更新參数,迭代进行下去就能够了。

假设你能坚持看到这里。少侠我仅仅能说你大功已成。由于事实上我们已经把EM算法整个推导完了,或许你还是有点云里雾里,那我们再来细致梳理一下这个流程

1 拿到全部的观測样本,依据先验或者喜好先给一个參数预计。

2 依据这个參数预计和样本计算类别分布Q。得到最贴近对数似然函数的下界函数。

3 对下界函数求极值。更新參数分布。

4 迭代计算,直至收敛。

说起来啊。EM算法据说是机器学习进阶的一个算法,但至少眼下来看,它的思路还是非常easy理解的嘛。整个过程中唯一一个可能刚開始学习的人认为有点绕的地方就是应用Jensen不等式的那一步,那我再啰嗦两句。所谓Jensen不等式,你能够这么理解。对于一个凸函数f而言。它的割线始终在函数图像上方你承认吧,我在上面任取两点x1。x2。參数theta介于0到1之间,那么theta*x1+(1-theta)*x2就是介于x1和x2之间的一点吧。在这点上过x1x2割线的值大于函数值吧。是不是就有了theta*f(x1)+(1-theta)*f(x2)>f(theta*x1+(1-theta)*x2)。依据这个结论再推广开来,就有E(f(x))>f(E(x)),在对数似然函数中,因为log是个凹函数,所以把它反过来用,老铁没毛病吧?!

这一点想通了我认为整个EM算法的流程还是蛮好懂的。

以下呢,我们还回到这个身高模型的预測。如果给了m个样本,有k个种族,每一个种族的身高都是服从高斯分布的,那么这就变成了EM算法中最具代表性的一个样例,高斯混合模型GMM。

====================================================================高斯混合模型(GMM)

刚才已经讲了EM算法的套路了,如果我们如今处于某一步迭代中,那么我们该干嘛呢?

E-step 求最佳的类别分布

能够将其理解为第i个样本属于第J类的概率。

M-step 更新參数

求得了Q之后,我们就得到了最贴近原对数似然函数的下界函数。那我们对它求极值就能够得到更新后的參数。先看一下这个下界函数

Log函数里面全是乘积项这是我们最喜欢的形式,这样求导的时候但凡不相关的我们直接扔掉即可。待求參数mu,sigma^2,psi,依次求导为0就成。

watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvc2luYXRfMjI1OTQzMDk=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" alt="">

对于psi的求解可能复杂一些。首先我们把下界函数中与psi不相关的项所有去掉。然后psi作为各类别的比例有一个天然的约束条件就是所有的psi之和为1,所以目标函数变成

这样的带约束的优化前面在SVM的时候不知道用了多少回。拉格朗日乘子法

接下来对psi求导

两边同一时候再对j从1到k连加,psi那一项就没了,右式就变成样本数目m。这样就求得了beta,回代我们就能够求得psi參数的更新

至此,全部的參数更新工作就已完毕,以下反复进行迭代即可了。

我们先把GMM的算法梳理一下

1 给參数取初始值,開始迭代。

2 求每一个样本对每一个类别的概率。科学的叫法叫求响应度

3 更新模型參数

4 反复23两步直至收敛。

我们再来看看这些參数的意义,事实上未尝不符合我们的直觉认识。W(i,j)能够看做第i个样本属于第j类的概率。那么全部样本中属于第j类的个数就是w(i,j)之和,每一个样本xi相应第j类的值就是W(i,j) xi,这样算的平均数就是第j类相应的mu,继续依照这个思路算的方差就是第j类的sigma^2。第j类的概率就是第j类的个数除以总样本数。所以,GMM模型尽管推导起来有点吓人,但细致想想它最后的结果也是符合我们的直觉认识的,每一个样本都是一部分属于某一类,全部样本中的某一类的部分构成了这一类的分布,perfect。!。

====================================================================

这种话,理论部分我们就讲完了。接下来又是调包侠的时刻了,上次写完后我想到鸢尾花数据无监督算法也能做啊。不给标签我们强行给它分类看看效果怎样。所以这里我们K-Means和GMM算法分别对鸢尾花进行处理,看看它们的聚类效果怎样。

代码例如以下

import numpy as np
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
#读取数据
iris=datasets.load_iris()
x=iris.data[:,:2]
y=iris.target
mu = np.array([np.mean(x[y == i], axis=0) for i in range(3)])
print '实际均值 = \n', mu
#K-Means
kmeans=KMeans(n_clusters=3,init='k-means++',random_state=0)
y_hat1=kmeans.fit_predict(x)
mu1=np.array([np.mean(x[y_hat1 == i], axis=0) for i in range(3)])
print 'K-Means均值 = \n', mu1
print '分类正确率为',np.mean(y_hat1==y)
gmm=GaussianMixture(n_components=3,covariance_type='full', random_state=0)
gmm.fit(x)
print 'GMM均值 = \n', gmm.means_
y_hat2=gmm.predict(x)
print '分类正确率为',np.mean(y_hat2==y)

输出结果为

实际均值 =

[[5.006  3.418]

[5.936  2.77 ]

[6.588  2.974]]

K-Means均值 =

[[5.77358491  2.69245283]

[ 5.006      3.418     ]

[ 6.81276596 3.07446809]]

分类正确率为 0.233333333333

GMM均值 =

[[5.01494511  3.44040237]

[ 6.69225795 3.03018616]

[ 5.90652226 2.74740414]]

分类正确率为 0.533333333333怒摔键盘啊,什么破正确率呀!

!憋急啊,我看事情并不简单。

机智的我们观察一下均值矩阵。K-Means给出的第一行似乎和实际的第二行非常接近,第二行和实际的第一行非常接近。相同。GMM给出的均值矩阵也有相同的问题。第二行和第三行似乎对调了。这不是算法的锅啊。它仅仅管给你聚类。哪里还能保证标签和你一样啊,三个类别六种标签方式人家算法也仅仅能随机一种好吗,所以如今我们把预測的结果的标签改一下看看实际的正确率怎样。

import numpy as np
from sklearn import datasets
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
#读取数据
iris=datasets.load_iris()
x=iris.data[:,:2]
y=iris.target
mu = np.array([np.mean(x[y == i], axis=0) for i in range(3)])
print '实际均值 = \n', mu
#K-Means
kmeans=KMeans(n_clusters=3,init='k-means++',random_state=0)
y_hat1=kmeans.fit_predict(x)
y_hat1[y_hat1==0]=3
y_hat1[y_hat1==1]=0
y_hat1[y_hat1==3]=1
mu1=np.array([np.mean(x[y_hat1 == i], axis=0) for i in range(3)])
print 'K-Means均值 = \n', mu1
print '分类正确率为',np.mean(y_hat1==y)
gmm=GaussianMixture(n_components=3,covariance_type='full', random_state=0)
gmm.fit(x)
print 'GMM均值 = \n', gmm.means_
y_hat2=gmm.predict(x)
y_hat2[y_hat2==1]=3
y_hat2[y_hat2==2]=1
y_hat2[y_hat2==3]=2
print '分类正确率为',np.mean(y_hat2==y)

输出结果为

实际均值 =

[[5.006  3.418]

[ 5.936 2.77 ]

[ 6.588 2.974]]

K-Means均值 =

[[5.006       3.418     ]

[ 5.77358491 2.69245283]

[ 6.81276596 3.07446809]]

分类正确率为 0.82

GMM均值 =

[[5.01494511  3.44040237]

[ 6.69225795 3.03018616]

[ 5.90652226 2.74740414]]

分类正确率为 0.786666666667

啊。这种结果还是比較让人惬意的,甚至比前面的一些监督学习的结果还要好一些……另外,标签不一致的问题我这里採用的是最蠢的手动调整,大家当然能够依据你算出的均值矩阵每行与原始均值矩阵哪行的距离最小,确定它在原始数据中的标签自己主动调整,这当然是OK的,我这里偷一点懒。

好了。愉快的工作日又要结束了,哈哈哈,周末你好!!。

机器学习笔记(十)EM算法及实践(以混合高斯模型(GMM)为例来次完整的EM)的更多相关文章

  1. 机器学习进阶-背景建模-(帧差法与混合高斯模型) 1.cv2.VideoCapture(进行视频读取) 2.cv2.getStructureElement(构造形态学的卷积) 3.cv2.createBackgroundSubtractorMOG2(构造高斯混合模型) 4.cv2.morpholyEx(对图像进行形态学的变化)

    1. cv2.VideoCapture('test.avi') 进行视频读取 参数说明:‘test.avi’ 输入视频的地址2. cv2.getStructureElement(cv2.MORPH_E ...

  2. PRML读书会第九章 Mixture Models and EM(Kmeans,混合高斯模型,Expectation Maximization)

    主讲人 网络上的尼采 (新浪微博: @Nietzsche_复杂网络机器学习) 网络上的尼采(813394698) 9:10:56 今天的主要内容有k-means.混合高斯模型. EM算法.对于k-me ...

  3. 记录:EM 算法估计混合高斯模型参数

    当概率模型依赖于无法观测的隐性变量时,使用普通的极大似然估计法无法估计出概率模型中参数.此时需要利用优化的极大似然估计:EM算法. 在这里我只是想要使用这个EM算法估计混合高斯模型中的参数.由于直观原 ...

  4. 混合高斯模型(Mixtures of Gaussians)和EM算法

    这篇讨论使用期望最大化算法(Expectation-Maximization)来进行密度估计(density estimation). 与k-means一样,给定的训练样本是,我们将隐含类别标签用表示 ...

  5. EM算法与混合高斯模型

    非常早就想看看EM算法,这个算法在HMM(隐马尔科夫模型)得到非常好的应用.这个算法公式太多就手写了这部分主体部分. 好的參考博客:最大似然预计到EM,讲了详细样例通熟易懂. JerryLead博客非 ...

  6. <转>与EM相关的两个算法-K-mean算法以及混合高斯模型

    转自http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006924.html http://www.cnblogs.com/jerrylead/ ...

  7. 机器学习3_EM算法与混合高斯模型

    ①EM算法: http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006936.html 李航 <统计学习方法>9.1节 ②混合高斯模 ...

  8. EM相关两个算法 k-mean算法和混合高斯模型

    转自http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006924.html http://www.cnblogs.com/jerrylead/ ...

  9. 混合高斯模型的EM求解(Mixtures of Gaussians)及Python实现源代码

    今天为大家带来混合高斯模型的EM推导求解过程. watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQveHVhbnl1YW5zZW4=/font/5a6L5L2T/ ...

随机推荐

  1. HDU 4708 Rotation Lock Puzzle (简单题)

    Rotation Lock Puzzle Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Oth ...

  2. 高手写的“iOS 的多核编程和内存管理”

    原文地址:http://anxonli.iteye.com/blog/1097777 多核运算 在iOS中concurrency编程的框架就是GCD(Grand Central Dispatch), ...

  3. 2008技术内幕:T-SQL语言基础 单表查询摘记

    这里的摘抄来自<Microsoft SQL Server 2008技术内幕:T-SQL语言基础>,书中用到的案例数据库是这个 TSQLFundamentals2008 ,官网给出的连接是这 ...

  4. Python已成为网络攻击的首选编程语言

    Python已成为网络攻击的首选编程语言 最新的调查数据表明,Python已经变成了世界上最热门的编程语言了,而Python的热门风也刮到了信息安全领域中.Python,摇身一变,也变成了黑客开发网络 ...

  5. sscanf,sprintf用法

    #include<string.h> #include<stdio.h> int main() { ],sztime1[],sztime2[]; sscanf("12 ...

  6. bashrc和profile的用途和区别

    使用终端登录Linux操作系统的控制台后,会出现一个提示符号(例如:#或~),在这个提示符号之后可以输入命令,Linux根据输入的命令会做回应,这一连串的动作是由一个所谓的Shell来做处理. She ...

  7. Flask 学习(三)模板

    Flask 学习(三)模板 Flask 为你配置 Jinja2 模板引擎.使用 render_template() 方法可以渲染模板,只需提供模板名称和需要作为参数传递给模板的变量就可简单执行. 至于 ...

  8. 混沌数学之Henon模型

    相关DEMO参见:混沌数学之离散点集图形DEMO 相关代码: // http://wenku.baidu.com/view/d51372a60029bd64783e2cc0.html?re=view ...

  9. 如何在Window 7 64位 PL/SQL 访问oracle 数据库

    一般 PLSQL Developer 没有64位版本,所以在64位系统上运行该程链接64位Oracle时就会报错.解决的方法如下: 第零步:在windows 7 中安装Oracle 11g 64 数据 ...

  10. scala 学习笔记五 foreach, map, reduce

    例子 val v = Vector(,,,) ) println(s) //输出:Vector(2, 4, 6, 8) val v2 = Vector(,,,) var v3 = v2.reduce( ...