MXNet是一个支持多种编程语言的机器学习库,使用MXNet可以方便地实现机器学习算法,尤其是深度神经网络。通过嵌入在宿主语言中,它将声明式符号表达与命令式张量计算相结合。它提供自动求导以计算梯度。MXNet具有高效的计算和存储操作,可运行在从移动设备到分布式GPU集群的各种异构系统上。MXNet的源代码可以在这里获得。

本文描述了MXNet的API设计和系统实现,并解释了如何以统一的方式处理符号表达式和张量操作的嵌入。我们的初步实验揭示了使用多GPU机器训练的大规模深度神经网络的应是有希望的。

1. Introduction

随着机器学习算法的发展,结构化和计算的复杂性逐渐成为机器学习系统设计和实现的挑战。在机器学习系统的设计与实现中,常见的编程范式包括命令式声明式。在命令式编程范式中,用户指定计算机”如何”执行计算;在声明式编程范式中,用户主要关注“要做什么”。比如,Matlab就是一种命令式编程,而Caffe和CXXNet等就是声明式编程。当然,有时这两种编程范式的界限也比较模糊,如Tensorflow和Theano可以看作两种编程范式的混合体。

与编程范式问题相关的是如何执行计算。执行可以是确定的,其结果立即在同一线程上返回;也可以是异步或延迟的,其语句被收集并在发布到可用设备之前首先作为中间表示转换为数据流图。确定性执行是限制性的(例如并行矩阵乘法),而异步/延迟执行时会在数据流图的实例范围内自动识别所有的并行性。

通过不同编程范式和执行方式的组合,我们就有了很多不同且有效的系统设计方法。例如,Minerva结合了命令式编程范式和异步执行方法;而Theano采用声明式方法,实现全局的图感知优化;相反的是,Caffe和CXXNet采用声明式编程范式和确定性执行方法。表1比较了命令式编程范式与声明式编程范式。

MXNet结合了不同编程范式和执行方法的优点。声明式编程范式为全局计算图提供清晰的边界,可以发现更多优化机会;而命令式编程提供了更多的灵活性。在深度学习领域,声明式编程范式在指定神经网络的计算结构时非常有用;而命令式编程对参数更新和交互式debug而言更为自然。

尽管支持多种语言并结合了不同的编程范例,但我们能够将执行融合到同一个后端引擎中。后端引擎跟踪计算图和命令操作之间的数据依赖关系,并进行有效地联合调度。 我们积极地减少内存占用,尽可能地执行就地更新和内存空间重用。 最后,我们设计了一个紧凑的通信API,以便改动很少的代码就可以使MXNet程序在多台机器上运行。

表2对现阶段几个机器学习系统进行比较,可以看出,MXNet的优势还是比较大的。此外,MXNet还是一个轻量级的机器学习框架,源代码只有5万行左右。

2. Programming Interface

2.1. Symbol: Declarative Symbolic Expressions

MXNet使用多输出的符号表达式Symbol来声明一个计算图。每个符号表达式都是由不同的操作构成,每个操作可以接受接收几个变量作为输入,并产生一至多个变量作为输出。变量可以是空闲的直到我们为它绑定一个值,亦或是其他表达式的输出。图2展示了一个通过连接变量实现的多层感知机。

为了对符号表达式进行评估,我们需要将一个空闲的变量与输入数据绑定并声明必要的输出。除了评估(前向传播),符号表达式同样支持自动求导(即反向传播)。除此之外,Symbol同样提供关于读取、保存和可视化等相关函数,这里不再赘述。

2.2. NDArray: Imperative Tensor Computation

MXNet通过提供命令式张量计算接口NDArray来弥补声明式符号表达式与宿主语言之间的空白。图3是使用GPU计算矩阵乘法的例子。

NDArray可以和Symbol无缝合作,例如,我们可以使用如下代码实现梯度下降算法:

while(1) {
net.forward_backward();
net.w -= eta * net.g;
};

上面的实现与只使用Symbol的实现具有相近的效率,但是代码看起来更加简洁。之所以这样,是因为MXNet对NDArray使用了延迟计算,并且后端引擎可以无误地处理二者之间的数据依赖。

2.3. KVStore: Data Synchronization Over Devices

KVStore是一个支持多机数据同步的分布式键值存储系统。它支持两种操作:将一个键值对从某个设备推送到存储系统,或者从存储系统中拉取某个健对应的值。此外,用户可以定义如何在存储系统中合并推送的键值对。最后,MXNet通过模型一致性来减小偏差。

下面的代码通过数据并行实现了分布式梯度下降算法:

while(1){
kv.pull(net.w);
net.forward_backward();
kv.push(net.g);
}

