自动求导机制

本说明将概述Autograd如何工作并记录操作。了解这些并不是绝对必要的,但我们建议您熟悉它,因为它将帮助您编写更高效,更简洁的程序,并可帮助您进行调试。

从后向中排除子图

每个变量都有两个标志:requires_gradvolatile。它们都允许从梯度计算中精细地排除子图,并可以提高效率。

艾伯特(http://www.aibbt.com/)国内第一家人工智能门户

requires_grad

如果有一个单一的输入操作需要梯度,它的输出也需要梯度。相反,只有所有输入都不需要梯度,输出才不需要。如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行。

>>> x = Variable(torch.randn(5, 5))
>>> y = Variable(torch.randn(5, 5))
>>> z = Variable(torch.randn(5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True

这个标志特别有用,当您想要冻结部分模型时,或者您事先知道不会使用某些参数的梯度。例如,如果要对预先训练的CNN进行优化,只要切换冻结模型中的requires_grad标志就足够了,直到计算到最后一层才会保存中间缓冲区,其中的仿射变换将使用需要梯度的权重并且网络的输出也将需要它们。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)

# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

volatile

纯粹的inference模式下推荐使用volatile,当你确定你甚至不会调用.backward()时。它比任何其他自动求导的设置更有效——它将使用绝对最小的内存来评估模型。volatile也决定了require_grad is False

volatile不同于require_grad的传递。如果一个操作甚至只有有一个volatile的输入,它的输出也将是volatileVolatility比“不需要梯度”更容易传递——只需要一个volatile的输入即可得到一个volatile的输出,相对的,需要所有的输入“不需要梯度”才能得到不需要梯度的输出。使用volatile标志,您不需要更改模型参数的任何设置来用于inference。创建一个volatile的输入就够了,这将保证不会保存中间状态。

>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad
True
>>> model(volatile_input).requires_grad
False
>>> model(volatile_input).volatile
True
>>> model(volatile_input).creator is None
True

自动求导如何编码历史信息

每个变量都有一个.creator属性,它指向把它作为输出的函数。这是一个由Function对象作为节点组成的有向无环图(DAG)的入口点,它们之间的引用就是图的边。每次执行一个操作时,一个表示它的新Function就被实例化,它的forward()方法被调用,并且它输出的Variable的创建者被设置为这个Function。然后,通过跟踪从任何变量到叶节点的路径,可以重建创建数据的操作序列,并自动计算梯度。

需要注意的一点是,整个图在每次迭代时都是从头开始重新创建的,这就允许使用任意的Python控制流语句,这样可以在每次迭代时改变图的整体形状和大小。在启动训练之前不必对所有可能的路径进行编码—— what you run is what you differentiate.

Variable上的In-place操作

在自动求导中支持in-place操作是一件很困难的事情,我们在大多数情况下都不鼓励使用它们。Autograd的缓冲区释放和重用非常高效,并且很少场合下in-place操作能实际上明显降低内存的使用量。除非您在内存压力很大的情况下,否则您可能永远不需要使用它们。

限制in-place操作适用性主要有两个原因:

1.覆盖梯度计算所需的值。这就是为什么变量不支持log_。它的梯度公式需要原始输入,而虽然通过计算反向操作可以重新创建它,但在数值上是不稳定的,并且需要额外的工作,这往往会与使用这些功能的目的相悖。

2.每个in-place操作实际上需要实现重写计算图。不合适的版本只需分配新对象并保留对旧图的引用,而in-place操作则需要将所有输入的creator更改为表示此操作的Function。这就比较棘手,特别是如果有许多变量引用相同的存储(例如通过索引或转置创建的),并且如果被修改输入的存储被任何其他Variable引用,则in-place函数实际上会抛出错误。

In-place正确性检查

每个变量保留有version counter,它每次都会递增,当在任何操作中被使用时。当Function保存任何用于后向的tensor时,还会保存其包含变量的version counter。一旦访问self.saved_tensors,它将被检查,如果它大于保存的值,则会引起错误。

PyTorch官方中文文档:自动求导机制的更多相关文章

  1. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  2. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

  3. PyTorch官方中文文档:PyTorch中文文档

    PyTorch中文文档 PyTorch是使用GPU和CPU优化的深度学习张量库. 说明 自动求导机制 CUDA语义 扩展PyTorch 多进程最佳实践 序列化语义 Package参考 torch to ...

  4. Pytorch学习(一)—— 自动求导机制

    现在对 CNN 有了一定的了解,同时在 GitHub 上找了几个 examples 来学习,对网络的搭建有了笼统地认识,但是发现有好多基础 pytorch 的知识需要补习,所以慢慢从官网 API进行学 ...

  5. PyTorch官方中文文档:torch

    torch 包 torch 包含了多维张量的数据结构以及基于其上的多种数学操作.另外,它也提供了多种工具,其中一些可以更有效地对张量和任意类型进行序列化. 它有CUDA 的对应实现,可以在NVIDIA ...

  6. PyTorch官方中文文档:torch.optim

    torch.optim torch.optim是一个实现了各种优化算法的库.大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法. 如何使用optimizer 为了使用t ...

  7. PyTorch官方中文文档:torch.Tensor

    torch.Tensor torch.Tensor是一种包含单一数据类型元素的多维矩阵. Torch定义了七种CPU tensor类型和八种GPU tensor类型: Data tyoe CPU te ...

  8. Pytorch Autograd (自动求导机制)

    Pytorch Autograd (自动求导机制) Introduce Pytorch Autograd库 (自动求导机制) 是训练神经网络时,反向误差传播(BP)算法的核心. 本文通过logisti ...

  9. ReactNative官方中文文档0.21

    整理了一份ReactNative0.21中文文档,提供给需要的reactnative爱好者.ReactNative0.21中文文档.chm  百度盘下载:ReactNative0.21中文文档 来源: ...

随机推荐

  1. 走进JavaScript——重拾函数

    创建函数 通过构造器的方式来创建函数,最后一个参数为函数体其他为形参 new Function('a','b','alert(a)') /* function anonymous(a,b) { ale ...

  2. redis新手入门,摸不着头脑可以看看<三>——lrange分页

    看了几天 redis开发与运维,写了个小demo练练手,直接上代码. 1.首先是数据库,本地要有redis,具体的如何安装redis,官网下个就好了,sososo. 2.启动redis 注意启动命令. ...

  3. CentOS 7 搭建基于携程Apollo(阿波罗)配置中心单机模式

    Apollo(阿波罗)是携程框架部门研发的配置管理平台,能够集中化管理应用不同环境.不同集群的配置,配置修改后能够实时推送到应用端,并且具备规范的权限.流程治理等特性.服务端基于Spring Boot ...

  4. iOS 9 HTTPS 的配置

    方法有两种: (1)废话少说直接上图: (2)右击info.plist 文件 open as ->source code 在里面注入如下代码就行了(位置不固定,但要在指定的文件夹选项里) < ...

  5. Struts2 基本的ResultType 【学习笔记】

    在struts2-core.jar/struts-default.xml中,我们可以找到关于result-type的一些配置信息,从中可以看出struts2组件默认为我们提供了这 些result-ty ...

  6. iOS程序闪退的原因以及处理办法

    iOS程序闪退是一种比较常见的现象.闪退的情况很多,造成程序闪退的原因也很多. ================================启动时闪退======================= ...

  7. python开发concurent.furtrue模块:concurent.furtrue的多进程与多线程&协程

    一,concurent.furtrue进程池和线程池 1.1 concurent.furtrue 开启进程,多进程&线程,多线程 # concurrent.futures创建并行的任务 # 进 ...

  8. CodeForces - 740C

    这题是思维考察.由于区间个数可能会很多,暴力完全没法下手.首先要明确区间长度最小的就决定了最后的答案,因为最小区间必须要要从0开始到区间长度减1才能满足让mex最大.接下来就是考虑如何填充数组才能让所 ...

  9. scrapy 中日志的使用

    我在后台调试 在后台调试scrapy spider的时候,总是觉得后台命令窗口 打印的东西太多了不便于观察日志,因此需要一个日志文件记录信息,这样以后会 方便查找问题. 分两种方法吧. 1.简单粗暴. ...

  10. 使用基于Android网络通信的OkHttp库实现Get和Post方式简单操作服务器JSON格式数据

     目录 前言 1 Get方式和Post方式接口说明 2 OkHttp库简单介绍及环境配置 3 具体实现 前言 本文具体实现思路和大部分代码参考自<第一行代码>第2版,作者:郭霖:但是文中讲 ...