参考自《Pytorch autograd,backward详解》:

1 Tensor

Pytorch中所有的计算其实都可以回归到Tensor上,所以有必要重新认识一下Tensor。

如果我们需要计算某个Tensor的导数,那么我们需要设置其.requires_grad属性为True。为方便说明,在本文中对于这种我们自己定义的变量,我们称之为叶子节点(leaf nodes),而基于叶子节点得到的中间或最终变量则可称之为结果节点

另外一个Tensor中通常会记录如下图中所示的属性:

  • data: 即存储的数据信息
  • requires_grad: 设置为True则表示该 Tensor 需要求导
  • grad: 该 Tensor 的梯度值,每次在计算 backward 时都需要将前一时刻的梯度归零,否则梯度值会一直累加,这个会在后面讲到。
  • grad_fn: 叶子节点通常为 None,只有结果节点的 grad_fn 才有效,用于指示梯度函数是哪种类型。
  • is_leaf: 用来指示该 Tensor 是否是叶子节点。

举例:

  1. x = torch.rand(3, requires_grad=True)
  2. y = x ** 2
  3. z = x + x
  4.  
  5. print(
  6. 'x requires grad: {}, is leaf: {}, grad: {}, grad_fn: {}.'
  7. .format(x.requires_grad, x.is_leaf, x.grad, x.grad_fn)
  8. )
  9. print(
  10. 'y requires grad: {}, is leaf: {}, grad: {}, grad_fn: {}.'
  11. .format(y.requires_grad, y.is_leaf, y.grad, y.grad_fn)
  12. )
  13. print(
  14. 'z requires grad: {}, is leaf: {}, grad: {}, grad_fn: {}.'
  15. .format(z.requires_grad, z.is_leaf, z.grad, z.grad_fn)
  16. )

运行结果:

  1. x requires grad: True, is leaf: True, grad: None, grad_fn: None.
  2. y requires grad: True, is leaf: False, grad: None, grad_fn: <PowBackward0 object at 0x0000021A3002CD88>.
  3. z requires grad: True, is leaf: False, grad: None, grad_fn: <AddBackward0 object at 0x0000021A3002CD88>.

2 torch.autograd.backward

如下代码:

  1. x = torch.tensor(1.0, requires_grad=True)
  2. y = torch.tensor(2.0, requires_grad=True)
  3. z = x**2+y
  4. z.backward()
  5. print(z, x.grad, y.grad)
  6.  
  7. >>> tensor(3., grad_fn=<AddBackward0>) tensor(2.) tensor(1.)

当 z 是一个标量,当调用它的 backward 方法后会根据链式法则自动计算出叶子节点的梯度值。

但是如果遇到 z 是一个向量或者是一个矩阵的情况,这个时候又该怎么计算梯度呢?这种情况我们需要定义grad_tensor来计算矩阵的梯度。在介绍为什么使用之前我们先看一下源代码中backward的接口是如何定义的:

  1. torch.autograd.backward(
  2. tensors,
  3. grad_tensors=None,
  4. retain_graph=None,
  5. create_graph=False,
  6. grad_variables=None)
  • tensor: 用于计算梯度的 tensor。也就是说这两种方式是等价的:torch.autograd.backward(z) == z.backward()
  • grad_tensors: 在计算非标量的梯度时会用到。他其实也是一个tensor,它的shape一般需要和前面的tensor保持一致。
  • retain_graph: 通常在调用一次 backward 后,pytorch 会自动把计算图销毁,所以要想对某个变量重复调用 backward,则需要将该参数设置为True
  • create_graph: 当设置为True的时候可以用来计算更高阶的梯度
  • grad_variables: 这个官方说法是 grad_variables' is deprecated. Use 'grad_tensors' instead. 也就是说这个参数后面版本中应该会丢弃,直接使用grad_tensors就好了。

pytorch设计了grad_tensors这么一个参数。它的作用相当于“权重”。

先看一个例子:

  1. x = torch.ones(2,requires_grad=True)
  2. z = x + 2
  3. z.backward()
  4.  
  5. >>> ...
  6. RuntimeError: grad can be implicitly created only for scalar outputs

上面的报错信息意思是只有对标量输出它才会计算梯度,而求一个矩阵对另一矩阵的导数束手无策。

$X = \begin{bmatrix} x_0 & x_1 \end{bmatrix} \Rightarrow Z = \begin{bmatrix} x_0 + 2 & x_1 + 2 \end{bmatrix} \Rightarrow \frac{\partial Z}{\partial X} = ?$

那么我们只要想办法把 $Z$ 转变成一个标量不就好了?比如我们可以对 $Z$ 求和,然后用求和得到的标量在分别对 $x_0, x_1$ 求导,这样不会对结果有影响,例如:

$Z_{sum} = \sum z_i = x_0 + x_1 + 4$

$\frac{\partial Z_{sum}}{\partial x_0} = \frac{\partial Z_{sum}}{\partial x_1} = 1$

  1. x = torch.ones(2,requires_grad=True)
  2. z = x + 2
  3. z.sum().backward()
  4. print(x.grad)
  5.  
  6. >>> tensor([1., 1.])

grad_tensors这个参数就扮演了帮助求和的作用。

换句话说,就是对 $Z$ 和一个权重张量grad_tensors进行 hadamard product 后求和。这也是 grad_tensors 需要与传入的 tensor 大小一致的原因。

  1. x = torch.ones(2,requires_grad=True)
  2. z = x + 2
  3. z.backward(torch.ones_like(z)) # grad_tensors需要与输入tensor大小一致
  4. print(x.grad)
  5.  
  6. >>> tensor([1., 1.])

