backward函数

官方定义:

torch.autograd.backward(tensorsgrad_tensors=Noneretain_graph=Nonecreate_graph=Falsegrad_variables=None)

Computes the sum of gradients of given tensors w.r.t. graph leaves.The graph is differentiated using the chain rule. If any of tensors are non-scalar (i.e. their data has more than one element) and require gradient, the function additionally requires specifying grad_tensors. It should be a sequence of matching length, that contains gradient of the differentiated function w.r.t. corresponding tensors (None is an acceptable value for all tensors that don’t need gradient tensors). This function accumulates gradients in the leaves - you might need to zero them before calling it.

翻译和解释:参数tensors如果是标量,函数backward计算参数tensors对于给定图叶子节点的梯度( graph leaves,即为设置requires_grad=True的变量)。

参数tensors如果不是标量,需要另外指定参数grad_tensors参数grad_tensors必须和参数tensors的长度相同。在这一种情况下,backward实际上实现的是代价函数(loss = torch.sum(tensors*grad_tensors); 注:torch中向量*向量实际上是点积,因此tensors和grad_tensors的维度必须一致 )关于叶子节点的梯度计算,而不是参数tensors对于给定图叶子节点的梯度。如果指定参数grad_tensors=torch.ones((size(tensors))),显而易见,代价函数关于叶子节点的梯度,也就等于参数tensors对于给定图叶子节点的梯度。

每次backward之前,需要注意叶子梯度节点是否清零,如果没有清零,第二次backward会累计上一次的梯度。

下面给出具体的例子:

 import torch
x=torch.randn((3),dtype=torch.float32,requires_grad=True)
y = torch.randn((3),dtype=torch.float32,requires_grad=True)
z = torch.randn((3),dtype=torch.float32,requires_grad=True)
t = x + y
loss = t.dot(z) #求向量的内积

在调用 backward 之前,可以先手动求一下导数,应该是:

用代码实现求导:

 loss.backward(retain_graph=True)
print(z,x.grad,y.grad) #预期打印出的结果都一样
print(t,z.grad) #预期打印出的结果都一样
print(t.grad) #在这个例子中,x,y,z就是叶子节点,而t不是,t的导数在backward的过程中求出来回传之后就会被释放,因而预期结果是None

结果和预期一致:

tensor([-2.6752,  0.2306, -0.8356], requires_grad=True) tensor([-2.6752,  0.2306, -0.8356]) tensor([-2.6752,  0.2306, -0.8356])
tensor([-1.1916, -0.0156, 0.8952], grad_fn=<AddBackward0>) tensor([-1.1916, -0.0156, 0.8952])
None

敲重点:注意到前面函数的解释中,在参数tensors不是标量的情况下,tensor.backward(grad_tensors)实现的是代价函数torch.sum(tensors*grad_tensors))关于叶子节点的导数。在上面例子中,loss = t.dot(z),因此用t.backward(z),实现的就是loss对于所有叶子结点的求导,实际运算结果和预期吻合。

 t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)

运行结果如下:

tensor([-0.7830,  1.4468,  1.2440], requires_grad=True) tensor([-0.7830,  1.4468,  1.2440]) tensor([-0.7830,  1.4468,  1.2440])
tensor([-0.7145, -0.7598, 2.0756], grad_fn=<AddBackward0>) None

上面的结果中,出现了一个问题,虽然loss关于x和y的导数正确,但是z不再是叶子节点了。

问题1:当使用t.backward(z,retain_graph=True)的时候, print(z.grad)结果是None,这意味着z不再是叶子节点,这是为什么呢?

另外一个尝试,loss = t.dot(z)=z.dot(t),但是如果用z.backward(t)替换t.backward(z,retain_graph=True),结果却不同。

 z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)

运行结果:

tensor([-1.0716, -1.3643, -0.0016], requires_grad=True) None None
tensor([-0.7324, 0.9763, -0.4036], grad_fn=<AddBackward0>) tensor([-0.7324, 0.9763, -0.4036])

问题2:上面的结果中可以看到,使用z.backward(t),x和y都不再是叶子节点了,z仍然是叶子节点,且得到的loss相对于z的导数正确。

上述仿真出现的两个问题,我还不能解释,希望和大家交流,我的邮箱yangyuwen_yang@126.com,欢迎来信讨论。

问题1:当使用t.backward(z,retain_graph=True)的时候, print(z.grad)结果是None,这意味着z不再是叶子节点,这是为什么呢?

问题2:上面的结果中可以看到,使用z.backward(t),x和y都不再是叶子节点了,z仍然是叶子节点,且得到的loss相对于z的导数正确。



另外强调一下,每次backward之前,需要注意叶子梯度节点是否清零,如果没有清零,第二次backward会累计上一次的梯度。
简单的代码可以看出:
 #测试1,:对比上两次单独执行backward,此处连续执行两次backward
t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)
z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)
# 结果x.grad,y.grad本应该是None,因为保留了第一次backward的结果而打印出上一次梯度的结果

tensor([-0.5590, -1.4094, -1.5367], requires_grad=True) tensor([-0.5590, -1.4094, -1.5367]) tensor([-0.5590, -1.4094, -1.5367])
tensor([-1.7914, 0.8761, -0.3462], grad_fn=<AddBackward0>) None
tensor([-0.5590, -1.4094, -1.5367], requires_grad=True) tensor([-0.5590, -1.4094, -1.5367]) tensor([-0.5590, -1.4094, -1.5367])
tensor([-1.7914, 0.8761, -0.3462], grad_fn=<AddBackward0>) tensor([-1.7914, 0.8761, -0.3462])

 #测试2,:连续执行两次backward,并且清零,可以验证第二次backward没有计算x和y的梯度
