终于可以引用PyTorch了!

0. 摘要

深度学习框架要么关注易用性,要么关注速度,很少同时关注二者。但PyTorch作为一个机器学习库,同时做到了这两点,说明了这两点是兼容的。

PyTorch提供了 基础必要 且 Pythonic 的编程风格,让调试更简单,并且支持许多科学计算库,同时提供硬件加速例如GPUs。

作者向我们展示:他们如何让PyTorch的关键组分协同工作的,并且提供了PyTorch实现的原则细节。作者强调:PyTorch就是一个常规的Python程序,并且完全由用户控制。

本文还在一些benchmarks测试了PyTorch的速度,并且展示了PyTorch每个独立子系统的效率。

1. 简介

许多框架,例如Caffe,TensorFlow和Theano,都是构建静态数据流图(static dataflow graph)来表示计算流程。

该静态图可以被数据的不同batch重复使用。静态图很好,可以提前让我们预知计算流程,并且可以在理论上用于提升性能和可拓展性。但另一方面,静态图的使用降低了易用性、调参体验,限制了能够被表示的计算类型。

因此,另一些框架(如Chainer)尝试用动态图机制(dynamic eager execution)处理深度学习任务。缺点是要么牺牲性能,要么使用更生涩的语言(如Torch等)。

然而,只要我们处理得当,动态图机制能够克服这些缺点。本文即介绍PyTorch。PyTorch是一个Python库,可以快速执行动态张量运算,并且配备了自动差分和GPU加速技术。最重要的是,其性能可比肩当前最快的库。

2. 背景

在科学计算领域,有4个面向深度学习的主要发展趋势:

  1. 基于数组的编程更高效。从1960年代开始,各领域专用的语言,例如MATLAB和R,将张量转换为一等对象(first-class object),再对这些张量进行操作。而这些一等对象由全面的数学算子提供支持。

  2. 自动差分技术。这项技术让用户从令人生畏的梯度计算中解脱出来。主要归功于autograd包。

  3. 开源的Python生态系统。这里有NumPy,SciPy,Pandas等。也正因如此,2014伊始,越来越多的框架支持Python界面。

  4. 硬件加速。普适的并行硬件越来越普及,例如GPUs。还有cuDNN等专用库提供支持,使得越来越多的平台 实现并受益于 硬件加速。

作者用一句话总结了以上趋势,从而解释了PyTorch的优越性:

PyTorch builds on these trends by providing an array-based programming model accelerated by GPUs and differentiable via automatic differentiation integrated in the Python ecosystem.

3. 设计原则

  1. Pythonic。PyTorch应该成为Python生态系统的一等对象。PyTorch应与Python保持连贯性和一致性,整合标准的画图、调试和数据处理工具。

  2. 用户至上。PyTorch应该把复杂的机器学习内核包裹起来,并提供没有性能损失的外接APIs。

  3. 保证易用性和兼容性。可以牺牲10%的速度,换取更好的易用性。此外,用户也可以不用预先提供的库,而自己编写代码以达到更高的性能。

  4. 不追求实现上的完美,而追求实用性。对于PyTorch的内部实现,我们不一定要追求完美;相反,为了适应AI的快速变化,为了能够搭载未来的额外的特征,我们可以简化内部实现,而牺牲一定的性能。

4. 针对易用性的核心设计

4.1 让深度学习模块不过是Python程序

机器学习发展迅速,从手写数字识别到星际争霸游戏。因此,神经网络也在不断发展,从简单的前向网络堆叠,到各式各样的复杂结构(例如包含循环和回路)。

为了适应这种快速变化,PyTorch放弃了基于图-元编程的方法(graph-metaprogramming based approach),而是保持了Python的必要编程结构。这种思路最早引用于Chainer和Dynet,而PyTorch将其拓展至深度学习的各个工作环节,例如定义网络、模型组建、数据装载、优化器运行和并行化训练。

例如我们最熟悉的layer。PyTorch将其表达为Python的类。其constructors创建并初始化其参数,forward方法处理输入激励。

但要强调的是,PyTorch并不强迫用户使用上述代码。

由于PyTorch程序是动态加载的,因此所有特征在设计过程中都是可触及的。例如打印状态,标准调试器和通用的可视化工具(如matplotlib)。用户在运行程序前,不再需要等待冗长的编译过程。最重要的是,用户可以实时观察程序的执行过程。

4.2 互用性和可拓展性

这一点使得PyTorch可以拥抱Python的生态系统。例如,PyTorch张量和NumPy数组间可以互相转换。注意,这种互换过程不涉及数据复制;二者仅仅表征对同一共享的内存区域的各自解释(objects on both sides only describe how to interpret a memory region which is shared among them)。因此,无论转换的数组有多大,花费的时间是恒定的。

不仅如此,许多核心子系统都被设计为可拓展的。例如自动差分系统,可以支持用户自定义的函数。为了实现这一点,用户只需要定义torch.autograd.Function的一个新子类,并且包含forward()backward()方法。类似的,用户还可以自定义新的数据集,而只需要继承torch.utils.data.Dataset并且包含__getitem____len__方法。而具体实现方法完全由用户决定,可以随意DIY。

用户如果觉得PyTorch的哪一个模块无法达到要求,可随意更换之。

4.3 自动差分

首先,自动差分是非常必要的,因为目前的深度学习方法大多依赖于基于梯度的优化方法。然而,Python是一种动态编程语言,在运行时允许大多数行为的变化。这与预先的差分是冲突的。

为了解决这一问题,PyTorch采用了操作过载方法(operator overloading approach),即在每一次执行运算时,建立关于运算的表示。

最后,哪怕在运行过程中张量有突变,差分也能够完成。为了保证安全性,我们为张量设置了版本控制系统(versioning system)。在突变发生时,为了避免不易察觉的性能损失,我们通常会告知用户错误发生,而不是自动予以纠正。

