Databricks孟祥瑞:ALS 在 Spark MLlib 中的实现

发表于2015-05-07 21:58| 10255次阅读| 来源《程序员》电子刊| 9 条评论| 作者孟祥瑞
摘要:MLlib在1.3中添加了不少机器学习及数据挖掘算法:研究主题分布的LDA、估计点集分布的GMM、提取频繁项集的 FP-growth等等。本文主要聚焦ALS的实现及其在1.3中的提升。

深受用户喜爱的大数据处理平台 Apache Spark 1.3 于前不久发布,MLlib 作为 Spark 负责机器学习 (ML) 的核心组件在 1.3 中添加了不少机器学习及数据挖掘的算法:研究主题分布的 latent Dirichlet allocation (LDA)、估计点集分布的高斯混合模型 (GMM)、提取频繁项集的 FP-growth、生成图聚类的 power iteration clustering (PIC)等等。呃,这些我们暂放一边不谈。MLlib 还添加了 Python 的 ML 流水线接口、模型基于 Parquet 的存储、以及分布式分块矩阵模型。呃,这些我们暂放另一边,也不谈……

那我们谈些什么?我想借这个机会聊聊 ALS 算法和其在 MLlib 中的实现,特别是在 Spark 1.3 中的改进。希望可以起到抛砖引玉的作用,让更多的人关注在 Spark 上实现机器学习算法会遇到的算法重构和运行效率问题。

ALS 是什么?

ALS 是交替最小二乘 (alternating least squares)的简称。在机器学习的上下文中,ALS
特指使用交替最小二乘求解的一个协同推荐算法。它通过观察到的所有用户给产品的打分,来推断每个用户的喜好并向用户推荐适合的产品。举个例子,我们考虑下面这个包含用户打分的打分矩阵:

这个矩阵的每一行代表一个用户 (u1,u2,...,u9)、每一列代表一个产品 (v1,v2,…,v9)。用户的打分在 1-9
之间。我们只显示观察到的打分。那么问题来了:用户
u5 给产品 v4 的打分大概会是多少?粗略地观察一下……这不是数独么?是的,而且如果按照数独来做的话(比较耗时、不推荐),用户
u5 一定会给产品
v4 打 9
分。为什么看上去选择很多,答案却是唯一的?因为数独的规则很强,每添加一条规则,就让整个系统的自由度下降一个量级。当我们要满足所有的规则时,整个系统的自由度已然降为一了。现在请努力地把上面的数独题想成一个打分矩阵。如果我们不添加任何条件的话,打分之间是相互独立的,我们没有任何依据来推断
u5 给 v4
的打分。所以在这个打分矩阵的基础上,我们需要提出一个限制其自由度的合理假设,使得我们可以通过观察已有打分猜测未知打分。

ALS 的核心就是下面这个假设:打分矩阵是近似低秩的。换句话说,一个
的打分矩阵 A 可以用两个小矩阵的乘积来近似:。这样我们就把整个系统的自由度从一下降到了。当然,我们也可以随便提一个假设把自由度直接降到一。我们接下来就聊聊为什么
ALS
的低秩假设是合理的。世上万千事物,人们的喜好各不相同。但描述一个人的喜好经常是在一个抽象的低维空间上进行的,并不需要把其喜欢的事物一一列出。举个例子,我喜欢看略带黑色幽默的警匪电影,那么大家根据这个描述就知道我大概会喜欢昆汀的《低俗小说》、《落水狗》和韦家辉的《一个字头的诞生》。这些电影都符合我对自己喜好的描述,也就是说他们在这个抽象的低维空间的投影和我的喜好相似。再抽象一些,把人们的喜好和电影的特征都投到这个低维空间,一个人的喜好映射到了一个低维向量,一个电影的特征变成了纬度相同的向量,那么这个人和这个电影的相似度就可以表述成这两个向量之间的内积。 我们把打分理解成相似度,那么打分矩阵A就可以由用户喜好矩阵和产品特征矩阵的乘积来近似了。

