自动求梯度

在深度学习中,我们经常需要对函数求梯度(gradient)。PyTorch提供的autograd包能够根据输入和前向传播过程自动构建计算图,并执行反向传播。本节将介绍如何使用autograd包来进行自动求梯度的有关操作。

概念

上一节介绍的Tensor是这个包的核心类,如果将其属性.requires_grad设置为True,它将开始追踪(track)在其上的所有操作(这样就可以利用链式法则进行梯度传播了)。完成计算后,可以调用.backward()来完成所有梯度计算。此Tensor的梯度将累积到.grad属性中。

注意在y.backward()时,如果y是标量,则不需要为backward()传入任何参数;否则,需要传入一个与y同形的Tensor

如果不想要被继续追踪,可以调用.detach()将其从追踪记录中分离出来,这样就可以防止将来的计算被追踪,这样梯度就传不过去了。此外,还可以用with torch.no_grad()将不想被追踪的操作代码块包裹起来,这种方法在评估模型的时候很常用,因为在评估模型时,我们并不需要计算可训练参数(requires_grad=True)的梯度。

Function是另外一个很重要的类。TensorFunction互相结合就可以构建一个记录有整个计算过程的有向无环图(DAG)。每个Tensor都有一个.grad_fn属性,该属性即创建该TensorFunction, 就是说该Tensor是不是通过某些运算得到的,若是,则grad_fn返回一个与这些运算相关的对象,否则是None。

下面通过一些例子来理解这些概念。

Tensor

创建一个Tensor并设置requires_grad=True:

x = torch.ones(2, 2, requires_grad=True)
print(x)
print(x.grad_fn)

输出:

tensor([[1., 1.],
[1., 1.]], requires_grad=True)
None

再做一下运算操作:

y = x + 2
print(y)
print(y.grad_fn)

输出:

tensor([[3., 3.],
[3., 3.]], grad_fn=<AddBackward>)
<AddBackward object at 0x1100477b8>

注意x是直接创建的,所以它没有grad_fn, 而y是通过一个加法操作创建的,所以它有一个为<AddBackward>grad_fn

像x这种直接创建的称为叶子节点,叶子节点对应的grad_fnNone

print(x.is_leaf, y.is_leaf) # True False

再来点复杂度运算操作:

z = y * y * 3
out = z.mean()
print(z, out)

输出:

tensor([[27., 27.],
[27., 27.]], grad_fn=<MulBackward>) tensor(27., grad_fn=<MeanBackward1>)

通过.requires_grad_()来用in-place的方式改变requires_grad属性:

a = torch.randn(2, 2) # 缺失情况下默认 requires_grad = False
a = ((a * 3) / (a - 1))
print(a.requires_grad) # False
a.requires_grad_(True)
print(a.requires_grad) # True
b = (a * a).sum()
print(b.grad_fn)

输出:

False
True
<SumBackward0 object at 0x118f50cc0>

梯度

因为out是一个标量,所以调用backward()时不需要指定求导变量:

out.backward() # 等价于 out.backward(torch.tensor(1.))

我们来看看out关于x的梯度 \(\frac{d(out)}{dx}\):

print(x.grad)

输出:

tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])

我们令out为 \(o\) , 因为

\[o=\frac14\sum_{i=1}^4z_i=\frac14\sum_{i=1}^43(x_i+2)^2
\]

所以

\[\frac{\partial{o}}{\partial{x_i}}\bigr\rvert_{x_i=1}=\frac{9}{2}=4.5
\]

所以上面的输出是正确的。

数学上,如果有一个函数值和自变量都为向量的函数 \(\vec{y}=f(\vec{x})\), 那么 \(\vec{y}\) 关于 \(\vec{x}\) 的梯度就是一个雅可比矩阵(Jacobian matrix):

\[J=\left(\begin{array}{ccc}
\frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{1}}{\partial x_{n}}\\
\vdots & \ddots & \vdots\\
\frac{\partial y_{m}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}
\end{array}\right)
\]

torch.autograd这个包就是用来计算一些雅克比矩阵的乘积的。例如,如果 \(v\) 是一个标量函数的 \(l=g\left(\vec{y}\right)\) 的梯度:

\[v=\left(\begin{array}{ccc}\frac{\partial l}{\partial y_{1}} & \cdots & \frac{\partial l}{\partial y_{m}}\end{array}\right)
\]

那么根据链式法则我们有 \(l\) 关于 \(\vec{x}\) 的雅克比矩阵就为:

\[v J=\left(\begin{array}{ccc}\frac{\partial l}{\partial y_{1}} & \cdots & \frac{\partial l}{\partial y_{m}}\end{array}\right) \left(\begin{array}{ccc}
\frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{1}}{\partial x_{n}}\\
\vdots & \ddots & \vdots\\
\frac{\partial y_{m}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}
\end{array}\right)=\left(\begin{array}{ccc}\frac{\partial l}{\partial x_{1}} & \cdots & \frac{\partial l}{\partial x_{n}}\end{array}\right)
\]

注意:grad在反向传播过程中是累加的(accumulated),这意味着每一次运行反向传播,梯度都会累加之前的梯度,所以一般在反向传播之前需把梯度清零。

# 再来反向传播一次,注意grad是累加的
out2 = x.sum()
out2.backward()
print(x.grad) out3 = x.sum()
x.grad.data.zero_()
out3.backward()
print(x.grad)

输出:

tensor([[5.5000, 5.5000],
[5.5000, 5.5000]])
tensor([[1., 1.],
[1., 1.]])

为什么在y.backward()时,如果y是标量,则不需要为backward()传入任何参数;否则,需要传入一个与y同形的Tensor?

简单来说就是为了避免向量(甚至更高维张量)对张量求导,而转换成标量对张量求导。举个例子,假设形状为 m x n 的矩阵 X 经过运算得到了 p x q 的矩阵 Y,Y 又经过运算得到了 s x t 的矩阵 Z。那么按照前面讲的规则,dZ/dY 应该是一个 s x t x p x q 四维张量,dY/dX 是一个 p x q x m x n的四维张量。问题来了,怎样反向传播?怎样将两个四维张量相乘???这要怎么乘???就算能解决两个四维张量怎么乘的问题,四维和三维的张量又怎么乘?导数的导数又怎么求,这一连串的问题,感觉要疯掉……

为了避免这个问题,我们不允许张量对张量求导,只允许标量对张量求导,求导结果是和自变量同形的张量。所以必要时我们要把张量通过将所有张量的元素加权求和的方式转换为标量,举个例子,假设y由自变量x计算而来,w是和y同形的张量,则y.backward(w)的含义是:先计算l = torch.sum(y * w),则l是个标量,然后求l对自变量x的导数。

参考

来看一些实际例子。

x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
y = 2 * x
z = y.view(2, 2)
print(z)

输出:

tensor([[2., 4.],
[6., 8.]], grad_fn=<ViewBackward>)

现在 z 不是一个标量,所以在调用backward时需要传入一个和z同形的权重向量进行加权求和得到一个标量。

v = torch.tensor([[1.0, 0.1], [0.01, 0.001]], dtype=torch.float)
z.backward(v)
print(x.grad)

输出:

tensor([2.0000, 0.2000, 0.0200, 0.0020])

注意,x.grad是和x同形的张量。

再来看看中断梯度追踪的例子:

x = torch.tensor(1.0, requires_grad=True)
y1 = x ** 2
with torch.no_grad():
y2 = x ** 3
y3 = y1 + y2 print(x.requires_grad)
print(y1, y1.requires_grad) # True
print(y2, y2.requires_grad) # False
print(y3, y3.requires_grad) # True

输出:

True
tensor(1., grad_fn=<PowBackward0>) True
tensor(1.) False
tensor(2., grad_fn=<ThAddBackward>) True

可以看到,上面的y2是没有grad_fn而且y2.requires_grad=False的,而y3是有grad_fn的。如果我们将y3x求梯度的话会是多少呢?

y3.backward()
print(x.grad)

输出:

tensor(2.)

为什么是2呢?$ y_3 = y_1 + y_2 = x^2 + x^3$,当 \(x=1\) 时 \(\frac {dy_3} {dx}\) 不应该是5吗?事实上,由于 \(y_2\) 的定义是被torch.no_grad():包裹的,所以与 \(y_2\) 有关的梯度是不会回传的,只有与 \(y_1\) 有关的梯度才会回传,即 \(x^2\) 对 \(x\) 的梯度。

上面提到,y2.requires_grad=False,所以不能调用 y2.backward(),会报错:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

此外,如果我们想要修改tensor的数值,但是又不希望被autograd记录(即不会影响反向传播),那么我么可以对tensor.data进行操作。

x = torch.ones(1,requires_grad=True)

print(x.data) # 还是一个tensor
print(x.data.requires_grad) # 但是已经是独立于计算图之外 y = 2 * x
x.data *= 100 # 只改变了值,不会记录在计算图,所以不会影响梯度传播 y.backward()
print(x) # 更改data的值也会影响tensor的值
print(x.grad)

输出:

tensor([1.])
False
tensor([100.], requires_grad=True)
tensor([2.])

注: 本文主要参考PyTorch官方文档,与原书同一节有很大不同。

