6. EM算法-高斯混合模型GMM+Lasso详细代码实现
1. 前言
我们之前有介绍过4. EM算法-高斯混合模型GMM详细代码实现,在那片博文里面把GMM说涉及到的过程,可能会遇到的问题,基本讲了。今天我们升级下,主要一起解析下EM算法中GMM(搞事混合模型)带惩罚项的详细代码实现。
2. 原理
由于我们的极大似然公式加上了惩罚项,所以整个推算的过程在几个地方需要修改下。
在带penality的GMM中,我们假设协方差是一个对角矩阵,这样的话,我们计算高斯密度函数的时候,只需要把样本各个维度与对应的\(\mu_k\)和\(\sigma_k\)计算一维高斯分布,再相加即可。不需要通过多维高斯进行计算,也不需要协方差矩阵是半正定的要求。
我们给上面的(1)式加入一个惩罚项,
\]
其中的\(P\)是样本的维度。\(\bar{x}_j\)表示每个维度的平均值,\(s_j\)表示每个维度的标准差。这个penality是一个L1范式,对\(\mu_k\)进行约束。
加入penality后(1)变为
\]
这里需要注意的一点是,因为penality有一个绝对值,所以在对\(\mu_k\)求导的时候,需要分情况。于是(2)变成了
\]
\left \{\begin{array}{cc}
\frac{1}{n_k}(\sum_{i=1}^N\gamma_{ik}x_i - \frac{\lambda\sigma^2}{s_j}), & \mu_k >= \bar{x}_j\\
\frac{1}{n_k}(\sum_{i=1}^N\gamma_{ik}x_i + \frac{\lambda\sigma^2}{s_j}), & \mu_k < \bar{x}_j
\end{array}\right.
\]
3. 算法实现
- 和不带惩罚项的GMM不同的是,我们GMM+LASSO的计算高斯密度函数有所变化。
#计算高斯密度概率函数,样本的高斯概率密度函数,其实就是每个一维mu,sigma的高斯的和
def log_prob(self, X, mu, sigma):
N, D = X.shape
logRes = np.zeros(N)
for i in range(N):
a = norm.logpdf(X[i,:], loc=mu, scale=sigma)
logRes[i] = np.sum(a)
return logRes
- 在m-step中计算\(\mu_{k+1}\)的公式需要变化,先通过比较\(\mu_{kj}\)和\(means_{kj}\)的大小,来确定绝对值shift的符号。
def m_step(self, step):
gammaNorm = np.array(np.sum(self.gamma, axis=0)).reshape(self.K, 1)
self.alpha = gammaNorm / np.sum(gammaNorm)
for k in range(self.K):
Nk = gammaNorm[k]
if Nk == 0:
continue
for j in range(self.D):
if step >= self.beginPenaltyTime:
# 算出penality的偏移量shift,通过当前维度的mu和样本均值比较,确定shift的符号,相当于把lasso的绝对值拆开了
shift = np.square(self.sigma[k, j]) * self.penalty / (self.std[j] * Nk)
if self.mu[k, j] >= self.means[j]:
shift = shift
else:
shift = -shift
else:
shift = 0
self.mu[k, j] = np.dot(self.gamma[:, k].T, self.X[:, j]) / Nk - shift
self.sigma[k, j] = np.sqrt(np.sum(np.multiply(self.gamma[:, k], np.square(self.X[:, j] - self.mu[k, j]))) / Nk)
- 最后需要修改loglikelihood的计算公式
def GMM_EM(self):
self.init_paras()
for i in range(self.times):
#m step
self.m_step(i)
# e step
logGammaNorm, self.gamma= self.e_step(self.X)
#loglikelihood
loglike = self.logLikelihood(logGammaNorm)
#penalty
pen = 0
if i >= self.beginPenaltyTime:
for j in range(self.D):
pen += self.penalty * np.sum(abs(self.mu[:,j] - self.means[j])) / self.std[j]
# print("step = %s, alpha = %s, loglike = %s"%(i, [round(p[0], 5) for p in self.alpha.tolist()], round(loglike - pen, 5)))
# if abs(self.loglike - loglike) < self.tol:
# break
# else:
self.loglike = loglike - pen
4. GMM算法实现结果
用我实现的GMM+LASSO算法,对多个penality进行计算,选出loglikelihood最大的k和penality,与sklearn的结果比较。
fileName = amix1-est.dat, k = 2, penalty = 0 alpha = [0.52838, 0.47162], loglike = -693.34677
fileName = amix1-est.dat, k = 2, penalty = 0 alpha = [0.52838, 0.47162], loglike = -693.34677
fileName = amix1-est.dat, k = 2, penalty = 1 alpha = [0.52789, 0.47211], loglike = -695.26835
fileName = amix1-est.dat, k = 2, penalty = 1 alpha = [0.52789, 0.47211], loglike = -695.26835
fileName = amix1-est.dat, k = 2, penalty = 2 alpha = [0.52736, 0.47264], loglike = -697.17009
fileName = amix1-est.dat, k = 2, penalty = 2 alpha = [0.52736, 0.47264], loglike = -697.17009
myself GMM alpha = [0.52838, 0.47162], loglikelihood = -693.34677, bestP = 0
sklearn GMM alpha = [0.53372, 0.46628], loglikelihood = -176.73112
succ = 299/300
succ = 0.9966666666666667
[0 1 0 0 1 1 0 1 1 1 0 0 1 0 0 1 0 0 0 1]
[0 1 0 0 1 0 0 1 1 1 0 0 1 0 0 1 0 0 0 1]
fileName = amix1-tst.dat, loglike = -2389.1852339407087
fileName = amix1-val.dat, loglike = -358.1157431278091
fileName = amix2-est.dat, k = 2, penalty = 0 alpha = [0.56, 0.44], loglike = 53804.54265
fileName = amix2-est.dat, k = 2, penalty = 0 alpha = [0.82, 0.18], loglike = 24902.5522
fileName = amix2-est.dat, k = 2, penalty = 1 alpha = [0.82, 0.18], loglike = 23902.65183
fileName = amix2-est.dat, k = 2, penalty = 1 alpha = [0.56, 0.44], loglike = 52929.96459
fileName = amix2-est.dat, k = 2, penalty = 2 alpha = [0.82, 0.18], loglike = 22907.40397
fileName = amix2-est.dat, k = 2, penalty = 2 alpha = [0.82, 0.18], loglike = 22907.40397
myself GMM alpha = [0.56, 0.44], loglikelihood = 53804.54265, bestP = 0
sklearn GMM alpha = [0.56217, 0.43783], loglikelihood = 11738677.90164
succ = 200/200
succ = 1.0
[0 1 0 0 1 0 1 1 0 0 0 1 1 0 1 0 1 1 0 1]
[0 1 0 0 1 0 1 1 0 0 0 1 1 0 1 0 1 1 0 1]
fileName = amix2-tst.dat, loglike = 51502.878096147084
fileName = amix2-val.dat, loglike = 6071.217012747491
fileName = golub-est.dat, k = 2, penalty = 0 alpha = [0.575, 0.425], loglike = -24790.19895
fileName = golub-est.dat, k = 2, penalty = 0 alpha = [0.525, 0.475], loglike = -24440.82743
fileName = golub-est.dat, k = 2, penalty = 1 alpha = [0.55, 0.45], loglike = -25582.27485
fileName = golub-est.dat, k = 2, penalty = 1 alpha = [0.6, 0.4], loglike = -26137.97508
fileName = golub-est.dat, k = 2, penalty = 2 alpha = [0.55, 0.45], loglike = -26686.02411
fileName = golub-est.dat, k = 2, penalty = 2 alpha = [0.55, 0.45], loglike = -26941.68964
myself GMM alpha = [0.525, 0.475], loglikelihood = -24440.82743, bestP = 0
sklearn GMM alpha = [0.5119, 0.4881], loglikelihood = 13627728.10766
succ = 29/40
succ = 0.725
[0 1 0 1 0 1 0 1 0 1 0 0 0 0 0 1 1 1 0 1]
[0 1 0 1 1 1 0 0 1 1 0 1 0 1 0 0 1 1 0 0]
fileName = golub-tst.dat, loglike = -12949.606698037718
fileName = golub-val.dat, loglike = -11131.35137056415
5. 总结
通过一番改造,实现了GMM+LASSO的代码,如果读者有什么好的改进方法,或者我有什么错误的地方,希望多多指教。
6. EM算法-高斯混合模型GMM+Lasso详细代码实现的更多相关文章
- 5. EM算法-高斯混合模型GMM+Lasso
1. EM算法-数学基础 2. EM算法-原理详解 3. EM算法-高斯混合模型GMM 4. EM算法-GMM代码实现 5. EM算法-高斯混合模型+Lasso 1. 前言 前面几篇博文对EM算法和G ...
- 4. EM算法-高斯混合模型GMM详细代码实现
1. EM算法-数学基础 2. EM算法-原理详解 3. EM算法-高斯混合模型GMM 4. EM算法-高斯混合模型GMM详细代码实现 5. EM算法-高斯混合模型GMM+Lasso 1. 前言 EM ...
- 3. EM算法-高斯混合模型GMM
1. EM算法-数学基础 2. EM算法-原理详解 3. EM算法-高斯混合模型GMM 4. EM算法-高斯混合模型GMM详细代码实现 5. EM算法-高斯混合模型GMM+Lasso 1. 前言 GM ...
- EM算法和高斯混合模型GMM介绍
EM算法 EM算法主要用于求概率密度函数参数的最大似然估计,将问题$\arg \max _{\theta_{1}} \sum_{i=1}^{n} \ln p\left(x_{i} | \theta_{ ...
- 高斯混合模型GMM与EM算法的Python实现
GMM与EM算法的Python实现 高斯混合模型(GMM)是一种常用的聚类模型,通常我们利用最大期望算法(EM)对高斯混合模型中的参数进行估计. 1. 高斯混合模型(Gaussian Mixture ...
- 贝叶斯来理解高斯混合模型GMM
最近学习基础算法<统计学习方法>,看到利用EM算法估计高斯混合模型(GMM)的时候,发现利用贝叶斯的来理解高斯混合模型的应用其实非常合适. 首先,假设对于贝叶斯比较熟悉,对高斯分布也熟悉. ...
- 高斯混合模型 GMM
本文将涉及到用 EM 算法来求解 GMM 模型,文中会涉及几个统计学的概念,这里先罗列出来: 方差:用来描述数据的离散或波动程度. \[var(X) = \frac{\sum_{i=1}^N( X_ ...
- 基本算法思想Java实现的详细代码
基本算法思想Java实现的详细代码 算法是一个程序的灵魂,一个好的算法往往可以化繁为简,高效的求解问题.在程序设计中算法是独立于语言的,无论使用哪一种语言都可以使用这些算法,本文笔者将以Java语言为 ...
- Spark2.0机器学习系列之10: 聚类(高斯混合模型 GMM)
在Spark2.0版本中(不是基于RDD API的MLlib),共有四种聚类方法: (1)K-means (2)Latent Dirichlet allocation (LDA) ...
随机推荐
- 创建在“system.net/defaultProxy”配置节中指定的Web代理时出错解决办法。
出现这种问题会有很多原因,大致解决方法 方法1:在CMD下输入netsh winsock reset命令 简单来说netsh winsock reset命令含义是重置 Winsock 目录.如果一台机 ...
- 【struts2】<package>的配置
<package>元素可以把逻辑上相关的一组Action.Result.Intercepter等元素封装起来,形成一个独立的模块,package可以继承其他的package,也可以作为父包 ...
- [转]MVC 框架教程
Spring web MVC 框架提供了模型-视图-控制的体系结构和可以用来开发灵活.松散耦合的 web 应用程序的组件.MVC 模式导致了应用程序的不同方面(输入逻辑.业务逻辑和 UI 逻辑)的分离 ...
- postman 脚本学习
pm的脚本断言库默认似乎是集成chaijs的.所以重点也要掌握chaijs的用法,其实和其他断言库类似.玩着玩着就会了.推荐看看 简书 chaijs 中文文档 传送门: # pm 脚本的教程 http ...
- Python 字符串过滤
需求: str1 = " """<div class="m_wrap clearfix"><ul class=" ...
- 解决iPad/iPhone下手机号码会自动被加上a标签的问题
解决办法: 在页面顶部<head></head>中增加: <meta name="format-detection" content="te ...
- JAVA日期查询:季度、月份、星期等时间信息
package com.stt.dateChange; import java.text.SimpleDateFormat; import java.util.Calendar; import jav ...
- PHP的生成器、yield和协程
虽然之前就接触了PHP的yield关键字和与之对应的生成器,但是一直没有场景去使用它,就一直没有对它上心的研究.不过公司的框架是基于php的协程实现,觉得有必要深入的瞅瞅了. 由于之前对于生成器接触不 ...
- 修改hadoop FileUtil.java,解决权限检查的问题
在Hadoop Eclipse开发环境搭建这篇文章中,第15.)中提到权限相关的异常,如下: 15/01/30 10:08:17 WARN util.NativeCodeLoader: Una ...
- Mongodb是用Sum 和Where条件
Sum 按照条件求和 db.aa.aggregate([ { $group: { _id: null, total: { $sum: "$value" } } }, //$valu ...