我们大致解释了
ALS 低秩假设的合理性,接下来的问题是怎么选这个抽象的低维空间。这个低维空间要能够有效的区分事物,如果我说我喜欢看 16:9
宽屏的彩色立体声电影,那一定是我真心不想透露我的喜好。但 ALS
是很难从实质上理解“黑色幽默”和“彩色”的区别是什么的,它需要一个更明确的可以量化的目标,这就是重构误差。既然我们的假设是打分矩阵A可以通过来近似,那么一个最直接的可以量化的目标就是通过U,V重构A所产生的误差。在 ALS 里,我们使用 Frobenius范数,,来量化重构误差,就是每个元素的重构误差的平方和。这里存在一个问题,我们只观察到部分打分,A 中的大量未知元正是我们想推断的,所以这个重构误差是包含未知数的。解决方案很简单很暴力:就只看对已知打分的重构误差吧。所以 ALS 的优化目标是:。这里 R 指观察到的 (用户,产品)集。

我们把一个协同推荐的问题通过低秩假设成功转变成了一个优化问题。下面要讨论的内容很显然:这个优化问题怎么解?其实答案已经在
ALS 的名字里给出——交替最小二乘。ALS
的目标函数不是凸的,而且变量互相耦合在一起,所以它并不算好解。但如果我们把用户特征矩阵U和产品特征矩阵V固定其一,这个问题立刻变成了一个凸的而且可拆分的问题。比如我们固定U,那么目标函数就可以写成。其中关于每个产品特征的部分是独立的,也就是说固定U求我们只需要最小化就好了,这个问题就是经典的最小二乘问题。所谓“交替”,就是指我们先随机生成然后固定它求解,再固定求解,这样交替进行下去。因为每步迭代都会降低重构误差,并且误差是有下界的,所以 ALS 一定会收敛。但由于问题是非凸的,ALS 并不保证会收敛到全局最优解。但在实际应用中,ALS 对初始点不是很敏感,是不是全局最优解造成的影响并不大。

ALS 在 MLlib 中的实现

ALS 的算法介绍完了,但我们距离一个好的分布式实现还有一段距离。因为 ALS 每步迭代中优化问题的目标函数可以拆分成互相独立的最小二乘子问题,所以从计算的角度来看 ALS 是适合分布式求解的。但通过观察一个子问题,我们会发现求解 vj是需要知道上一步得到的每个已知打分对应的的值。如果分布式求解,我们可能会需要从其它节点获取这些数据,从而产生通信费用。和很多机器学习算法的分布实现类似,ALS 的分布式实现主要关心的是计算复杂度和通信复杂度。

计算复杂度比较容易估算,所以我们先讲。求解一个的最小二乘问题的复杂度是。当固定U求V时,我们一共有n个最小二乘子问题,所以总的复杂度是,其中 nnz 指观察到的打分数量。再加上固定V求U的复杂度,一步完整的迭代需要的计算量就是。MLlib 中的 ALS 实现通过法方程 (normal equation) 求解最小二乘子问题,需要的空间复杂度是。最小二乘有很多种求解方法,这里为什么选法方程以及其求解精度我们就略去不谈了。

通信复杂度是分布式实现一个算法时一定要重点考虑的问题,稍有不慎就会导致十倍甚至百倍的效率损失。我们先看一下最坏的情况:假设求解时所需要的用户特征都需要从其它节点获取,并且子问题之间完全独立。例如图1所示,求解 v1 需要获取 u1 和 u2,求解 v2 需要获取 u1、u2 和 u3等等。这种假设下每步迭代需要交换的数据量是,比输入数据要高一个量级。虽然还是比每步迭代需要的计算量低一个量级,但由于k一般不大,而且做一个浮点运算比通过网络传输一个字节要快很多,所以在这种情况下通信时间会远远超出计算时间。

图1:通信复杂度示例图

