转载:https://www.cnblogs.com/marsggbo/p/11549631.html

平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考。以下笔记基于Pytorch1.0

1|0Tensor

Pytorch中所有的计算其实都可以回归到Tensor上,所以有必要重新认识一下Tensor。如果我们需要计算某个Tensor的导数,那么我们需要设置其.requires_grad属性为True。为方便说明,在本文中对于这种我们自己定义的变量,我们称之为叶子节点(leaf nodes),而基于叶子节点得到的中间或最终变量则可称之为结果节点。例如下面例子中的x则是叶子节点,y则是结果节点。

x = torch.rand(3, requires_grad=True)
y = x**2
z = x + x

另外一个Tensor中通常会记录如下图中所示的属性:

  • data: 即存储的数据信息
  • requires_grad: 设置为True则表示该Tensor需要求导
  • grad: 该Tensor的梯度值,每次在计算backward时都需要将前一时刻的梯度归零,否则梯度值会一直累加,这个会在后面讲到。
  • grad_fn: 叶子节点通常为None,只有结果节点的grad_fn才有效,用于指示梯度函数是哪种类型。例如上面示例代码中的y.grad_fn=<PowBackward0 at 0x213550af048>, z.grad_fn=<AddBackward0 at 0x2135df11be0>
  • is_leaf: 用来指示该Tensor是否是叶子节点。

2|0torch.autograd.backward

有如下代码:

x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
z = x**2+y
z.backward()
print(z, x.grad, y.grad)
>>> tensor(3., grad_fn=<AddBackward0>) tensor(2.) tensor(1.)

可以z是一个标量,当调用它的backward方法后会根据链式法则自动计算出叶子节点的梯度值。

但是如果遇到z是一个向量或者是一个矩阵的情况,这个时候又该怎么计算梯度呢?这种情况我们需要定义grad_tensor来计算矩阵的梯度。在介绍为什么使用之前我们先看一下源代码中backward的接口是如何定义的:

torch.autograd.backward(
tensors,
grad_tensors=None,
retain_graph=None,
create_graph=False,
grad_variables=None)
  • tensor: 用于计算梯度的tensor。也就是说这两种方式是等价的:torch.autograd.backward(z) == z.backward()
  • grad_tensors: 在计算矩阵的梯度时会用到。他其实也是一个tensor,shape一般需要和前面的tensor保持一致。
  • retain_graph: 通常在调用一次backward后,pytorch会自动把计算图销毁,所以要想对某个变量重复调用backward,则需要将该参数设置为True
  • create_graph: 当设置为True的时候可以用来计算更高阶的梯度
  • grad_variables: 这个官方说法是grad_variables' is deprecated. Use 'grad_tensors' instead.也就是说这个参数后面版本中应该会丢弃,直接使用grad_tensors就好了。

好了,参数大致作用都介绍了,下面我们看看pytorch为什么设计了grad_tensors这么一个参数,以及它有什么用呢?

还是用代码做个示例

x = torch.ones(2,requires_grad=True)
z = x + 2
z.backward()
>>> ...

RuntimeError: grad can be implicitly created only for scalar outputs

当我们运行上面的代码的话会报错,报错信息为RuntimeError: grad can be implicitly created only for scalar outputs

上面的报错信息意思是只有对标量输出它才会计算梯度,而求一个矩阵对另一矩阵的导数束手无策。

X=[x0x1] Z=X+2=[x0+2x1+2]⇒∂Z∂X=?X=[x0x1] Z=X+2=[x0+2x1+2]⇒∂Z∂X=?

那么我们只要想办法把矩阵转变成一个标量不就好了?比如我们可以对z求和,然后用求和得到的标量在对x求导,这样不会对结果有影响,例如:

Zsum=∑zi=x0+x1+8then∂Zsum∂x0=∂Zsum∂x1=1Zsum=∑zi=x0+x1+8then∂Zsum∂x0=∂Zsum∂x1=1

我们可以看到对z求和后再计算梯度没有报错,结果也与预期一样:

