5 自动微分

求导是几乎所有深度学习优化算法的关键步骤。虽然求导的计算很简单,只需要一些基本的微积分。但对于复杂的模型,手工进行更新是一件很痛苦的事情(而且经常容易出错)。深度学习框架通过自动计算导数,即自动微分(automatic differentiation)来加快求导。实际中,根据设计好的模型,系统会构建一个计算图(computational graph),来跟踪计算是哪些数据通过哪些操作组合起来产生输出。自动微分使系统能够随后反向传播梯度。这里,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。

5.1 一个简单的例子

作为一个演示例子,假设我们想对函数$ y = 2\mathbf{x}^T\mathbf{x}$关于列向量x求导。首先,我们创建变量x并为其分配一个初始值。

import torch
x = torch.arange(4.0)
x
tensor([0., 1., 2., 3.])

在我们计算y关于x的梯度之前,需要一个地方来存储梯度。重要的是,我们不会在每次对一个参数求导时都分配新的内存。因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。注意,一个标量函数关于向量x的梯度是向量,并且与x具有相同的形状。

x.requires_grad_(True) # 等价于x=torch.arange(4.0,requires_grad=True)
x.grad # 默认值是None

现在计算y。

y = 2 * torch.dot(x, x)
y
tensor(28., grad_fn=<MulBackward0>)

x是一个长度为4的向量,计算x和x的点积,得到了我们赋值给y的标量输出。接下来,通过调用反向传播函数来自动计算y关于x每个分量的梯度,并打印这些梯度。

y.backward()
x.grad
tensor([ 0.,  4.,  8., 12.])

函数$ y = 2\mathbf{x}^T\mathbf{x}\(关于\)x$的梯度应为\(4x\)。让我们快速验证这个梯度是否计算正确。

x.grad == 4 * x
tensor([True, True, True, True])

现在计算x的另一个函数。

# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
x.grad
tensor([1., 1., 1., 1.])

5.2 非标量变量的反向传播

当y不是标量时,向量y关于向量x的导数的最自然解释是一个矩阵。对于高阶和高维的y和x,求导的结果可以是一个高阶张量。

然而,虽然这些更奇特的对象确实出现在高级机器学习中(包括深度学习中),但当调用向量的反向计算时,我们通常会试图计算一批训练样本中每个组成部分的损失函数的导数。这里,我们的目的不是计算微分矩阵,而是单独计算批量中每个样本的偏导数之和。

# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x
# 等价于y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad
tensor([0., 2., 4., 6.])

5.3 分离计算

有时,我们希望将某些计算移动到记录的计算图之外。例如,假设y是作为x的函数计算的,而z则是作为y和x的函数计算的。想象一下,我们想计算z关于x的梯度,但由于某种原因,希望将y视为一个常数,并且只考虑到x在y被计算后发挥的作用。

这里可以分离y来返回一个新变量u,该变量与y具有相同的值,但丢弃计算图中如何计算y的任何信息。换句话说,梯度不会向后流经u到x。因此,下面的反向传播函数计算\(z=u*x\)关于x的偏导数,同时将u作为常数处理,而不是\(z=x*x*x\)关于x的偏导数。

x.grad.zero_()
y = x * x
u = y.detach()
z = u * x
z.sum().backward()
x.grad == u,u
(tensor([True, True, True, True]), tensor([0., 1., 4., 9.]))

由于记录了y的计算结果,我们可以随后在y上调用反向传播,得到y=xx关于的x的导数,即2x。

x.grad.zero_()
y.sum().backward()
x.grad == 2 * x
tensor([True, True, True, True])

5.4 Python控制流的梯度计算

使用自动微分的一个好处是:即使构建函数的计算图需要通过Python控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度。在下面的代码中,while循环的迭代次数和if语句的结果都取决于输入a的值。