为了在
Spark 上提供一个高效的 ALS 实现,我们需要合理的设计数据分区和 RDD 缓存来减少数据交换。从上面的图我们会观察到,如果计算 v1 和
v2 是在同一个分区上完成的,我们只需要把 u1 和 u2 一次发给这个分区,然后在计算 v1 和 v2 的时时候在本机内存直接读取 u1 和
u2 即可。 这样就省掉了不必要的数据传输。图2描述了如何在分区的情况下通过
U来求解V,注意节点之间的数据交换量减少了。使用这种分区结构,我们需要在原始打分数据的基础上额外保存一些信息。在 P1,我们要知道把 u1 发给
Q1 和 Q2,把 u2 发给 Q1。我们可以查看和 u1 相关联的所有产品来确定需要把 u1
发给谁,但每次迭代都扫一遍数据是很不划算的。所以在 MLlib 的实现中我们只计算一次这个信息,然后把结果通过 RDD
缓存起来重复使用。这部分数据我们在代码里称作 OutBlock。在 Q1,我们需要知道 v1
和哪些用户向量有关联及其对应的打分,从而构建最小二乘问题并求解。这部分数据不仅包含原始打分数据,还包含从每个用户分区收到的向量排序信息,我们在代码里称作
InBlock。所以从 U 求解 V,我们需要通过用户的 OutBlock 信息把用户向量发给产品分区,然后通过产品的 InBlock
信息构建最小二乘问题并求解。从 V 求解 U,我们需要产品的 OutBlock 信息和用户的 InBlock 信息。所有的 InBlock 和
OutBlock 信息在迭代过程中都通过 RDD 缓存。大家会发现原始的打分数据其实在用户的 InBlock 和产品的 InBlock
各存了一份,但分区方式不同,这么做可以避免在迭代过程中对原始数据的交换。

图2:数据分区设计后的通信复杂度

接下来我们讨论一下
InBlock 的数据结构。以 Q1 为例,我们要知道所有关于 v1 和 v2 的所有打分:(v1, u1, a11),(v2, u1,
a12), (v1, u2, a21), (v2, u2, a22), (v2, u3, a32)。但是把这些打分直接按照 Tuple
存的话会有几个问题。首先是空间的额外开销,每个 Tuple 实例都需要一个指针,而每个 Tuple 所存的数据不过是两个 ID
和一个打分,非常不划算。而且存储大量的 Tuple 实例会降低 Java 垃圾回收效率。所以我们使用三个原始数组来存 InBlock
信息:([v1, v2, v1, v2, v2], [u1, u1, u2, u2, u3], [a11, a12, a21, a22,
a32])。这样不仅大幅减少了实例数量,还有效地利用了连续内存。但还存在一个问题,当我们求解 v1 时,我们要通过所有和 v1 关联的用户向量
(u1, u2) 来构建最小二乘问题。这里有两个选择:a) 扫一遍 InBlock 信息,同时对所有的产品构建对应的最小二乘问题;b)
对于每一个产品,扫描 InBlock 信息,构建并求解其对应的最小二乘问题。之前提到过通过法方程求解一个最小二乘问题的空间复杂度是,所以方法 a 所需要的空间是,比存储产品向量所需空间高出一个量级。而方法
b 也不算理想,因为要对 InBlock 信息多次扫描。在Spark  1.3 里,我们首先将 InBlock 信息按照产品 ID 排序:
([v1, v1, v2, v2, v2],[u1, u2, u1, u2, u3], [a11, a21, a12, a22,
a32])。这样我们只需要顺序扫描一遍数据,就可以逐个创建最小二乘问题并求解,这样所需的空间降到了。在
Java 里将三个很大的原始数组根据某一个排序并不是件很容易的事情。我们使用 Spark 中的 TimSort 实现来排序,这也是在
Petabyte Sort 比赛中 Databricks
小组所使用的排序算法。排序后的另外一个好处是我们可以把数据进一步压缩。对于每一个产品,我们只需纪录它所对应的打分开始和结束的位置即可。InBlock
就变成了这样:([v1, v2], [0, 2, 5], [u1, u2, u1, u2, u3], [a11, a21, a12, a22,
a32])。其中 [0, 2] 指 v1 对应的打分的区间是 [0, 2),[2, 5] 指 v2 对应的打分的区间是 [2,
5)。通过一系列的调整,我们在内存使用、时间和空间复杂度上都达到了较好的效果。

