详解聚类算法Kmeans的两大优化——mini-batch和Kmeans++
本文始发于个人公众号:TechFlow,原创不易,求个关注
今天是机器学习专题的第13篇文章,我们来看下Kmeans算法的优化。
在上一篇文章当中我们一起学习了Kmeans这个聚类算法,在算法的最后我们提出了一个问题:Kmeans算法虽然效果不错,但是每一次迭代都需要遍历全量的数据,一旦数据量过大,由于计算复杂度过大迭代的次数过多,会导致收敛速度非常慢。
想想看,如果我们是在面试当中遇到的这个问题,我们事先并不知道正解,我们应该怎么回答呢?
还是老套路,我们在回答问题之前,先来分析问题。问题是收敛速度慢,计算复杂度高。计算复杂度高的原因我们也知道了,一个是因为样本过大,另一个是因为迭代次数过多。所以显然,我们想要改进这个问题,应该从这两点入手。
这两点是问题的关键点,针对这两点我们其实可以想出很多种优化和改进的方法。也就是说这是一个开放性问题,相比标准答案,推导和思考问题的思路更加重要。相反,如果我们抓不住关键点,那么回答也会跑偏,这就是为什么我在面试的时候,有些候选人会回答使用分布式系统或者是增加资源加速计算,或者是换一种其他的算法的原因。
也就是说分析问题和解决问题的思路过程,比解决方法本身更加重要。
下面,我们就上面提到的两个关键点各介绍一个优化方法。
mini batch
mini batch的思想非常朴素,既然全体样本当中数据量太大,会使得我们迭代的时间过长,那么我们缩小数据规模行不行?
那怎么减小规模呢,很简单,我们随机从整体当中做一个抽样,选取出一小部分数据来代替整体。这样我们人为地缩小样本的规模,不就可以提升迭代的速度了?
通过抽样我们的确可以提升迭代的效率,但是这样能保证正确性吗?
这个问题很好回答,我们只需要简单做个实验就可以证明。
我们利用上周开发的并没有经过任何优化的代码,并且将生成的样本的数量增加到五万,从下面的这张图我们可以看出,朴素的Kmeans足足用了37.2秒才完成了计算。我们得到的聚类结果如下:
接着我们通过numpy下的random.choice,从中随机选择1000条样本,我们对比一下前后的耗时和结果。
我们再来看下两次聚类的中心,从图片上来看两者误差极小,我们打印出坐标来观察,误差在0.05以内,可以说是非常接近了。
虽然mini batch的原理说穿了一钱不值,但是它的的确确非常重要,不仅重要而且在机器学习领域广为使用。在大数据的场景下,几乎所有模型都需要做mini batch优化。
但是我们不禁有一个问题,这个方案全靠随机,看起来非常不靠谱,会不会出现我们选出来的结果偏差特别大的情况,比如刚好都在一个簇当中?从理论上来看,这当然是可能的,所以为了谨慎起见,我们可以重复多次采样,再对计算到的类簇坐标计算均值,直到簇中心趋于稳定为止。或者可以人工设置迭代次数,直到满足迭代次数要求时停止。
Kmeans ++
如果说mini batch是一种通用的方法,并且看起来有些儿戏的话,那么下面要介绍的方法则要硬核许多。这个方法直接在Kmeans算法本身上做优化因此被称为Kmeans++。
前文当中我们已经说过了,想要优化Kmeans算法的效率问题,大概有两个入手点。一个是样本数量太大,另一个是迭代次数过多。刚才我们介绍的mini batch针对的是样本数量过多的情况,Kmeans++的方法则是针对迭代次数。我们通过某种方法降低收敛需要的迭代次数,从而达到快速收敛的目的。
这个思路很明确,但是操作却不简单,迭代次数和收敛效果是相关的。也就是说在达到收敛之前,迭代次数是不能减少的,否则就会导致不收敛。而且聚类问题和分类问题不同,我们在分类问题当中有一个明确的损失函数用来优化。在我们使用梯度下降法的时候,还可以将梯度前的学习率设置得稍稍大一些,从而加快收敛的速度。但是聚类问题不同,尤其是Kmeans算法,我们的依次迭代,坐标变换的值是通过求平均坐标也就是质心的坐标得到的。除非我们修改迭代的逻辑,否则没办法加快迭代。
我们从算法运作的思路出发的确会得到这个结论,这个结论也是没问题的,但是有问题的是收敛的速度除了取决于每次迭代的变化率之外,还有另外一个重要的指标。就是迭代起始的位置。
也就是说我们是从怎样的情况开始收敛的,显然如果我们的初始状态离最终的收敛状态越近,那么收敛需要的迭代次数就越少,所以我们这个优化算法的目标就是想办法找到一个足够接近收敛结果的起始状态。这个思路应该也不难想通,但是这当中藏着一个巨大的疑问,我们在训练的时候并不知道收敛的状态是什么,又怎么能判断起始状态距离收敛结果的远近呢?
显然直接走是走不通的,我们需要迂回一下。
我们来分析一下,其实可以得到很多结论。首先,如果我们随机选择K个样本点作为起始的簇中心效果比随机K个坐标点更好。原因也很简单,因为我们随机坐标对应的是在最大和最小值框成的矩形面积当中选择K个点,而我们从样本当中选K个点的范围则要小得多。我们可以单纯从面积的占比就可以看得出来。由于样本具有聚集性,我们在样本当中选择起始状态,选到接近类簇的可能性要比随机选大得多。
但是还有一个小问题,比如说在上面的例子当中类簇是3,我们随机选择3个样本作为起始状态。但是问题来了,如果我们刚好选的3个点在一个类簇当中怎么办,那样到收敛状态不也需要很久吗?
这个问题的确是存在的,我们要避免选到同一个簇中点的情况。但是由于我们并不知道样本的分布情况,怎么来判断呢?
这个时候需要用到聚类的另一个性质,我们再来观察一下上面的图:
我们可以发现,簇是有向心性的。也就是说在同一个簇附近的点都会被纳入这个簇的范围内,反过来说就是两个离得远的点属于不同簇的可能性比离得近的大。
Kmeans++的思路正是基于上面的这两点,我们将目前已经想到的洞见整理一下,就可以得到算法原理了。
算法原理
首先,其实的簇中心是我们通过在样本当中随机得到的。不过我们并不是一次性随机K个,而是只随机1个。
接着,我们要从生下的n-1个点当中再随机出一个点来做下一个簇中心。但是我们的随机不是盲目的,我们希望设计一个机制,使得距离所有簇中心越远的点被选中的概率越大,离得越近被随机到的概率越小。
我们重复上述的过程,直到一共选出了K个簇中心为止。
轮盘法
我们来看一下如何根据权重来确定概率,实现这点的算法有很多,其中比较简单的是轮盘法。这个算法应该源于赌博或者是抽奖,原理也非常相似。
我们或多或少都玩过超市或者是其他场景下的转盘抽奖,在抽奖当中有一个指针一直保持不动。我们转动转盘,当转盘停下的时候,指针所指向的位置就是抽奖的结果。
我们都知道命中结果的概率和轮盘上对应的面积有关,面积越大抽中的概率也就越大,否则抽中的概率越小。
我们用公式表示一下,对于每一个点被选中的概率是:
其中是每个点到所有类簇的最短距离,表示点被选中作为类簇中心的概率。
轮盘法其实就是一个模拟转盘抽奖的过程,只不过我们用数组模拟了转盘。我们把转盘的扇形拉平,拉成条状,原来的每个扇形就对应了一个区间。扇形的面积就对应了区间的长度,显然长度越长,抽中的概率越大。然后我们来进行抽奖,我们用区间的长度总和乘上一个0-1区间内的数。
我们找到这个结果落在的区间,就是这次轮盘抽中的结果。这样我们就实现了控制随机每个结果的概率。
在上面这张图当中,我们随机出来的值是0.68,然后我们每一次减去区间长度,最后落到的区间,就是我们随机得到的结果。
总结
明白了轮盘算法之后,整个Kmeans++的思路已经是一览无余了。也就是说我们把抽取类簇中心类比成了轮盘抽奖,我们利用轮盘抽取K个样本来作为初始的类簇中心。从而尽可能地减少迭代次数,逼近最终的结果。
那么,这样的方法究竟有没有效果呢?
同样,我们通过实验来证明,首先我们来写出代码。我们需要一个辅助函数用来计算某个样本和已经选好的簇中心之间的最小距离,我们要用这个距离来做轮盘算法。
这个函数很简单,只是计算距离,取最小值而已:
def get_cloest_dist(point, centroids):
# 首先赋值成无穷大,依次递减
min_dist = math.inf
for centroid in centroids:
dist = calculateDistance(point, centroid)
if dist < min_dist:
min_dist = dist
return min_dist
接着就是用轮盘法选出K个中心,首先我们先随机选一个,然后再根据距离这个中心的举例用轮盘法选下一个,依次类推,直到选满K个中心为止。
import math
import random
def kmeans_plus(dataset, k):
clusters = []
n = dataset.shape[0]
# 首先先选出一个中心点
rdx = np.random.choice(range(n), 1)
# np.squeeze去除多余的括号
clusters.append(np.squeeze(dataset[rdx]).tolist())
d = [0 for _ in range(len(dataset))]
for _ in range(1, k):
tot = 0
# 计算当前样本到已有簇中心的最小距离
for i, point in enumerate(dataset):
d[i] = get_cloest_dist(point, clusters)
tot += d[i]
# random.random()返回一个0-1之间的小数
# 总数乘上它就表示我们随机转了轮盘
tot *= random.random()
# 轮盘法选择下一个簇中心
for i, di in enumerate(d):
tot -= di
if tot > 0:
continue
clusters.append(np.squeeze(dataset[i]).tolist())
break
return np.mat(clusters)
最后,我们把图画出来看下效果:
上图当中白色的点表示最后收敛的位置,红色的X表示我们用Kmeans++计算得到的起始位置,可以发现距离最终的结果已经非常接近了。显然,我们只需要很少几次迭代就可以达到收敛状态。
当然Kmeans++本身也具有随机性,并不一定每一次随机得到的起始点都能有这么好的效果,但是通过策略,我们可以保证即使出现最坏的情况也不会太坏。
在实际的场景当中,如果我们真的需要对大规模的数据应用Kmeans算法,我们往往会将多种优化策略结合在一起用,并且多次计算取平均,从而保证在比较短的时间内得到一个足够好的结果。这也是机器学习领域很多算法优化的精髓,即不再追求最优解,而只要一个足够好的解。很多时候,在结果上一点小小的退让,可以将算法效率提升很多。
今天关于Kmeans的优化内容就到这些,如果觉得有所收获,请顺手点个关注或者转发吧,你们的举手之劳对我来说很重要。
详解聚类算法Kmeans的两大优化——mini-batch和Kmeans++的更多相关文章
- dll的加载方式主要分为两大类,显式和隐式链接
之前简单写过如何创建lib和dll文件及简单的使用(http://blog.csdn.net/betabin/article/details/7239200).现在先再深入点写写dll的加载方式. d ...
- 【Java知识点专项练习】之 数据类型两大类
Java的数据类型分为两大类:基本类型和引用类型: 基本类型只能保存一些常量数据,引用类型除了可以保存数据,还能提供操作这些数据的功能: 为了操作基本类型的数据,java也对它们进行了封装, 得到八个 ...
- 关于fmri数据分析的两大类,四种方法
关于fmri数据分析的两大类,四种方法: 数据驱动: tca:其实这种方法,主要是提取时间维的特征.如果用它来进行数据的分析,则必须要利用其他的数据方法,比如结合ICA. ica:作为pca的一般化实 ...
- Java入门到精通——框架篇之Spring源码分析Spring两大核心类
一.Spring核心类概述. Spring里面有两个最核心的类这是Spring实现最重要的部分. 1.DefaultListableBeanFactory 这个类位于Beans项目下的org.spri ...
- Access Violation分成两大类:运行期和设计期(很全的解释)
用Delphi开发程序时,我们可以把遇到的Access Violation分成两大类:运行期和设计期. 一.设计期的Access Violation 1.硬件原因 在启动或关闭Delphi IDE以 ...
- 聚类K-Means和大数据集的Mini Batch K-Means算法
import numpy as np from sklearn.datasets import make_blobs from sklearn.cluster import KMeans from s ...
- CSS的选择器分为两大类
CSS的选择器分为两大类:基本选择题和扩展选择器. 基本选择器: 标签选择器:针对一类标签 ID选择器:针对某一个特定的标签使用 类选择器:针对你想要的所有标签使用 通用选择器(通配符):针对所有的标 ...
- 03 Java的数据类型分为两大类 类型转换 八大基本类型
数据类型 强类型语言:要求变量的使用要严格符合规定,所有变量都必须先定义后才能使用 Java的数据类型分为两大类 基本类型(primitive type) 数值类型 整数类型 byte占1个字节范围: ...
- java的数据类型分为两大类
java的数据类型分为两大类 基本类型(primitive type) 数据类型 整数类型 byte占一个字节范围:-128-127 short占两个字节范围:-32768-32767 int占四个字 ...
随机推荐
- js 实现手风琴
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- 常用JS代码片段
1.隐藏部分数字,如手机号码,身份证号码 1 2 3 function (str,start,length,mask_char){ return str.replace(str.substr(star ...
- Qt类声明中Q_OBJECT的作用与报错解决
2017-06-22 周四 大雨 北京 院里 新建作图类,继承自QCUstomPlot类 因为需要同时作8张图,都要单坐标缩放的功能,因此想干脆新建一个类,继承自QCUstomPlot,把需要的功能都 ...
- 事务以及Spring的事务管理
一.什么是事务? 事务是逻辑上的一组操作,要么都执行,要么都不执行 二.事务的特性(ACID) 原子性: 事务是最小的执行单位,不允许分割.事务的原子性确保动作要么全部完成,要么完全不起作用: 一致性 ...
- Mybatis调用存储过程报错
Mybatis调用存储过程 贴码 123456 Error querying database. Cause: java.sql.SQLException: User does not have ac ...
- bp(net core)+easyui+efcore实现仓储管理系统——入库管理之二(三十八)
abp(net core)+easyui+efcore实现仓储管理系统目录 abp(net core)+easyui+efcore实现仓储管理系统——ABP总体介绍(一) abp(net core)+ ...
- flask 参数校验
校验参数是否存在,不存在返回400 @app.route('/check',methods=['POST']) def check(): values = request.get_json() req ...
- Aajx
# Ajax入门及基本开发 ## # Ajax的基本概念 >> 概念: 界面异步传输技术: 将几种技术和在一起进行开发的一种编程方式: >> 基本应用场景: > Goog ...
- 7-41 jmu-python-最佳身高 (10 分)
最佳的情侣身高差遵循着一个公式:(女方的身高)×1.09 =(男方的身高).下面就请你写个程序,为任意一位用户计算他/她的情侣的最佳身高. 输入格式: 输入第一行给出正整数N(≤10),为前来查询的用 ...
- 这些Zepto中实用的方法集
前言 时间过得可真快,转眼间2017年已去大半有余,你就说吓不吓人,这一年你成长了多少,是否荒度了很多时光,亦或者天天向上,收获满满.今天主要写一些看Zepto基础模块时,比较实用的部分内部方法,在我 ...