上面的代码中,我们向KVStore注册权值更新函数,每个工作节点重复地从参数服务器拉取最新的权值,并向其推送本地新计算出来的梯度。

同样地,上面这种混合编程模式与单纯的声明式编程具有相近的效率,原因同2.2节。

2.4. Other Modules

MXNet中附带了一些工具库,包括打包工具、数据迭代器以及数据预处理工具等,常见的优化算法,如随机梯度下降,均在training模块中实现。

3. Implementation

3.1. Computation Graph



执行绑定操作后的符号表达式代表了一个计算图。图4展示了图2中多层感知机中的前向与反向传播。在执行计算之前,MXNet会转换计算图以优化效率并为内部变量分配内存。

3.1.2. Graph Optimization

关于计算图优化,我们有以下几点。首先,我们只需要获得绑定期间指定的输出所需的子图。 例如,在预测中我们仅需要前向图;而为了从内部层提取特征,我们可以略过后面几层。只计算子图,我们可以减少大部分计算量。其次,多个操作可以分类汇总成单个操作。例如,\(a\times b +1\)可以被替换成单个的BLAS或GPU调用。最后,我们可以人工实现一些大型操作,如神经网络的某一层。

3.1.3. Memory Allocation

在内存分配方面,注意到每个变量的生存周期,即创建与最后一次调用之间的时间段,都是计算图已知的。因此,我们可以为生存周期不交叉的变量重用内存。但关键的是,理想的分配策略需要\(O(n^2)\)的时间复杂度,其中\(n\)是变量个数。

我们提出了两种拥有线性时间复杂度的启发式内存分配策略。第一种算法叫做\(inplace\),模拟遍历图的过程,并保留到目前为止未使用的依赖节点的引用计数器。如果计数器归零,则回收内存。第二种算法叫做\(co-share\),允许两个节点在非并行执行时共享一块内存,这会产生一个额外的依赖性约束。特别地,在每次进行调度时,我们在图的候选路径中找到最长路径并执行所需的内存分配。

3.2. Dependency Engine

在MXNet中,每个资源单元(包括NDArray,随机数生成器等)都使用唯一标记注册到后端引擎。然后,通过指定所需的资源标记,将任一操作(例如矩阵操作或数据通信)推送到引擎中。如果依赖性问题已被解决,引擎会持续调度推送的操作以便执行。由于通常存在多个计算资源,例如CPU,GPU和内存/PCIe总线,因此引擎通过多线程来调度操作以实现并行化和更高效的资源利用率。

与大部分数据流引擎不同,MXNet的引擎将变异操作跟踪为现有资源单元。这使得能够在numpy和其他张量库中调度数组变异操作。(这里数组变异操作指的是这些操作会改变调用它们的原始数组,非变异操作则不会改变原始数组,它们总会返回一个新数组。)通过将参数更新表示为改变参数数组,使参数的内存重用变得更容易。同时,这也使得一些特殊操作的调度变得容易。比如,当生成具有相同随机种子的两个随机数时,我们可以通知引擎它们将写入种子变量,所以它们不应该并行执行,这样做有助于重现结果。

3.3. Data Communication

我们基于参数服务器架构实现了KVStore。它与前人的工作有以下两方面的不同:首先,我们使用引擎来调度KVStore操作并管理数据一致性。该策略不仅使数据同步与计算紧密协作,且大大简化了实现。其次,我们采用两级架构。 一级服务器管理单台机器内设备之间的数据同步,二级服务器管理机器间的数据同步。我们可以聚合来自一级服务器的出站数据,从而降低带宽需求;此外,机器内和机器间的同步可以使用不同的一致性模型。

4. Evaluation

4.1. Raw performance

这一部分比较了MXNet与Torch7、Caffe、Tensorflow在卷积网络基准测试中的性能。TensorFlow使用的是CUDA 7.0和CUDNN 2,而其他的框架使用的是CUDA 7.5和CUDNN 3。批量大小设置为32,所有的卷积神经网络在Nvidia GTX 980显卡上进行训练。图6是测试结果。可以看到,MXNet、Torch7和Caffe拥有相近的性能,这是因为大部分操作都由CUDA/CUDNN实现。Tensorflow比其它框架慢约2倍,可能是使用的CUDA/CUDNN版本较低的缘故。

4.2. Memory usage

图7是除输出变量外的内部变量的内存使用情况。由图可知,“inplace”和“co-share”都可以有效地减少内存开销。通过将这两种方法结合,我们可以在所有网络的训练过程中降低2倍的内存开销,在预测过程中降低4倍的内存开销。举例来说,对于最耗费内存的VGG模型,训练过程中也只不过使用了(相对于模型本身大小)额外的16M内存空间。

4.3. Scalability