在 Spark 1.3 中,我们还对 ALS
做了一些其它的改进。为了避免不必要的 map 查询和支持多种 ID 类型,我们在实现中并没有直接在 InBlock 中存储用户的原始
ID,而只记录了需要的用户向量应该是哪个分区发过来的第几个。比如在 Q1 分区 ,u2 就是从 P1 发过来的第二个,而 u2 原始的 ID
是多少并不影响问题的求解。我们把分区和索引信息编码到一个整型里,在高位存分区 ID,在低位存对应分区的索引,在空间上也尽量做到不浪费。此外,因为
ALS 对求解的精度要求不高,为了减少数据交换量,我们把Spark  1.2 中使用的 Double 改成了 Float
来存储用户和产品向量。还有一些优化我们就不一一提及了,有兴趣的读者可以参看 ALS 源码以及相关的 JIRA。

通过对实现的改进,新版的
ALS 在速度、资源和稳定性上都有大幅度提升。下图是我们在 Amazon Reviews 数据集上做的一些比较。测试使用 16 个
m3.2xlarge 节点的 Amazon EC2 集群。可以看到,ALS 在速度上对比 Spark 1.2 有 2-4x
的提升,而且表现出了更好的伸缩性。我们还在更大的集群上测试了一个大概有 500 亿打分的数据集,ALS 表示无压力。

小结

本文简单介绍了
ALS 算法和其在 MLlib 中的实现。希望通过分析 ALS
可以让大家直观的看到,同样的算法,在分布式系统上实现时,不同的选择会带来性能上巨大的差异。大家在 Spark
上实现机器学习算法时,不妨先分析一下空间、时间、和通信复杂度,然后合理的利用 Spark 的分区和缓存机制做到高效的实现。希望在 2015
年看到更多的人加入 MLlib 的开发和维护,让 MLlib 的算法更好更快更易用!

孟祥瑞,Databricks 软件工程师、Apache Spark PMC成员 ,Apache Spark Committer。