x = torch.ones(2,requires_grad=True)
z = x + 2
z.sum().backward()
print(x.grad)
>>> tensor([1., 1.])

我们再仔细想想,对z求和不就是等价于z点乘一个一样维度的全为1的矩阵吗?即sum(Z)=dot(Z,I)sum(Z)=dot(Z,I),而这个I也就是我们需要传入的grad_tensors参数。(点乘只是相对于一维向量而言的,对于矩阵或更高为的张量,可以看做是对每一个维度做点乘)

代码如下:

x = torch.ones(2,requires_grad=True)
z = x + 2
z.backward(torch.ones_like(z)) # grad_tensors需要与输入tensor大小一致
print(x.grad)
>>> tensor([1., 1.])

弄个再复杂一点的:

x = torch.tensor([2., 1.], requires_grad=True).view(1, 2)
y = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True) z = torch.mm(x, y)

print(f"z:{z}")

z.backward(torch.Tensor([[1., 0]]), retain_graph=True)

print(f"x.grad: {x.grad}")

print(f"y.grad: {y.grad}")
>>> z:tensor([[5., 8.]], grad_fn=<MmBackward>)

x.grad: tensor([[1., 3.]])

y.grad: tensor([[2., 0.],

[1., 0.]])

结果解释如下:

总结:

说了这么多,grad_tensors的作用其实可以简单地理解成在求梯度时的权重,因为可能不同值的梯度对结果影响程度不同,所以pytorch弄了个这种接口,而没有固定为全是1。引用自知乎上的一个评论:如果从最后一个节点(总loss)来backward,这种实现(torch.sum(y*w))的意义就具体化为 multiple loss term with difference weights 这种需求了吧。

3|0torch.autograd.grad

torch.autograd.grad(
outputs,
inputs,
grad_outputs=None,
retain_graph=None,
create_graph=False,
only_inputs=True,
allow_unused=False)

看了前面的内容后在看这个函数就很好理解了,各参数作用如下:

  • outputs: 结果节点,即被求导数
  • inputs: 叶子节点
  • grad_outputs: 类似于backward方法中的grad_tensors
  • retain_graph: 同上
  • create_graph: 同上
  • only_inputs: 默认为True, 如果为True, 则只会返回指定input的梯度值。 若为False,则会计算所有叶子节点的梯度,并且将计算得到的梯度累加到各自的.grad属性上去。
  • allow_unused: 默认为False, 即必须要指定input,如果没有指定的话则报错。

4|0参考

【转载】关于grad_tensors的解惑的更多相关文章

  1. 【转】字符编码笔记:ASCII,Unicode和UTF-8

    今天整理笔记,关于NSString转NSData时,什么时候使用NSUTF8StringEncoding,或者NSASCIIStringEncoding,或者 NSUnicodeStringEncod ...

  2. 转载:Python 包管理工具解惑

    Python 包管理工具解惑 本站文章除注明转载外,均为本站原创或者翻译. 本站文章欢迎各种形式的转载,但请18岁以上的转载者注明文章出处,尊重我的劳动,也尊重你的智商: 本站部分原创和翻译文章提供m ...

  3. ASP.NET 跨域请求之jQuery的ajax jsonp的使用解惑 (转载)

    前天在项目中写的一个ajax jsonp的使用,出现了问题:可以成功获得请求结果,但没有执行success方法,直接执行了error方法提示错误——ajax jsonp之前并没有用过,对其的理解为跟普 ...

  4. GCC中-fpic解惑(转载)

    参考: 1.<3.18 Options for Code Generation Conventions>2.<Options for Linking>3.<GCC -fP ...

  5. 转载:第五弹!全球首个微信小程序(应用号)开发教程!通宵吐血赶稿,每日更新!

    博卡君今天继续更新,忙了一天,终于有时间开工写教程.不罗嗦了,今天我们来看看如何实现一些前端的功能和效果. 第八章:微信小程序分组开发与左滑功能实现 先来看看今天的整体思路: 进入分组管理页面--&g ...

  6. (译)iOS Code Signing: 解惑

    子龙山人 Learning,Sharing,Improving! (译)iOS Code Signing: 解惑 免责申明(必读!):本博客提供的所有教程的翻译原稿均来自于互联网,仅供学习交流之用,切 ...

  7. Gulp入门与解惑

    Gulp简介 Gulp.js 是一个自动化构建工具,开发者可以使用它在项目开发过程中自动执行常见任务.Gulp.js是基于 Node.js构建的,利用Node.js流的威力,你可以快速构建项目. 安装 ...

  8. JVM菜鸟进阶高手之路九(解惑)

    转载请注明原创出处,谢谢! 在第八系列最后有些疑惑的地方,后来还是在我坚持不懈不断打扰笨神,阿飞,ak大神等,终于解决了该问题.第八系列地址:http://www.jianshu.com/p/7f7c ...

  9. JVM 菜鸟进阶高手之路九(解惑)

    转载请注明原创出处,谢谢! 在第八系列最后有些疑惑的地方,后来还是在我坚持不懈不断打扰笨神,阿飞,ak大神等,终于解决了该问题.第八系列地址:http://www.cnblogs.com/lirenz ...