t.backward(z,retain_graph=True)
print(z,x.grad,y.grad)
print(t,z.grad)
x.grad.data.zero_()
y.grad.data.zero_()
z.backward(t)
print(z,x.grad,y.grad)
print(t,z.grad)

tensor([ 0.8671, 0.6503, -1.6643], requires_grad=True) tensor([ 0.8671, 0.6503, -1.6643]) tensor([ 0.8671, 0.6503, -1.6643])
tensor([1.6231e+00, 1.3842e+00, 4.6492e-06], grad_fn=<AddBackward0>) None
tensor([ 0.8671, 0.6503, -1.6643], requires_grad=True) tensor([0., 0., 0.]) tensor([0., 0., 0.])
tensor([1.6231e+00, 1.3842e+00, 4.6492e-06], grad_fn=<AddBackward0>) tensor([1.6231e+00, 1.3842e+00, 4.6492e-06])

附参考学习的链接如下,并对作者表示感谢:
PyTorch 的 backward 为什么有一个 grad_variables 参数?该话题在我的知乎专栏里同时转载,欢迎更多讨论交流,参见链接

Pytorch中torch.autograd ---backward函数的使用方法详细解析,具体例子分析的更多相关文章

  1. PyTorch 中 torch.matmul() 函数的文档详解

    官方文档 torch.matmul() 函数几乎可以用于所有矩阵/向量相乘的情况,其乘法规则视参与乘法的两个张量的维度而定. 关于 PyTorch 中的其他乘法函数可以看这篇博文,有助于下面各种乘法的 ...

  2. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

  3. 模式识别 - libsvm该函数的调用方法 详细说明

    libsvm该函数的调用方法 详细说明 本文地址: http://blog.csdn.net/caroline_wendy/article/details/26261173 须要载入(load)SVM ...

  4. Pytorch中torch.load()中出现AttributeError: Can't get attribute

    原因:保存下来的模型和参数不能在没有类定义时直接使用. Pytorch使用Pickle来处理保存/加载模型,这个问题实际上是Pickle的问题,而不是Pytorch. 解决方法也非常简单,只需显式地导 ...

  5. 【学习笔记】pytorch中squeeze()和unsqueeze()函数介绍

    squeeze用来减少维度, unsqueeze用来增加维度 具体可见下方博客. pytorch中squeeze和unsqueeze

  6. pytorch 中的Variable一般常用的使用方法

    Variable一般的初始化方法,默认是不求梯度的 import torch from torch.autograd import Variable x_tensor = torch.randn(2, ...

  7. javascript中的闭包、函数的toString方法

    闭包: 闭包可以理解为定义在一个函数内部的函数, 函数A内部定义了函数B, 函数B有访问函数A内部变量的权力: 闭包是函数和子函数之间的桥梁: 举个例子: let func = function() ...

  8. pytorch autograd backward函数中 retain_graph参数的作用,简单例子分析,以及create_graph参数的作用

    retain_graph参数的作用 官方定义: retain_graph (bool, optional) – If False, the graph used to compute the grad ...

  9. pytorch中torch.narrow()函数

    torch.narrow(input, dim, start, length) → Tensor Returns a new tensor that is a narrowed version of  ...

随机推荐

  1. WebSocket和Socket

    WebSocket和Socket tags:WebSocket和Socket 引言:好多朋友想知道WebSocket和Socket的联系和区别,下面应该就是你们想要的 先来一张之前收集的图,我看到这张 ...

  2. spawn-fcgi运行fcgiwrap

    http://linuxjcq.blog.51cto.com/3042600/718002 标签:休闲 spawn-fcgi fcgiwarp fcgi 职场 原创作品,允许转载,转载时请务必以超链接 ...

  3. 玩转JPA(一)---异常:Repeated column in mapping for entity/should be mapped with insert="false" update="fal

    最近用JPA遇到这样一个问题:Repeated column in mapping for entity: com.ketayao.security.entity.main.User column: ...

  4. MySQL基本命令1

    在ubuntu系统中操作命令:登录:mysql -uroot -p启动:service mysql start停止:service mysql stop重启:service mysql restart ...

  5. SpringBoot JMS(ActiveMQ) 使用实践

    ActiveMQ 1. 下载windows办的activeMQ后,在以下目录可以启动: 2. 启动后会有以下提示 3. 所以我们可以通过http://localhost:8161访问管理页面,通过tc ...

  6. BZOJ_2303_[Apio2011]方格染色 _并查集

    BZOJ_2303_[Apio2011]方格染色 _并查集 Description Sam和他的妹妹Sara有一个包含n × m个方格的 表格.她们想要将其的每个方格都染成红色或蓝色. 出于个人喜好, ...

  7. Spring中bean的注入方式

    首先,要学习Spring中的Bean的注入方式,就要先了解什么是依赖注入.依赖注入是指:让调用类对某一接口的实现类的实现类的依赖关系由第三方注入,以此来消除调用类对某一接口实现类的依赖. Spring ...

  8. Django基础四(model和数据库)

    上一篇博文学习了Django的form和template.到目前为止,我们所涉及的内容都是怎么利用Django在浏览器页面上显示内容.WEB开发除了数据的显示之外,还有数据的存储,本文的内容就是如何在 ...

  9. 【移动端web】软键盘兼容问题

    软键盘收放事件 这周几天遇到了好几个关于web移动端兼容性的问题.并花了很长时间去研究如何处理这几种兼容问题. 这次我们来说说关于移动端软键盘的js处理吧. 一般情况下,前端是无法监控软键盘到底是弹出 ...

  10. python 防止sql注入字符串拼接的正确用法

    在使用pymysql模块时,在使用字符串拼接的注意事项错误用法1 sql='select * from where id="%d" and name="%s" ...