EM算法是一种迭代算法,用于含有隐变量(hidden variable)的概率模型参数的极大似然估计,或极大后验概率估计。EM算法的每次迭代由两步组成:E步,求期望(expectation);M步,求极大(Maximization)。

EM算法的引入


给一些观察数据,可以使用极大似然估计法,或贝叶斯估计法估计模型参数。但是当模型含有隐变量时,就不能简单地使用这些方法。有些时候,参数的极大似然估计问题没有解析解,只能通过迭代的方法求解,EM算法就是可以用于求解这个问题的一种迭代算法。

  • EM算法
  • 输入:观测变量数据Y,隐变量数据Z,联合分布\(P(Y, X | \theta)\),条件分布\(P(Z|Y, \theta)\);
  • 输出:模型参数\(\theta\)
  1. 选择参数的初值\(\theta^{(0)}\),开始迭代;
  2. E步:记\(\theta^{(0)}\)为第\(i\)次迭代参数\(\theta\)的估计值,在第\(i+1\)次迭代的E步,计算\[Q(\theta, \theta^{(i)}) = E_Z[\log P(Y, Z|\theta) | Y, \theta^{(i)}]\] \[= \sum_Z \log P(Y, Z | \theta) P (Z|Y, \theta^{(i)})\] 这里,\(P(Z|Y, \theta^{(i)})\)是在给定观测数据Y和当前的参数估计\(\theta^{(i)}\)下隐变量数据Z的条件概率分布;
  3. M步:求使\(Q(\theta, \theta^{(i)})\)极大化的\(\theta\),确定第\(i+1\)次迭代的参数的估计值\(\theta^{(i+1)}\) \[\theta^{(i+1)} = \arg \max_{\theta} Q(\theta, \theta^{(i)})\]
  4. 重复第(2)步和第(3)步,直到收敛。

其中的函数\(Q(\theta, \theta^{(i)})\)是EM算法的核心。称为Q函数。

  • Q函数:完全数据的对数似然函数\(\log P(Y, Z | \theta)\)关于在给定观测数据Y和当前参数\(\theta^{(i)}\)下对未观测数据Z的条件概率分布\(P(Z|Y, \theta^{i})\)的期望,\[Q(\theta, \theta^{(i)}) = E_Z[\log P(Y, Z|\theta) | Y, \theta^{(i)}]\] \[= \sum_Z P(Z | Y, \theta^{(i)}) \log P(Y, Z | \theta)\]

  • 可以证明,每次迭代使似然函数增大或达到局部极值。

  • EM算法可以用于生成模型的非监督学习,生成模型由联合概率分布\(P(X, Y)\)表示,可以认为非监督学习训练数据是联合概率分布产生的数据,X为观测数据,Y为未观测数据。

  • 在应用中,初值的选择变得非常重要。常用的办法是选取几个不同的初值进行迭代,然后对得到的各个估计值加以比较,从中选择最好的。

EM算法在高斯混合模型学习中的应用


  • 高斯混合模型是指具有以下形式的概率分布模型:\[P(y | \theta) = \sum_{k = 1}^K \alpha_k \phi(y|\theta_k)\] 其中,\(\alpha_k\)是系数,\(\alpha_k \geq 0\),\(\sum_{k=1}^K \alpha_k = 1\);\(\phi\)是高斯分布密度,\(\theta_k = (\mu_k, \sigma_k^2)\),\[\phi(y|\theta_k) = \frac{1}{\sqrt{2\pi} \sigma_k} \exp (-\frac{(y - \mu_k)^2}{2 \sigma_k^2})\]称为第k个分模型。

使用EM算法进行高斯混合模型参数估计