5. 针对高性能的PyTorch实现

用Python解释器跑深度学习算法是出了名的麻烦:例如,全局锁(global interpreter lock)会保证任意时刻只有一个进程在运行(而避免多线程)。对于那些采用静态图的框架,它们通过定制的解释器来避免这个问题。

而PyTorch从另一个角度解决了这一问题。

5.1 高效的C++核

PyTorch的大部分都是用C++实现的,以保证高性能。例如libtorch库用来表示张量的数据结构、CPU和GPU操作以及基本的并行算子。还有自动差分系统等。这样,我们就可以实现多线程执行PyTorch的核心操作。

这带来了一个额外的好处。当我们认为Python不行时,我们还可以用C++的一等bindings。

5.2 分而治之

PyTorch将control(如程序回路、分支)和数据流(如张量及其操作)严格分开。控制流通过Python和C++代码控制,并在CPU上执行;算子既可以在CPU上运行,也可以在GPU上运行。

5.3 定制的张量分配缓存

几乎每一个算子都要动态分配一个输出张量。因此,针对这种动态内存分配的速度优化就显得尤为重要。PyTorch对此进行了优化。

PyTorch还对内存分配进行了优化,例如要求每次分配512字节倍数的内存。

还有很多,参见5.3节。

5.4 多进程

为了解决Python的全局锁问题,Python社区建立了标准的多进程模块。PyTorch进一步将其发展为torch.multiprocessing模块。

5.5 Reference counting

为了最大化效率,用户通常会用上所有GPU内存。因此PyTorch的内存管理显得尤为重要。

对于某张量,动态语义库可能无法事先知晓其未来会被如何使用。因此,垃圾回收就成了一种常用的内存管理策略。PyTorch采用了一种特别的策略,快速释放空闲GPU内存。详情参见5.5节。

实验略。

Paper | PyTorch: An Imperative Style, High-Performance Deep Learning Library的更多相关文章

  1. 英特尔深度学习框架BigDL——a distributed deep learning library for Apache Spark

    BigDL: Distributed Deep Learning on Apache Spark What is BigDL? BigDL is a distributed deep learning ...

  2. What are some good books/papers for learning deep learning?

    What's the most effective way to get started with deep learning?       29 Answers     Yoshua Bengio, ...

  3. Machine and Deep Learning with Python

    Machine and Deep Learning with Python Education Tutorials and courses Supervised learning superstiti ...

  4. Can deep learning help you find the perfect girl?

    Can deep learning help you find the perfect girl? One of the first things I did when I moved to Mont ...

  5. 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks

    In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...

  6. (转) Awesome Deep Learning

    Awesome Deep Learning  Table of Contents Free Online Books Courses Videos and Lectures Papers Tutori ...

  7. Top Deep Learning Projects in github

    Top Deep Learning Projects A list of popular github projects related to deep learning (ranked by sta ...

  8. [DEEP LEARNING An MIT Press book in preparation]Deep Learning for AI

    动人的DL我们有六个月的时间,积累了一定的经验,实验,也DL有了一些自己的想法和理解.曾经想扩大和加深DL相关方面的一些知识. 然后看到了一个MIT按有关的对出版物DL图书http://www.iro ...

  9. Deep Learning Libraries by Language

    Deep Learning Libraries by Language Tweet         Python Theano is a python library for defining and ...

随机推荐

  1. Mac Pro 2017款自带php与用brew重装PHP后的地址

    mac pro 2017款自带PHP与apache位置: [apache]apache配置文件 :/etc/apache2/httpd.confDocumentRoot : /Library/WebS ...

  2. Java学习笔记 -- 头代码

    每次写Java程序都会忘记这个main代码怎么写,特意把他写下来,之后忘了还可以温故而知新. 程序猿们请千万不要鄙视我o(╯□╰)o public static void main(String[] ...

  3. 09-Node.js学习笔记-异步编程

    同步API,异步API 同步API:只有当前API执行完成后,才能继续执行下一个API console.log('before'); console.log('after'); 异步API:当前API ...

  4. 第2次作业-titanic数据集练习

    一.读入titanic.xlsx文件,按照教材示例步骤,完成数据清洗. titanic数据集包含11个特征,分别是: Survived:0代表死亡,1代表存活Pclass:乘客所持票类,有三种值(1, ...

  5. IT兄弟连 HTML5教程 多媒体应用 新增多媒体播放元素

    在HTML5之前,要在网站上展示视频.音频.动画等多媒体信息,除了使用第三方自主开发的播放器,使用最多的工具应该算是Flash了,但是它们都需要在浏览器中安装各种插件才能使用,有时速度很慢.HTML5 ...

  6. Idea中新建yml不显示叶子形状的原因

    IntelliJ IDEA 2019.2.4 x64 (版本),不显示叶子形状,导致写配置无法自动提示(自动提示请安装插件)Spring Assistant 先看一下Editor--->File ...

  7. 使用matplotlib.pyplot中scatter()绘制散点图

    1.二维散点图 二维散点图的函数原型: matplotlib.pyplot.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=Non ...

  8. python frozenset

    1. 一旦初始化,并不可以改变 2. 可以作为字典的键值

  9. 多线程六 同步容器&并发容器

    同步容器(使用的是synchronized,并且不一定是百分百安全) 本篇续 -- 线程之间的通信 ,介绍java提供的并发集合,既然正确的使用wait和notify比较困难,java平台为我们提供了 ...

  10. C#_服务器EXCEL模板文件导出

    A-1:EXCEL模板导出 非常简单,将EXCEL模板上传到项目中后,将其浏览URL保存下来(excelUrl),然后: window.location.href="http://local ...