SparkMLlib-----GMM算法
Gaussian Mixture Model(GMM)是一个很流行的聚类算法。它与K-Means的很像,但是K-Means的计算结果是算出每个数据点所属的簇,而GMM是计算出这些数据点分配到各个类别的概率。与K-Means对比K-Means存在一些缺点,比如K-Means的聚类结果易受样本中的一些极值点影响。此外GMM的计算结果由于是得出一个概率,得出一个概率包含的信息量要比简单的一个结果多,对于49%和51%的发生的事件如果仅仅使用简单的50%作为阈值来分为两个类别是非常危险的。
Gaussian Mixture Model,顾名思义,它是假设数据服从高斯混合分布,或者说是从多个高斯分布中生成出来的。每个GMM由K个高斯分布组成,每个高斯分布称为一个"Component",这些Component线性加在一起就组成了GMM的概率密度函数:
使用GMM做聚类的方法,我们先使用R等工具采样数据绘出数据点分布的图观察是否符合高斯混合分布,或者直接假设我们的数据是符合高斯混合分布的,之后根据数据推算出GMM的概率分布,对应的每个高斯分布就是每个类别,因为我们已知(假设)了概率密度分布的形式,要去求出其中参数,所以是一个参数估计的过程,我们要推导出每个混合成分的参数(均值向量mu,协方差矩阵sigma,权重weight),高斯混合模型在训练时使用了极大似然估计法,最大化以下对数似然函数:
该式无法直接解析求解,因此采用了期望-最大化方法(Expectation-Maximization,EM)方法求解,具体步骤如下:
1.根据给定的K值,初始化K个多元高斯分布以及其权重;
2.根据贝叶斯定理,估计每个样本由每个成分生成的后验概率;(EM方法中的E步)
3.根据均值,协方差的定义以及2步求出的后验概率,更新均值向量、协方差矩阵和权重;(EM方法的M步)重复2~3步,直到似然函数增加值已小于收敛阈值,或达到最大迭代次数
接下来进行模型的训练与分析,我们采用了mllib包封装的GMM算法,具体代码如下
package com.xj.da.gmm import breeze.linalg.DenseVector
import breeze.numerics.sqrt
import org.apache.commons.math.stat.correlation.Covariance
import org.apache.spark.mllib.clustering.{GaussianMixture, GaussianMixtureModel}
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext} import scala.collection.mutable.ArrayBuffer /**
* author : kongcong
* number : 27
* date : 2017/7/19
*/
object GMMWithMultivariate {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
//.setMaster("local")
.setAppName("GMMWithMultivariate")
val sc = new SparkContext(conf) val rawData: RDD[String] = sc.textFile("hdfs://master:8020/home/kongc/data/query_result.csv")
//val rawData: RDD[String] = sc.textFile("data/query_result.csv")
println("count: " + rawData.count())
//println(rawData.count())
// col1, col2, status
val data: RDD[linalg.Vector] = rawData.map { line =>
val raw: Array[String] = line.split(",")
Vectors.dense(raw(0).toDouble, raw(1).toDouble, raw(4).toDouble)
}
// data.collect().take(10).foreach(println(_))
// col1, col2, status
val trainData: RDD[linalg.Vector] = rawData.map { line =>
val raw: Array[String] = line.split(",")
Vectors.dense(raw(0).toDouble, raw(1).toDouble)
}
// trainData.collect().take(10).foreach(println(_)) // 指定初始模型
// 0
val filter0: RDD[linalg.Vector] = data.filter(_.toArray(2) == 0)
println(filter0.count()) //23195
// 1
val filter1: RDD[linalg.Vector] = data.filter(_.toArray(2) == 1)
println(filter1.count()) //14602 val w1: Double = (filter0.count()/319377.toDouble)
val w2: Double = (filter1.count()/319377.toDouble)
println(s"w1 = $w1") // 均值
val m0x: Double = filter0.map(_.toArray(0)).mean()
val m0y: Double = filter0.map(_.toArray(1)).mean()
val m1x: Double = filter1.map(_.toArray(0)).mean()
val m1y: Double = filter1.map(_.toArray(1)).mean()
// 方差
val vx0: Double = filter0.map(_.toArray(0)).variance()
val vy0: Double = filter0.map(_.toArray(1)).variance()
val vx1: Double = filter1.map(_.toArray(0)).variance()
val vy1: Double = filter1.map(_.toArray(1)).variance() // 均值向量
val mu1: linalg.Vector = Vectors.dense(Array(m0x, m0y))
val mu2: linalg.Vector = Vectors.dense(Array(m1x, m1y))
println(s"mu1 : $mu1")
println(s"mu2 : $mu2") val array: RDD[Array[Double]] = rawData.map { line =>
val raw: Array[String] = line.split(",")
Array(raw(0).toDouble, raw(1).toDouble, raw(4).toDouble)
} val f0: RDD[Array[Double]] = array.filter(_(2) == 0)
val f1: RDD[Array[Double]] = array.filter(_(2) == 1)
println("f0.count:"+f0.count())
println("f1.count:"+f1.count()) // 0 x,y求协方差矩阵
val x0: RDD[Double] = f0.map(_(0))
val y0: RDD[Double] = f0.map(_(1))
//println(x0.collect().length == y0.collect().length)
// 1 x,y求协方差矩阵
val x1: RDD[Double] = f1.map(_(0))
val y1: RDD[Double] = f1.map(_(1))
val ma0: Array[Array[Double]] = Array(x0.collect(),y0.collect())
val ma1: Array[Array[Double]] = Array(x1.collect(),y1.collect()) val r0: RDD[Array[Double]] = sc.parallelize(ma0)
val r1: RDD[Array[Double]] = sc.parallelize(ma1) val rdd0: RDD[linalg.Vector] = r0.map(f => Vectors.dense(f))
val rdd1: RDD[linalg.Vector] = r1.map(f => Vectors.dense(f)) val RM0: RowMatrix = new RowMatrix(rdd0)
val RM1: RowMatrix = new RowMatrix(rdd1) // 计算协方差矩阵
//println(RM0.computeCovariance().numCols) /*val i: Double = DenseVector(1.0, 2.0, 3.0, 4.0) dot DenseVector(1.0, 1.0, 1.0, 1.0)
val c0yx: Double = i - m0x * m0y*/ val c0yx: Double = DenseVector(x0.collect()) dot DenseVector(y0.collect()) - m0x * m0y
val c1yx: Double = DenseVector(x1.collect()) dot DenseVector(y1.collect()) - m1x * m1y //cov(Vectors.dense(x0.collect()),Vectors.dense(y0.collect()))
val sigma1 = Matrices.dense(2, 2, Array(vx0, c0yx, c0yx, vy0))
val sigma2 = Matrices.dense(2, 2, Array(vx1, c1yx, c1yx, vy1))
val gmm1 = new MultivariateGaussian(mu1, sigma1)
val gmm2 = new MultivariateGaussian(mu2, sigma2) val gaussians = Array(gmm1, gmm2) // 构建一个GaussianMixtureModel需要两个参数 一个是权重数组 一个是组成混合高斯分布的每个高斯分布
val initModel = new GaussianMixtureModel(Array(w1, w2), gaussians) for (i <- 0 until initModel.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(initModel.weights(i), initModel.gaussians(i).mu, initModel.gaussians(i).sigma))
} val gaussianMixture = new GaussianMixture()
val mixtureModel = gaussianMixture
.setInitialModel(initModel)
.setK(2)
.setConvergenceTol(0.0001)
.run(trainData) val predict: RDD[Int] = mixtureModel.predict(trainData)
rawData.zip(predict).saveAsTextFile("hdfs://master:8020/home/kongc/data/out/gmm/predict2") for (i <- 0 until mixtureModel.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(mixtureModel.weights(i), mixtureModel.gaussians(i).mu, mixtureModel.gaussians(i).sigma))
} }
}
参考:http://blog.pluskid.org/?p=39
http://dblab.xmu.edu.cn/blog/1456/
SparkMLlib-----GMM算法的更多相关文章
- GMM算法k-means算法的比较
1.EM算法 GMM算法是EM算法族的一个具体例子. EM算法解决的问题是:要对数据进行聚类,假定数据服从杂合的几个概率分布,分布的具体参数未知,涉及到的随机变量有两组,其中一组可观测另一组不可观测. ...
- SparkMLlib分类算法之支持向量机
SparkMLlib分类算法之支持向量机 (一),概念 支持向量机(support vector machine)是一种分类算法,通过寻求结构化风险最小来提高学习机泛化能力,实现经验风险和置信范围的最 ...
- SparkMLlib回归算法之决策树
SparkMLlib回归算法之决策树 (一),决策树概念 1,决策树算法(ID3,C4.5 ,CART)之间的比较: 1,ID3算法在选择根节点和各内部节点中的分支属性时,采用信息增益作为评价标准.信 ...
- GMM算法的matlab程序
GMM算法的matlab程序 在“GMM算法的matlab程序(初步)”这篇文章中已经用matlab程序对iris数据库进行简单的实现,下面的程序最终的目的是求准确度. 作者:凯鲁嘎吉 - 博客园 h ...
- GMM算法的matlab程序(初步)
GMM算法的matlab程序 在https://www.cnblogs.com/kailugaji/p/9648508.html文章中已经介绍了GMM算法,现在用matlab程序实现它. 作者:凯鲁嘎 ...
- SparkMLlib分类算法之决策树学习
SparkMLlib分类算法之决策树学习 (一) 决策树的基本概念 决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风 ...
- SparkMLlib分类算法之逻辑回归算法
SparkMLlib分类算法之逻辑回归算法 (一),逻辑回归算法的概念(参考网址:http://blog.csdn.net/sinat_33761963/article/details/5169383 ...
- 机器学习——EM算法与GMM算法
目录 最大似然估计 K-means算法 EM算法 GMM算法(实际是高斯混合聚类) 中心思想:①极大似然估计 ②θ=f(θold) 此算法非常老,几乎不会问到,但思想很重要. EM的原理推导还是蛮复杂 ...
- Kmeans算法学习与SparkMlLib Kmeans算法尝试
K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一.K-means算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类.通过迭代的方法,逐次更新各聚类中心的 ...
- EM算法和GMM算法的相关推导及原理
极大似然估计 我们先从极大似然估计说起,来考虑这样的一个问题,在给定的一组样本x1,x2······xn中,已知它们来自于高斯分布N(u, σ),那么我们来试试估计参数u,σ. 首先,对于参数估计的方 ...
随机推荐
- python-散列表
散列表 简单地来说,通过某种函数关系将输入的数据映射为数字,使得数字与数据有着一一对应的关系. 其中,散列函数必须满足一定的要求: 它必须是一致的.例如,当你输入mag时得到4,那么每当输入mag时, ...
- Ubuntu环境下 matplotlib 图例中文乱码
最近做了一个最小二乘法的代码编写并用 matplotlib 绘制了一张图,但是碰到了中文乱码问题.简单搜索之后,发现有人总结出了比较好的方案,亲测可行.推荐给大家. 本文前提条件是 已经 安装好 ma ...
- Ionic3新特性--页面懒加载2加载其他组件
在第一节中,我们介绍了页面的懒加载方式,并进行了初步的分析,这里,我们将进一步介绍如何配合页面懒加载进行其他组件Component.Pipe.Directive等的模块化,和加载使用. 首先说明一点, ...
- 用redis实现TOMCAT集群下的session共享
上篇实现了 LINUX中NGINX反向代理下的TOMCAT集群(http://www.cnblogs.com/yuanjava/p/6850764.html) 这次我们在上篇的基础上实现session ...
- spring-boot开发:使用内嵌容器进行快速开发及测试
一.简述一下spring-boot微框架 1.spring-boot微框架是什么? 大家都知道,在使用spring框架进行应用开发时需要很多*.xml的初始化配置文件,而springBoot就是用来简 ...
- python编写知乎爬虫实践
爬虫的基本流程 网络爬虫的基本工作流程如下: 首先选取一部分精心挑选的种子URL 将种子URL加入任务队列 从待抓取URL队列中取出待抓取的URL,解析DNS,并且得到主机的ip,并将URL对应的网页 ...
- node.js零基础详细教程(6):mongodb数据库操作
第六章 建议学习时间4小时 课程共10章 学习方式:详细阅读,并手动实现相关代码 学习目标:此教程将教会大家 安装Node.搭建服务器.express.mysql.mongodb.编写后台业务逻辑. ...
- ELK-Kibana-01
1.Kibana介绍 Kibana 是一个设计使用和Elasticsearch配置工作的开源分析和可视化平台.可以用它进行搜索.查看.集成Elasticsearch中的数据索引.可以利用各种图 ...
- 关于JS数组的定义
关于js数组的定义的一些内容: 数组是一个对象 只用一个变量,储存多个同类型的信息 数组--连续的储存空间 数组的下标从0开始 ps:定义一个数组可以看作是一个旅馆.里面有很多小房子. 1.创建数组- ...
- 大话Python中*args和**kargs的使用
对于初学者来说,看到*args和**kargs就头大,到底它们有何用处,怎么使用?这篇文章将为你揭开可变参数的神秘面纱 1.*args 实质就是将函数传入的参数,存储在元组类型的变量args当中 de ...