假设观测数据\(y_1,y_2,...,y_N\)由高斯混合模型生成,\[P(y | \theta) = \sum_{k = 1}^K \alpha_k \phi(y|\theta_k)\]其中,\(\theta = (\alpha_1, \alpha_2, ..., \alpha_K; \theta_1, \theta_2,...,\theta_K)\)。我们用EM算法估计高斯混合模型的参数\(\theta\)。

  1. 明确隐变量,写出完全数据的对数似然函数

    可以认为观测数据是这样产生的:首先依概率\(\alpha_k\)选择第k个分模型,然后根据这个分模型的概率分布生成一个数据。计算出完全数据的对数似然函数为\[\log P(y, \gamma | \theta) = \sum_{k = 1}^K n_k \log \alpha_k + \sum_{j = 1}^N \gamma_{jk} [\log(\frac{1}{\sqrt{2\pi}}) - \log \sigma_k - \frac{1}{2\sigma^2} (y_i - \mu_k)^2]\]其中,\(\gamma_{jk}\)为1代表第j个观测来自第k个分模型,否则为0,\(n_k = \sum_{j=1}^N \gamma_{jk}\),\(\sum_{k=1}^K n_k = N\)。
    观测数据\(y_j\)及未观测数据\(\gamma_{jk}\),完全数据是\((y_j,\gamma_{j1},\gamma_{j2},...,\gamma_{jK}, j = 1,2,...,N\)

  2. EM算法的E步:确定Q函数
    计算出Q函数为:\[Q(\theta, \theta^{(i)}) = \sum_{k=1}^K n_k \log \alpha_k + \sum_{k=1}^N \hat{\gamma}_{jk} [\log (\frac{1}{\sqrt{2\pi}}) - \log \sigma_k - \frac{1}{2\sigma_k^2} (y_i - \mu_k)^2]\] 其中,\[\hat{\gamma}_{jk} = E(\gamma_{jk} | y, \theta) = \frac{\alpha_k \phi (y_i | \theta_k)}{\sum_{k=1}^K \alpha_k \phi (y_i | \theta_k)}\] \[n_k = \sum_{j = 1}^N E\gamma_{jk}\]

  3. 确定EM算法中的M步
    迭代的M步是求函数\(Q(\theta, \theta^{(i)})\)对\(\theta\)的极大值,即求新一轮迭代的模型参数:\[\theta^{(i+1)} = \arg \max_{\theta} Q(\theta, \theta^{(i)})\]
    通过求偏导可以得到计算相应参数的表达式。

  • 高斯混合模型参数估计的EM算法
  • 输入:观测数据\(y_1, y_2,...,y_N\),高斯混合模型
  • 输出:高斯混合模型参数。
  1. 取参数的初始值开始迭代
  2. E步:依据当前模型参数,计算分模型k对观测数据\(y_i\)的响应度\[\hat{\gamma}_{jk} = \frac{\alpha_k \phi (y_i | \theta_k)}{\sum_{k=1}^K \alpha_k \phi (y_i | \theta_k)}, j = 1,2,...,N; k = 1,2,...,K\]
  3. M步:计算新一轮迭代的模型参数:\[\hat{\mu_k} = \frac{\sum_{j=1}^N \hat{\gamma}_{jk} y_j}{\sum_{j = 1}^N \hat{\gamma}_{jk}}, k = 1,2,...,K\] \[\hat{\sigma}_k^2 = \frac{\sum_{j=1}^N \hat{\gamma}_{jk} (y_j - \mu_k)^2}{\sum_{k=1}^N \hat{\gamma}_{jk}}, k = 1,2,...,K\] \[\hat{\alpha}_k = \frac{\sum_{j=1}^N \hat{\gamma}_{jk}}{N}, k = 1,2,...,K\]
  4. 重复第(2)步和第(3)步,直到收敛。
  • EM算法还可以解释为F函数的极大-极大算法,基于这个解释有若干变形与推广,如广义期望极大(GEM)算法。

(注:本文为读书笔记与总结,侧重算法原理,来源为《统计学习方法》一书第九章)


作者:rubbninja
出处:http://www.cnblogs.com/rubbninja/
关于作者:目前主要研究领域为机器学习与无线定位技术,欢迎讨论与指正!

学习笔记——EM算法的更多相关文章

  1. [ML学习笔记] XGBoost算法

    [ML学习笔记] XGBoost算法 回归树 决策树可用于分类和回归,分类的结果是离散值(类别),回归的结果是连续值(数值),但本质都是特征(feature)到结果/标签(label)之间的映射. 这 ...

  2. 统计学习方法笔记--EM算法--三硬币例子补充

    本文,意在说明<统计学习方法>第九章EM算法的三硬币例子,公式(9.5-9.6如何而来) 下面是(公式9.5-9.8)的说明, 本人水平有限,怀着分享学习的态度发表此文,欢迎大家批评,交流 ...

  3. 机器学习笔记—EM 算法

    EM 算法所面对的问题跟之前的不一样,要复杂一些. EM 算法所用的概率模型,既含有观测变量,又含有隐变量.如果概率模型的变量都是观测变量,那么给定数据,可以直接用极大似然估计法,或贝叶斯估计法来估计 ...

  4. 学习笔记 - Manacher算法

    Manacher算法 - 学习笔记 是从最近Codeforces的一场比赛了解到这个算法的~ 非常新奇,毕竟是第一次听说 \(O(n)\) 的回文串算法 我在 vjudge 上开了一个[练习],有兴趣 ...

  5. 数据挖掘学习笔记--AdaBoost算法(一)

    声明: 这篇笔记是自己对AdaBoost原理的一些理解,如果有错,还望指正,俯谢- 背景: AdaBoost算法,这个算法思路简单,但是论文真是各种晦涩啊-,以下是自己看了A Short Introd ...

  6. 学习笔记-KMP算法

    按照学习计划和TimeMachine学长的推荐,学习了一下KMP算法. 昨晚晚自习下课前粗略的看了看,发现根本理解不了高端的next数组啊有木有,不过好在在今天系统的学习了之后感觉是有很大提升的了,起 ...

  7. Java学习笔记——排序算法之快速排序

    会当凌绝顶,一览众山小. --望岳 如果说有哪个排序算法不能不会,那就是快速排序(Quick Sort)了 快速排序简单而高效,是最适合学习的进阶排序算法. 直接上代码: public class Q ...

  8. Java学习笔记——排序算法之进阶排序(堆排序与分治并归排序)

    春蚕到死丝方尽,蜡炬成灰泪始干 --无题 这里介绍两个比较难的算法: 1.堆排序 2.分治并归排序 先说堆. 这里请大家先自行了解完全二叉树的数据结构. 堆是完全二叉树.大顶堆是在堆中,任意双亲值都大 ...

  9. Java学习笔记——排序算法之希尔排序(Shell Sort)

    落日楼头,断鸿声里,江南游子.把吴钩看了,栏杆拍遍,无人会,登临意. --水龙吟·登建康赏心亭 希尔算法是希尔(D.L.Shell)于1959年提出的一种排序算法.是第一个时间复杂度突破O(n²)的算 ...

随机推荐

  1. T-SQL 去除特定字段的前导0

    在工作过程中遇到一个需求,要从特定字段中删除前导零,这是一个简单的VARCHAR(10)字段. 例如,如果字段包含"00001A",则SELECT语句需要将数据返回为"1 ...

  2. Solr实战:使用Hue+Solr实现标签查询

    公司最近在研究多条件组合查询方案,Google的一位技术专家Sam和我们讨论了几个备选方案. Sam的信: 我做了进一步研究,目前有这么几种做法: 1) 最直接粗暴,只做一个主index,比如按行业+ ...

  3. Yii2.0.7 限制user module登录遇到的问题

    在Yii2.0.6的时候我是在以下文件通过以下方法实现的. frontend/modules/user/Module.php namespace frontend\modules\user; clas ...

  4. 简单socket()编程

    客户端: 1.socket( int af, int type, int protocol) socket()函数用于根据指定的地址族.数据类型和协议来分配一个套接口的描述字及其所用的资源.如果协议p ...

  5. easy_UI 投票列表

    首先我们考虑一下在项目投票种用到的属性(ID,投票标题,备选项目,参与人数) entity package cn.entity; public class GridNode { private Lon ...

  6. Re 模块

    re模块提供方法如compile, search, findall, match和其他的方法.这些函数是使用REGEX语法建立了一个模式来处理文本的. 第一个方法:search. 一个基本的搜索工作原 ...

  7. 杂项之图像处理pillow

    杂项之图像处理pillow 本节内容 参考文献 生成验证码源码 一些小例子 1. 参考文献 http://pillow-cn.readthedocs.io/zh_CN/latest/ pillow中文 ...

  8. ngx_http_proxy_module模块.md

    ngx_http_proxy_module ngx_http_proxy_module模块允许将请求传递到另一个服务器. proxy_bind Syntax: proxy_bind address [ ...

  9. Oracle 判断某個字段的值是不是数字

    转:https://my.oschina.net/bairrfhoinn/blog/207835 摘要: 壹共有三种方法,分别是使用 to_number().regexp_like() 和 trans ...

  10. Maven打包含有Main方法jar并运行

    最近使用Kettle做定时数据抽取,因为Job更新或需求变更,修改Bug等种种原因,需要对重跑Job一般是针对每天的数据重跑一次.刚开始的做法是直接在自己的开发机器上重跑,这样速度比较慢,因为这时候你 ...