随机推荐

  1. Codeforces Round #650 (Div. 3) C. Social Distance (前缀和)

    题意:有一排座位,要求每人之间隔\(k\)个座位坐,\(1\)代表已做,\(0\)代表空座,问最多能坐几人. 题解:我们分别从前和从后跑个前缀和,将已经有人坐的周围的位置标记,然后遍历求每一段连续的\ ...

  2. 如何使用Gephi工具进行可视化复杂网络图

    在Gephi安装官网中也介绍了一些如何使用该工具的方法,我将根据自己的数据和可视化的图片进行介绍 第一步:整理数据格式,我的数据是.csv格式的(.xlsx,.xls等等) 数据第一行第一列必须是相同 ...

  3. 国产网络测试仪MiniSMB - 如何配置VLAN数据流

    国产网络测试仪MiniSMB(www.minismb.com)是复刻smartbits的IP网络性能测试工具,是一款专门用于测试智能路由器,网络交换机的性能和稳定性的软硬件相结合的工具.可以通过此以太 ...

  4. 3.keepalived+脚本实现nginx高可用

    标题 : 3.keepalived+脚本实现nginx高可用 目录 : Nginx 序号 : 3 else exit 0 fi else exit 0 fi - 需要保证脚本有执行权限,可以使用chm ...

  5. woj1010 alternate sum 数学 woj1011 Finding Teamates 数学

    title: woj1010 alternate sum 数学 date: 2020-03-10 categories: acm tags: [acm,woj,数学] 一道数学题.简单. 题意 给一个 ...

  6. Windows下TeamViewer远程控制的安装与使用

    Windows下远程控制的安装与使用 由于疫情,在家里写论文,资料数据都在学校,只能远程控制了,选的是TeamViewer. 分为控制机和被控制机,安装使用略有不同. 从该网址安装:https://w ...

  7. codeforces 878A

    A. Short Program time limit per test 2 seconds memory limit per test 256 megabytes input standard in ...

  8. HTTP2.0 的学习笔记

    1 1 1 HTTP2.0 1 11 1 1 1 1 1 1 超文本传输安全协议(英语:Hypertext Transfer Protocol Secure,缩写:HTTPS,也被称为HTTP ove ...

  9. 使用 js 实现十大排序算法: 归并排序

    使用 js 实现十大排序算法: 归并排序 归并排序 refs js 十大排序算法 All In One https://www.cnblogs.com/xgqfrms/p/13947122.html ...

  10. 为什么国内的好多具备 HTTPS 的网站却没有使用 HTTPS 重定向功能

    为什么国内的好多具备 HTTPS 的网站却没有使用 HTTPS 重定向功能 HTTPS 重定向 good demos ️ HTTPS http://www.xgqfrms.xyz/ https://w ...