Spark MLlib中KMeans聚类算法的解析和应用
聚类算法是机器学习中的一种无监督学习算法,它在数据科学领域应用场景很广泛,比如基于用户购买行为、兴趣等来构建推荐系统。
核心思想可以理解为,在给定的数据集中(数据集中的每个元素有可被观察的n个属性),使用聚类算法将数据集划分为k个子集,并且要求每个子集内部的元素之间的差异度尽可能低,而不同子集元素的差异度尽可能高。简而言之,就是通过聚类算法处理给定的数据集,将具有相同或类似的属性(特征)的数据划分为一组,并且不同组之间的属性相差会比较大。
K-Means算法是聚类算法中应用比较广泛的一种聚类算法,比较容易理解且易于实现。
>> "标准" K-Means算法
KMeans算法的基本思想是随机给定K个初始簇中心,按照最邻近原则把待分类样本点分到各个簇。然后按平均法重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心的移动距离小于某个给定的值或者满足已定条件。主要分为4个步骤:
- 为要聚类的点寻找聚类中心,比如随机选择K个点作为初始聚类中心
- 计算每个点到聚类中心的距离,将每个点划分到离该点最近的聚类中去
- 计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心
- 反复执行第2步和第3步,直到聚类中心不再改变或者聚类次数达到设定迭代上限或者达到指定的容错范围
示例图:
KMeans算法在做聚类分析的过程中主要有两个难题:初始聚类中心的选择和聚类个数K的选择。
>> Spark MLlib对KMeans的实现分析
Spark MLlib针对"标准"KMeans的问题,在实现自己的KMeans上主要做了如下核心优化:
1. 选择合适的初始中心点
Spark MLlib在初始中心点的选择上,有两种算法:
随机选择:依据给的种子seed,随机选择K个随机中心点
k-means||:默认的算法
- val RANDOM = "random"
- val K_MEANS_PARALLEL = "k-means||"
2. 计算样本属于哪一个中心点时对距离计算的优化
假设中心点是(a1,b1),要计算的点是(a2,b2),那么Spark MLlib采取的计算方法是(记为lowerBoundOfSqDist):
对比欧几里得距离(记为EuclideanDist):
可轻易证明lowerBoundOfSqDist小于或等于EuclideanDist,并且计算lowerBoundOfSqDist很方便,只需处理中心点和要计算的点的L2范数。那么在实际处理中就分两种情况:
当lowerBoundOfSqDist大于"最近距离"(之前计算好的,记为bestdistance),那么可以推导欧式距离也大于bestdistance,不需要计算欧式距离,省去了很多计算工作
当lowerBoundOfSqDist小于bestdistance,则会调用fastSquaredDistance进行距离的快速计算
关于fastSquaredDistance:
- 首先计算一个精度:
- val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
- if (precisionBound1 < precision) {
- // 精度满足squared distance期望的精度
- // val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
- // 2.0 * dot(v1, v2)为2(a1*a2 + b1*b2)可以利用之前计算的L2范数
- sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
- } else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
- val dotValue = dot(v1, v2)
- sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
- val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
- (sqDist + EPSILON)
- if (precisionBound2 > precision) {
- sqDist = Vectors.sqdist(v1, v2)
- }
- } else {
- sqDist = Vectors.sqdist(v1, v2)
- }
- //精度不满足要求时,则进行Vectors.sqdist(v1, v2)的处理,即原始的距离计算
>> Spark MLlib中KMeans相关源码分析
基于mllib包下的KMeans相关源码涉及的类和方法(ml包下与下面略有不同,比如涉及到的fit方法):
- KMeans类和伴生对象
- train方法:根据设置的KMeans聚类参数,构建KMeans聚类,并执行run方法进行训练
- run方法:主要调用runAlgorithm方法进行聚类中心点等的核心计算,返回KMeansModel
initialModel:可以直接设置KMeansModel作为初始化聚类中心选择,也支持随机和k-means || 生成中心点
predict:预测样本属于哪个"类"
computeCost:通过计算数据集中所有的点到最近中心点的平方和来衡量聚类效果。一般同样的迭代次数,cost值越小,说明聚类效果越好。
注意:该方法在Spark 2.4.X版本已经过时,并且会在Spark 3.0.0被移除,具体取代方法可以查看ClusteringEvaluator
主要看一下train和runAlgorithm的核心源码:
- def train(
- // 数据样本
- data: RDD[Vector],
- // 聚类数量
- k: Int,
- // 最大迭代次数
- maxIterations: Int,
- // 初始化中心,支持"random"或者"k-means||"
- initializationMode: String,
- // 初始化时的随机种子
- seed: Long): KMeansModel = {
- new KMeans().setK(k)
- .setMaxIterations(maxIterations)
- .setInitializationMode(initializationMode)
- .setSeed(seed)
- .run(data)
- }
- /**
- * Implementation of K-Means algorithm.
- */
- private def runAlgorithm( data: RDD[VectorWithNorm],
- instr: Option[Instrumentation]): KMeansModel = {
- val sc = data.sparkContext
- val initStartTime = System.nanoTime()
- val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)
- val centers = initialModel match {
- case Some(kMeansCenters) =>
- kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
- case None =>
- if (initializationMode == KMeans.RANDOM) {
- // random
- initRandom(data)
- } else {
- // k-means||
- initKMeansParallel(data, distanceMeasureInstance)
- }
- }
- val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
- logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")
- var converged = false
- var cost = 0.0
- var iteration = 0
- val iterationStartTime = System.nanoTime()
- instr.foreach(_.logNumFeatures(centers.head.vector.size))
- // Execute iterations of Lloyd's algorithm until converged
- // Kmeans迭代执行,计算每个样本属于哪个中心点,中心点累加的样本值以及计数。然后根据中心点的所有样本数据进行中心点的更新,并且比较更新前的数值,根据两者距离判断是否完成
- //迭代次数小于最大迭代次数,并行计算的中心点还没有收敛
- while (iteration < maxIterations && !converged) {
- // 损失值累加器
- val costAccum = sc.doubleAccumulator
- // 广播中心点
- val bcCenters = sc.broadcast(centers)
- // Find the new centers
- val collected = data.mapPartitions { points =>
- // 当前聚类中心
- val thisCenters = bcCenters.value
- // 中心点的维度
- val dims = thisCenters.head.vector.size
- val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims))
- val counts = Array.fill(thisCenters.length)(0L)
- points.foreach { point =>
- // 通过当前的聚类中心点,找出最近的聚类中心点
- // findClosest是为了计算bestDistance,参考上述Spark对距离计算的优化
- val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
- costAccum.add(cost)
- distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
- counts(bestCenter) += 1
- }
- counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
- }.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
- axpy(1.0, sum2, sum1)
- (sum1, count1 + count2)
- }.collectAsMap()
- if (iteration == 0) {
- instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
- }
- val newCenters = collected.mapValues { case (sum, count) =>
- distanceMeasureInstance.centroid(sum, count)
- }
- bcCenters.destroy(blocking = false)
- // Update the cluster centers and costs
- converged = true
- newCenters.foreach { case (j, newCenter) =>
- if (converged &&
- !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) {
- // 距离大于,则说明中心点位置改变
- converged = false
- }
- // 更新中心点
- centers(j) = newCenter
- }
- cost = costAccum.value
- iteration += 1
- }
- val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
- logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")
- if (iteration == maxIterations) {
- logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
- } else {
- logInfo(s"KMeans converged in $iteration iterations.")
- }
- logInfo(s"The cost is $cost.")
- new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration)
- }
>> Spark MLlib的KMeans应用示例
1. 准备数据
- 诺丹姆吉本斯主教中学(Notre Dame-Bishop Gibbons School) 71 0 0 283047.0 13289.0
- 海景基督高中(Ocean View Christian Academy) 45 0 0 276403.0 13289.0
- 卡弗里学院(Calvary Baptist Academy) 58 0 0 227567.0 13289.0
- ...
2. 示例代码
- //将加载的rdd数据,每一条变成一个向量,整个数据集变成矩阵
- val parsedata = rdd.map { case Row(schoolid, schoolname, locationid, school_type, zs, fee, byj) =>
- //"特征因子":学校位置id,学校类型,住宿方式,学费,备用金
- val features = Array[Double](locationid.toString.toDouble, school_type.toString.toDouble, zs.toString.toDouble, fee.toString.toDouble, byj.toString.toDouble)
- //将数组变成机器学习中的向量
- Vectors.dense(features)
- }.cache() //默认缓存到内存中,可以调用persist()指定缓存到哪
- //用kmeans对样本向量进行训练得到模型
- //聚类中心
- val numclusters = List(3, 6, 9)
- //指定最大迭代次数
- val numIters = List(10, 15, 20)
- var bestModel: Option[KMeansModel] = None
- var bestCluster = 0
- var bestIter = 0
- val bestRmse = Double.MaxValue
- for (c <- numclusters; i <- numIters) {
- val model = KMeans.train(parsedata, c, i)
- //集内均方差总和(WSSSE),一般可以通过增加类簇的个数 k 来减小误差,一般越小越好(有可能出现过拟合)
- val d = model.computeCost(parsedata)
- println("选择K:" + (c, i, d))
- if (d < bestRmse) {
- bestModel = Some(model)
- bestCluster = c
- bestIter = i
- }
- }
- println("best:" + (bestCluster, bestIter, bestModel.get.computeCost(parsedata)))
- //用模型对我们的数据进行预测
- val resrdd = df.map { case Row(schoolid, schoolname, locationid, school_type, zs, fee, byj) =>
- //提取到每一行的特征值
- val features = Array[Double](locationid.toString.toDouble, school_type.toString.toDouble, zs.toString.toDouble, fee.toString.toDouble, byj.toString.toDouble)
- //将特征值转换成特征向量
- val linevector = Vectors.dense(features)
- //将向量输入model中进行预测,得到预测值
- val prediction = bestModel.get.predict(linevector)
- //返回每一行结果((sid,sname),所属类别)
- ((schoolid.toString, schoolname.toString), prediction)
- }
- //中心点
- /*val centers: Array[linalg.Vector] = model.clusterCenters
- centers.foreach(println)*/
- //按照所属"类别"分组,并根据"类别"排序,然后保存到数据库
- // saveData2Mysql是封装好的保存数据到mysql的方法
- resrdd.groupBy(_._2).sortBy(_._1).foreachPartition(saveData2Mysql(_))
上述示例只是一个简单的demo,实际应用中会更复杂,牵涉到数据的预处理,比如对数据进行量化、归一化,以及如何调参以获取最优训练模型。
推荐文章:
Spark实现推荐系统中的相似度算法
关于一些技术点的随笔记录(二)
Spark存储Parquet数据到Hive,对map、array、struct字段类型的处理
Kafka中sequence IO、PageCache、SendFile的应用详解
对Spark硬件配置的建议
关注微信公众号:大数据学习与分享,获取更对技术干货
Spark MLlib中KMeans聚类算法的解析和应用的更多相关文章
- Spark MLBase分布式机器学习系统入门:以MLlib实现Kmeans聚类算法
1.什么是MLBaseMLBase是Spark生态圈的一部分,专注于机器学习,包含三个组件:MLlib.MLI.ML Optimizer. ML Optimizer: This layer aims ...
- Matlab中K-means聚类算法的使用(K-均值聚类)
K-means聚类算法采用的是将N*P的矩阵X划分为K个类,使得类内对象之间的距离最大,而类之间的距离最小. 使用方法:Idx=Kmeans(X,K)[Idx,C]=Kmeans(X,K) [Idx, ...
- 机器学习中K-means聚类算法原理及C语言实现
本人以前主要focus在传统音频的软件开发,接触到的算法主要是音频信号处理相关的,如各种编解码算法和回声消除算法等.最近切到语音识别上,接触到的算法就变成了各种机器学习算法,如GMM等.K-means ...
- Spark中的聚类算法
Spark - Clustering 官方文档:https://spark.apache.org/docs/2.2.0/ml-clustering.html 这部分介绍MLlib中的聚类算法: 目录: ...
- 使用 Spark MLlib 做 K-means 聚类分析[转]
原文地址:https://www.ibm.com/developerworks/cn/opensource/os-cn-spark-practice4/ 引言 提起机器学习 (Machine Lear ...
- Spark MLlib KMeans 聚类算法
一.简介 KMeans 算法的基本思想是初始随机给定K个簇中心,按照最邻近原则把分类样本点分到各个簇.然后按平均法重新计算各个簇的质心,从而确定新的簇心.一直迭代,直到簇心的移动距离小于某个给定的值. ...
- MLlib 中的聚类和分类
聚类和分类是机器学习中两个常用的算法,聚类将数据分开为不同的集合,分类对新数据进行类别预测,下面将就两类算法进行介绍. 1. 聚类和分类(1)什么是聚类 聚类( Clustering)指将数据对象分组 ...
- K-means聚类算法及python代码实现
K-means聚类算法(事先数据并没有类别之分!所有的数据都是一样的) 1.概述 K-means算法是集简单和经典于一身的基于距离的聚类算法 采用距离作为相似性的评价指标,即认为两个对象的距离越近,其 ...
- 通过IDEA及hadoop平台实现k-means聚类算法
由于实验室任务方向变更,本文不再更新~ 有段时间没有操作过,发现自己忘记一些步骤了,这篇文章会记录相关步骤,并随时进行补充修改. 1 基础步骤,即相关环境部署及数据准备 数据文件类型为.csv文件,e ...
随机推荐
- Python中的enumerate函数的作用
enumerate函数是将一个可迭代对象中元素,按元素顺序每个增加一个索引值,将其组成一个索引序列,利用它可以同时获得索引和值,这样做的目的是为了将一个可迭代对象中元素组成一个"索引,值&q ...
- PyQt程序执行时报错:AttributeError: 'winTest' object has no attribute 'setCentralWidget'的解决方法
用QtDesigner设计了一个UI界面,保存在文件Ui_wintest.ui中,界面中使用了MainWindow窗口,窗口名字也叫MainWindow,用PyUIC将其转换成了 Ui_wintest ...
- jQuery笔记(一)
day01 - jQuery 学习目标: 能够说出什么是 jQuery 能够说出 jQuery 的优点 能够简单使用 jQuery 能够说出 DOM 对象和 jQuery 对象的区别 能够写出常用的 ...
- .net core 注入的几种方式
一.注册的几种类型: services.TryAddSingleton<IHttpContextAccessor, HttpContextAccessor>();//单利模式,整个应用程序 ...
- 华大MCU单片机之HC32F003/HC32F005 hc32f005_ddl_Rev1.9.0 Lite精简版库使用心得
之前几个项目开发都是用的华大HC32F003_DDL_Rev1.0.2的库函数,今年刚开始入手华大,刚开始不是很了解这个芯片,看到库能用就上手了.这个版本的库编译效率很低,16K的芯片一下就写爆了.后 ...
- 题解-CF1239D Catowice City
CF1239D Catowice City 有 \(n\) 个人和 \(n\) 只猫.有 \(m\) 对人猫友谊,即第 \(u_i\) 个人认识第 \(v_i\) 只猫,保证第 \(i\) 个人和第 ...
- spring入门学习
开发步骤: 1.导入Spring开发的基本坐标 2.编写接口和实现类 3.创建Spring核心配置文件 4.在Spring核心配置文件中配置实现类 5.使用Spring的API获得Bean实例Bean ...
- C++异常之二 基本语法
2. 异常处理的基本语法 下面是一个基本的代码例子,说明 throw.try.catch的基本用法,与 catch 的类型自动匹配: 1 #include <iostream> 2 #in ...
- 情话爬虫工具[windows版]
有没有在气氛暧昧的情况下想说点什么却又无话可说?女朋友有没有抱怨过你,只会写代码,一点都不懂情调?这次,是时候要改变她对你的看法了!一键爬取情话,情话全都躺在txt里面.想怎么玩就怎么玩!张口一句情话 ...
- JDK11 下载安装与配置环境变量
1.jdk11本身也包含jre,不需要安装jre,低版本需要安装jre 2.jdk下载地址:https://www.oracle.com/technetwork/java/javase/downloa ...