3 torch.autograd.grad

  1. torch.autograd.grad(
  2. outputs,
  3. inputs,
  4. grad_outputs=None,
  5. retain_graph=None,
  6. create_graph=False,
  7. only_inputs=True,
  8. allow_unused=False)

看了前面的内容后再看这个函数就很好理解了,各参数作用如下:

  • outputs: 结果节点,即被求导数
  • inputs: 叶子节点
  • grad_outputs: 类似于backward方法中的grad_tensors
  • retain_graph: 同上
  • create_graph: 同上
  • only_inputs: 默认为True,如果为True,则只会返回指定input的梯度值。 若为False,则会计算所有叶子节点的梯度,并且将计算得到的梯度累加到各自的.grad属性上去。
  • allow_unused: 默认为False, 即必须要指定input,如果没有指定的话则报错。

注意该函数返回的是 tuple 类型。

关于Pytorch中autograd和backward的一些笔记的更多相关文章

  1. Pytorch中的自动求导函数backward()所需参数含义

    摘要:一个神经网络有N个样本,经过这个网络把N个样本分为M类,那么此时backward参数的维度应该是[N X M] 正常来说backward()函数是要传入参数的,一直没弄明白backward需要传 ...

  2. pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分

    参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...

  3. pytorch 中的重要模块化接口nn.Module

    torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...

  4. pytorch中调用C进行扩展

    pytorch中调用C进行扩展,使得某些功能在CPU上运行更快: 第一步:编写头文件 /* src/my_lib.h */ int my_lib_add_forward(THFloatTensor * ...

  5. PyTorch中的C++扩展

    今天要聊聊用 PyTorch 进行 C++ 扩展. 在正式开始前,我们需要了解 PyTorch 如何自定义module.这其中,最常见的就是在 python 中继承torch.nn.Module,用 ...

  6. [源码解析] PyTorch 分布式 Autograd (1) ---- 设计

    [源码解析] PyTorch 分布式 Autograd (1) ---- 设计 目录 [源码解析] PyTorch 分布式 Autograd (1) ---- 设计 0x00 摘要 0x01 分布式R ...

  7. [源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础

    [源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础 目录 [源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础 0x00 摘要 0x0 ...

  8. [源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关

    [源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关 0x00 摘要 我们已经知道 dist.autograd 如何发送和接受消息,本文再来看看如何其他支撑部分,就是如 ...

  9. [源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

    [源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎 目录 [源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎 0x00 摘要 0 ...

随机推荐

  1. NO33 第6--7关题目讲解

    客户端(电脑)通过浏览器输入域名,先找hosts文件及本地dns缓存,若都没有,就找localDNS服务器,若没有,localDNF服务器找根服务器(全球13台的那个根”.“服务器),根就把.com这 ...

  2. MQTT 协议学习:000-有关概念入门

    背景 从本章开始,在没有特殊说明的情况下,文章中的MQTT版本均为 3.1.1. MQTT 协议是物联网中常见的协议之一,"轻量级物联网消息推送协议",MQTT同HTTP属于第七层 ...

  3. Centos7 rsync+inotify两台服务器同步文件(单向)

    注:本篇介绍的是单向同步,即A文件同步到B,但B的文件不同步到A,双向同步的在下一篇文章中. rsync与inotify不再赘述,直接进入实战. 0.背景 两台服务器IP地址分别为: 源服务器:192 ...

  4. 赶在EW2020之前,FreeRTOS发布V10.3.0,将推出首个LTS版本

    点击下载:FreeRTOSv10.3.0.exe 说明: 1.新版更新: (1)对于IAR For RISC-V进行支持,并且加强了对RISC-V内核芯片支持,做了多处修正. (2)对阿里平头哥CH2 ...

  5. cenos7配置confluence+mysql5.6

    一.准备阶段 我的环境为 腾讯云镜像centos7.4 ,centos 内置 mariadb  需要先删除 #检查是否安装了 mariadb rpm -qa |grep mariadb #删除mari ...

  6. 【Jasypt】给你的配置加把锁

    前言 前几天,有个前同事向我吐槽,他们公司有个大神把公司的项目代码全部上传到了 github,并且是公开项目,所有人都可以浏览.更加恐怖的是项目里面包含配置文件,数据库信息.redis 配置.各种公钥 ...

  7. C# 绘制矩形方框读写内存类 cs1.6人物透视例子

     封装的有问题 其中方框可能在别的方向可能 会显示不出来建议不要下载了 抽时间我会用纯c#写一个例子的  其中绘制方框文字和直线调用的外部dll采用DX11(不吃CUP)绘制我封装成了DLL命名为 S ...

  8. 吴裕雄 Bootstrap 前端框架开发——Bootstrap 字体图标(Glyphicons):glyphicon glyphicon-cloud

    <!DOCTYPE html> <html> <head> <meta charset="utf-8"> <meta name ...

  9. Centos7忘记mysql的root用户密码

    1.先停止mysql服务 ​[root@CentOS ~]# ps -ef | grep mysql root : pts/ :: /bin/sh /usr/local/mysql/bin/mysql ...

  10. POJ 3393:Lucky and Good Months by Gregorian Calendar 年+星期 模拟

    Lucky and Good Months by Gregorian Calendar Time Limit: 1000MS   Memory Limit: 65536K Total Submissi ...