转载:Databricks孟祥瑞:ALS 在 Spark MLlib 中的实现的更多相关文章

  1. Apache Spark源码走读之23 -- Spark MLLib中拟牛顿法L-BFGS的源码实现

    欢迎转载,转载请注明出处,徽沪一郎. 概要 本文就拟牛顿法L-BFGS的由来做一个简要的回顾,然后就其在spark mllib中的实现进行源码走读. 拟牛顿法 数学原理 代码实现 L-BFGS算法中使 ...

  2. Spark MLlib中的OneHot哑变量实践

    在机器学习中,线性回归和逻辑回归算是最基础入门的算法,很多书籍都把他们作为第一个入门算法进行介绍.除了本身的公式之外,逻辑回归和线性回归还有一些必须要了解的内容.一个很常用的知识点就是虚拟变量(也叫做 ...

  3. Spark MLlib中KMeans聚类算法的解析和应用

    聚类算法是机器学习中的一种无监督学习算法,它在数据科学领域应用场景很广泛,比如基于用户购买行为.兴趣等来构建推荐系统. 核心思想可以理解为,在给定的数据集中(数据集中的每个元素有可被观察的n个属性), ...

  4. spark MLLib的基础统计部分学习

    参考学习链接:http://www.itnose.net/detail/6269425.html 机器学习相关算法,建议初学者去看看斯坦福的机器学习课程视频:http://open.163.com/s ...

  5. 使用 Spark MLlib 做 K-means 聚类分析[转]

    原文地址:https://www.ibm.com/developerworks/cn/opensource/os-cn-spark-practice4/ 引言 提起机器学习 (Machine Lear ...

  6. 使用Spark MLlib进行情感分析

    使用Spark MLlib进行情感分析             使用Spark MLlib进行情感分析 一.实验说明 在当今这个互联网时代,人们对于各种事情的舆论观点都散布在各种社交网络平台或新闻提要 ...

  7. 在Java Web中使用Spark MLlib训练的模型

    PMML是一种通用的配置文件,只要遵循标准的配置文件,就可以在Spark中训练机器学习模型,然后再web接口端去使用.目前应用最广的就是基于Jpmml来加载模型在javaweb中应用,这样就可以实现跨 ...

  8. Spark机器学习中ml和mllib中矩阵、向量

    1:Spark ML与Spark MLLIB区别? Spark MLlib是面向RDD数据抽象的编程工具类库,现在已经逐渐不再被Spark团队支持,逐渐转向Spark ML库,Spark ML是面向D ...

  9. FP-Growth in Spark MLLib

    并行FP-Growth算法思路 上图的单线程形成的FP-Tree. 分布式算法事实上是对FP-Tree进行分割,分而治之 首先,假设我们只关心...|c这个conditional transactio ...

随机推荐

  1. 2)NET CORE特性与优势

    先看看netcore有哪些特性,哪些优点,与.net frameworkd 差异吧: l  跨平台: 可以在 Windows.macOS 和 Linux 操作系统上运行. l  跨体系结构保持一致:  ...

  2. ASP.NET SignalR 系列(五)之群组推送

    在上一章介绍了 一对一推送的方式,这章重点介绍下群组推送和多人推送 群组主要就是用到了方法:Groups.Add(Context.ConnectionId, groupName); 将不同的连接id加 ...

  3. linux限定用户或组对磁盘空间的使用

    实验环境 环境:centos7.3 ,一块磁盘sdb分一个分区sdb1. 安装磁盘配额支持软件 yum install quota 制作文件系统,并以支持配额功能的方式挂载文件系统 mkfs.ext4 ...

  4. python3基础之“术语表(2)”

    51.编程: 让计算机执行的指令. 52.代码: 让计算机执行的命令. 53.底层编程语言: 与高级语言相比,更接近二进制的语言. 54.高级编程语言: 读起来像英语的易于理解的语言. 55.汇编语言 ...

  5. angular http interceptors 拦截器使用分享

    拦截器 在开始创建拦截器之前,一定要了解 $q和延期承诺api 出于全局错误处理,身份验证或请求的任何同步或异步预处理或响应的后处理目的,希望能够在将请求移交给服务器之前拦截请求,并在将请求移交给服务 ...

  6. django 自定义身份认证

    自定义身份认证: Django 自带的认证系统足够应付大多数情况,但你或许不打算使用现成的认证系统.定制自己的项目的权限系统需要了解哪些一些关键点,即Django中哪些部分是能够扩展或替换的.这个文档 ...

  7. C语言判断字符串是否是 hex string的代码

    把写内容过程中经常用到的一些内容段备份一下,如下内容内容是关于C语言判断字符串是否是 hex string的内容. { static unsigned int hex2bin[256]={0}; me ...

  8. ar 解压一个.a文件报错: xxx.a is a fat file (use libtool(1) or lipo(1) and ar(1) on it)

    Linux  使用终端指令 ar x /Users/apple/Desktop/libWC_LIB_SDKT.a解压一个文件 报错如图所示: 是因为该.a文件包含了多个cpu架构,比如armv7,ar ...

  9. 如何使用Fiddler抓取APP接口和微信授权网页源代码

    Fiddler,一个抓包神器,不仅可以通过手机访问APP抓取接口甚至一些数据,还可以抓取微信授权网页的代码. 下载安装 1. 下载地址(官网):  https://www.telerik.com/do ...

  10. 【问题】如何在Linux与Windows间共享文件

    实验环境 Linux LSB Version: :core-4.1-amd64:core-4.1-noarch Distributor ID: CentOS Description: CentOS L ...