def f(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = 100 * b
return c

计算梯度。

a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()

我们现在可以分析上面定义的f函数。请注意,它在其输入a中是分段线性的。换言之,对于任何a,存在某个常量标量k,使得f(a)=k*a,其中k的值取决于输入a,因此可以用d/a验证梯度是否正确。

a.grad == d / a
tensor(True)

【pytorch学习】之自动微分的更多相关文章

  1. pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分

    参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...

  2. PyTorch自动微分基本原理

    序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...

  3. PyTorch 自动微分示例

    PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...

  4. PyTorch 自动微分

    PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...

  5. 自动微分(AD)学习笔记

    1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...

  6. Pytorch学习笔记(一)---- 基础语法

    书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...

  7. 【pytorch】pytorch学习笔记(一)

    原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...

  8. Pytorch学习笔记(一)——简介

    一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...

  9. PyTorch 学习

    PyTorch torch.autograd模块 深度学习的算法本质上是通过反向传播求导数, PyTorch的autograd模块实现了此功能, 在Tensor上的所有操作, autograd都会为它 ...

  10. MindSpore:自动微分

    MindSpore:自动微分 作为一款「全场景 AI 框架」,MindSpore 是人工智能解决方案的重要组成部分,与 TensorFlow.PyTorch.PaddlePaddle 等流行深度学习框 ...

随机推荐

  1. MySQL 双主集群搭建

    搭建 MySQL 双主集群涉及多个配置步骤,以及对于可能出现的问题的理解和解决.下面将详细说明搭建过程的每个步骤. 前提条件 环境准备:准备两台服务器(物理或虚拟),并确保它们可以互相通信(例如,通过 ...

  2. git clone error: RPC failed; curl 18 transfer closed with outstanding read data remaining

    备忘 git clone比较大的工程时,出现这种错误:error: RPC failed; curl 18 transfer closed with outstanding read data rem ...

  3. day07-Java方法01

    Java方法01 1.什么是方法? Java是语句的集合,它们在一起执行一个功能 方法是解决一类问题的步骤的有序集合 方法包含于类或者对象中 方法在程序中被创建,在其他地方被引用 设计方法的原则:方法 ...

  4. 标记SA_RESTART的作用

    在程序执行的过程中,有时候会收到信号,我们可以捕捉信号并执行信号处理函数,信号注册函数里有一个struct sigaction的结构体,其中有一个sa_flags的成员,如果sa_flags |= S ...

  5. Spring Boot学习日记13

    学习引入Thymeleaf Thymeleaf 官网:https://www.thymeleaf.org/ Thymeleaf 在Github 的主页:https://github.com/thyme ...

  6. 记录--关于前端的音频可视化-Web Audio

    这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 背景 最近听音乐的时候,看到各种动效,突然好奇这些音频数据是如何获取并展示出来的,于是花了几天功夫去研究相关的内容,这里只是给大家一些代码 ...

  7. 【Nginx】如何使用自签CA配置HTTPS加密反向代理访问?看了这篇我会了!!

    写在前面 随着互联网的发展,很多公司和个人越来越重视网络的安全性,越来越多的公司采用HTTPS协议来代替了HTTP协议.为何说HTTPS协议比HTTP协议安全呢?小伙伴们自行百度吧!我就不说了.今天, ...

  8. KingbaseES 查看函数中最耗时的sql

    测试 创建测试环境所需表及函数 create table test1(id int); INSERT INTO test1(id) VALUES (generate_series(1, 10000)) ...

  9. C++设计模式 - 门面模式(Facade)

    接口隔离模式 在组件构建过程中,某些接口之间直接的依赖常常会带来很多问题.甚至根本无法实现.采用添加一层间接(稳定)接口,来隔离本来互相紧密关联的接口是一种常见的解决方案. 典型模式 Facade P ...

  10. 学习Source Generators之IncrementalValueProvider

    前面我们使用了IIncrementalGenerator来生成代码,接下来我们来详细了解下IIncrementalGenerator的核心部分IncrementalValueProvider. 介绍 ...