pytorch 自动求梯度的更多相关文章

  1. PyTorch入门学习(二):Autogard之自动求梯度

    autograd包是PyTorch中神经网络的核心部分,简单学习一下. autograd提供了所有张量操作的自动求微分功能. 它的灵活性体现在可以通过代码的运行来决定反向传播的过程, 这样就使得每一次 ...

  2. Pytorch中的自动求梯度机制和Variable类

    自动求导机制是每一个深度学习框架中重要的性质,免去了手动计算导数,下面用代码介绍并举例说明Pytorch的自动求导机制. 首先介绍Variable,Variable是对Tensor的一个封装,操作和T ...

  3. 从头学pytorch(二) 自动求梯度

    PyTorch提供的autograd包能够根据输⼊和前向传播过程⾃动构建计算图,并执⾏反向传播. Tensor Tensor的几个重要属性或方法 .requires_grad 设为true的话,ten ...

  4. TensorFlow自动求梯度

    例1 import tensorflow as tf a=tf.Variable(tf.constant(1.0),name='a') b=tf.Variable(tf.constant(1.0),n ...

  5. 『PyTorch x TensorFlow』第六弹_从最小二乘法看自动求导

    TensoFlow自动求导机制 『TensorFlow』第二弹_线性拟合&神经网络拟合_恰是故人归 下面做了三个简单尝试, 利用包含gradients.assign等tf函数直接构建图进行自动 ...

  6. pytorch的自动求导机制 - 计算图的建立

    一.计算图简介 在pytorch的官网上,可以看到一个简单的计算图示意图, 如下. import torchfrom torch.autograd import Variable x = Variab ...

  7. Pytorch中的自动求导函数backward()所需参数含义

    摘要:一个神经网络有N个样本,经过这个网络把N个样本分为M类,那么此时backward参数的维度应该是[N X M] 正常来说backward()函数是要传入参数的,一直没弄明白backward需要传 ...

  8. Pytorch Tensor, Variable, 自动求导

    2018.4.25,Facebook 推出了 PyTorch 0.4.0 版本,在该版本及之后的版本中,torch.autograd.Variable 和 torch.Tensor 同属一类.更确切地 ...

  9. PytorchZerotoAll学习笔记(三)--自动求导

    Pytorch给我们提供了自动求导的函数,不用再自己再推导计算梯度的公式了 虽然有了自动求导的函数,但是这里我想给大家浅析一下:深度学习中的一个很重要的反向传播 references:https:// ...

随机推荐

  1. Spring Boot -- 认识Spring Boot

    在前面我们已经学习过Srping MVC框架,我们需要配置web.xml.spring mvc配置文件,tomcat,是不是感觉配置较为繁琐.那我们今天不妨来试试使用Spring Boot,Sprin ...

  2. 007.Oracle数据库 , 使用%进行模糊查询

    /*Oracle数据库查询日期在两者之间*/ SELECT PKID, OCCUR_DATE, ATA FROM LM_FAULT WHERE ( ( OCCUR_DATE >= to_date ...

  3. jquery动态选中radio,获取radio选中值

    //动态选中radio值,1:表示radio的name 2:表示后台传过来的radio值$(":radio[name='1'][value='" + 2 + "']&qu ...

  4. Java 布尔运算

    章节 Java 基础 Java 简介 Java 环境搭建 Java 基本语法 Java 注释 Java 变量 Java 数据类型 Java 字符串 Java 类型转换 Java 运算符 Java 字符 ...

  5. How To Configure NFS Client on CentOS 8 / RHEL 8

    https://computingforgeeks.com/configure-nfs-client-on-centos-rhel/

  6. 《ES6标准入门》(阮一峰)--8.函数的扩展

    1.函数参数的默认值 基本用法 ES6 之前,不能直接为函数的参数指定默认值,只能采用变通的方法. function log(x, y) { y = y || 'World'; console.log ...

  7. Jmeter安装插件Stepping Thread Group

    下载链接:https://jmeter-plugins.org/downloads/old/ 下载解压后,将JMeterPlugins-Standard.jar包放在jmeter安装目录的jmeter ...

  8. duilib+cef自定义浏览器控件编译错误

    新版博客已经搭建好了,有问题请访问 htt://www.crazydebug.com 公司二期好主播项目,决定用duilib开发界面,且从ie内核换成谷歌内核 再用duilib自定义一个Browser ...

  9. 基于springboot实现Java阿里短信发送

    1.接口TestController import java.util.Random; import com.aliyuncs.DefaultAcsClient; import com.aliyunc ...

  10. Maven项目工程目录

    maven工程目录规范: src/main/java   存放项目的.java文件 src/main/resources   存放项目的资源文件,如spring.hibernate配置文件 src/t ...