为了测量可扩展性,我们在Amazon EC2 g2.8x实例上进行实验,每个实例拥有4块Nvidia GK104 GPU和10G的以太网。我们在ILSVRC12数据集上训练包含批量归一化(Batch Normalization)的GoogleNet,固定学习率为\(0.05\),动量为\(0.9\),权重衰减系数为\(10^{-4}\),每个GPU在一个批次中读取36张图片。

收敛结果如图8所示。可以看出,与单机训练相比,分布式训练在开始时收敛较慢,但在10次数据传递后表现优异。1台机器和10台机器的平均数据传输成本分别为14K和1.4K秒。因此,该实验获得了超线性加速比。

总的来说,MXNet是一个非常优秀的开源深度学习框架。在MXNet中可以有多种方式实现深度学习模型,如NDArraySymbolGluon等。最近Keras好像也开始支持MXNet作为后端,后续可以关注一下。

CoRR 2015 | MXNet: A Flexible and Efficient Machine Learning Library for Heterogeneous Distributed Systems的更多相关文章

  1. [翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

    [翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed ...

  2. Machine Learning Library (MLlib) Guide, BOOKS

    download.microsoft.com/download/0/9/6/096170E9-23A2.../9780735698178.pdf   Microsoft Azure Essential ...

  3. 在QT中引用Shark Machine Learning library

    最近因为项目需要,看了看机器学习方面的东西.Google一番,发现Shark正是朕需要的东西.于是准备按官方文档来使用它了.但是官方文档只有怎么生成静态库,并没有在QT里引用的sample. 废话不多 ...

  4. How do I learn machine learning?

    https://www.quora.com/How-do-I-learn-machine-learning-1?redirected_qid=6578644   How Can I Learn X? ...

  5. 【机器学习Machine Learning】资料大全

    昨天总结了深度学习的资料,今天把机器学习的资料也总结一下(友情提示:有些网站需要"科学上网"^_^) 推荐几本好书: 1.Pattern Recognition and Machi ...

  6. 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】

    转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...

  7. 机器学习(Machine Learning)与深度学习(Deep Learning)资料汇总

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  8. SOME USEFUL MACHINE LEARNING LIBRARIES.

    from: http://www.erogol.com/broad-view-machine-learning-libraries/ http://www.slideshare.net/Vincenz ...

  9. 17 Great Machine Learning Libraries

    17 Great Machine Learning Libraries 08 October 2013 After wonderful feedback on my previous post on ...

随机推荐

  1. dart系列之:手写Library,Library编写最佳实践

    目录 简介 使用part和part of src中的文件 package中的lib文件 总结 简介 Library是dart用来组织代码的一种非常有用的方式,通过定义不同的Library,可以将非常有 ...

  2. 达索CATIA许可证(License)管理使用和优化

    现下主流的V6版本CATIA,是由达索公司提供授权的浮动型License,其客户端通过企业内网从许可证服务器获得许可证,最少要有一个服务器端DS License Server提供一定数量的Licens ...

  3. Javascript中字符串常用方法

    JavaScript字符串常用方法 (1)获取相应位置的字符(charAt()) var str="你好,小小鸟!" var s=str.charAt(1) //获取到索引为1的字 ...

  4. 【刷题-LeetCode】188 Best Time to Buy and Sell Stock IV

    Best Time to Buy and Sell Stock IV Say you have an array for which the i-th element is the price of ...

  5. 使用VS Code的MySQL扩展管理数据库

    我将在本文告诉你如何用VS Code的扩展程序管理MySQL数据库,包括连接到MySQL.新建数据库和表.修改字段定义.简单的查询方法以及导入导出. 在许多情况下,我们需要随时查看数据库的记录来确保程 ...

  6. keepalived的抢占与非抢占模式

    目录 一:keepalived的抢占与非抢占模式 1.抢占模式 2.非抢占模式 二:接下来分4种情况说明 三:以上3种,只要级别高就会获取master,与state状态是无关的 一:keepalive ...

  7. 裸k8s搭建中遇到的两个坑

    在装docker的时候报错了,需要先安装selinux版本.才能安装容器. 需要按照提示安装这个包. 采用强制安装.rpm -ivh 包名字 --force --nodeps 在k8s的master上 ...

  8. K8S的安装部署以及基础知识

    Kubernetes(K8S)概述 Kubernetes又称作k8s,是Google在2014年发布的一个开源项目. 最初Google开发了一个叫Borg的系统(现在命名为Omega),来调度近20多 ...

  9. 学习JAVAWEB第十七天

    今天还是在做登陆界面,做到后台servlet了 知识点太不熟练了,还得继续做

  10. Understanding C++ Modules In C++20 (1)

    Compiling evironment: linux (ubuntu 16.04)+ gcc-10.2. The Post will clarify and discuss what modules ...