PyTorch中的backward [转]
转自:https://sherlockliao.github.io/2017/07/10/backward/
backward只能被应用在一个标量上,也就是一个一维tensor,或者传入跟变量相关的梯度。
特别注意Variable
里面默认的参数requires_grad=False
,所以这里我们要重新传入requires_grad=True
让它成为一个叶子节点。
对其求偏导:
import torch as t
from torch.autograd import Variable as v # simple gradient
a = v(t.FloatTensor([2, 3]), requires_grad=True)
b = a + 3
c = b * b * 3
out = c.mean()
out.backward()
print('*'*10)
print('=====simple gradient======')
print('input')
print(a.data)
print('compute result is')
print(out.data[0])
print('input gradients are')
print(a.grad.data)
下面研究一下如何能够对非标量的情况下使用backward。backward里传入的参数是每次求导的一个系数。
首先定义好输入m=(x1,x2)=(2,3),然后我们做的操作就是n=,这样我们就定义好了一个向量输出,结果第一项只和x1有关,结果第二项只和x2有关,那么求解这个梯度,
# backward on non-scalar output
m = v(t.FloatTensor([[2, 3]]), requires_grad=True)
n = v(t.zeros(1, 2))
n[0, 0] = m[0, 0] ** 2
n[0, 1] = m[0, 1] ** 3
n.backward(t.FloatTensor([[1, 1]]))
print('*'*10)
print('=====non scalar output======')
print('input')
print(m.data)
print('input gradients are')
print(m.grad.data)
jacobian矩阵
对其求导:
k.backward(parameters)
接受的参数parameters
必须要和k
的大小一模一样,然后作为k
的系数传回去,backward里传入的参数是每次求导的一个系数。
# jacobian
j = t.zeros(2 ,2)
k = v(t.zeros(1, 2))
m.grad.data.zero_()
k[0, 0] = m[0, 0] ** 2 + 3 * m[0 ,1]
k[0, 1] = m[0, 1] ** 2 + 2 * m[0, 0]
# [1, 0] dk0/dm0, dk1/dm0
k.backward(t.FloatTensor([[1, 0]]), retain_variables=True) # 需要两次反向求导
j[:, 0] = m.grad.data
m.grad.data.zero_()
# [0, 1] dk0/dm1, dk1/dm1
k.backward(t.FloatTensor([[0, 1]]))
j[:, 1] = m.grad.data
print('jacobian matrix is')
print(j)
我们要注意backward()
里面另外的一个参数retain_variables=True
,这个参数默认是False,也就是反向传播之后这个计算图的内存会被释放,这样就没办法进行第二次反向传播了,所以我们需要设置为True,因为这里我们需要进行两次反向传播求得jacobian矩阵。
PyTorch中的backward [转]的更多相关文章
- 关于Pytorch中autograd和backward的一些笔记
参考自<Pytorch autograd,backward详解>: 1 Tensor Pytorch中所有的计算其实都可以回归到Tensor上,所以有必要重新认识一下Tensor. 如果我 ...
- pytorch中tensorboardX的用法
在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...
- pytorch 中的重要模块化接口nn.Module
torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...
- 转pytorch中训练深度神经网络模型的关键知识点
版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...
- pytorch中调用C进行扩展
pytorch中调用C进行扩展,使得某些功能在CPU上运行更快: 第一步:编写头文件 /* src/my_lib.h */ int my_lib_add_forward(THFloatTensor * ...
- 关于Pytorch中accuracy和loss的计算
这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚. 给出实例 def train(train_loader, model, criteon, optimizer, epoc ...
- 【PyTorch】PyTorch中的梯度累加
PyTorch中的梯度累加 使用PyTorch实现梯度累加变相扩大batch PyTorch中在反向传播前为什么要手动将梯度清零? - Pascal的回答 - 知乎 https://www.zhihu ...
- PyTorch中的C++扩展
今天要聊聊用 PyTorch 进行 C++ 扩展. 在正式开始前,我们需要了解 PyTorch 如何自定义module.这其中,最常见的就是在 python 中继承torch.nn.Module,用 ...
- PyTorch中的Batch Normalization
Pytorch中的BatchNorm的API主要有: 1 torch.nn.BatchNorm1d(num_features, 2 3 eps=1e-05, 4 5 momentum=0.1, 6 7 ...
随机推荐
- Python运维开发基础03-语法基础 【转】
上节作业回顾(讲解+温习60分钟) #!/usr/bin/env python3 # -*- coding:utf-8 -*- # author:Mr.chen #只用变量和字符串+循环实现“用户登陆 ...
- 【转】HashMap实现原理及源码分析
哈希表(hash table)也叫散列表,是一种非常重要的数据结构,应用场景极其丰富,许多缓存技术(比如memcached)的核心其实就是在内存中维护一张大的哈希表,而HashMap的实现原理也常常出 ...
- DCL单例模式
我们第一次写的单例模式是下面这样的: public class Singleton { private static Singleton instance = null; public static ...
- 鸟哥Linux私房菜基础学习篇学习笔记1
鸟哥Linux私房菜基础学习篇学习笔记1 第三章 主导分区(MBR),当系统在开机的时候会主动去读取这个区块的内容,必须对硬盘进行分区,这样硬盘才能被有效地使用. 所谓的分区只是针对64Bytes的分 ...
- OPENSSL编程 第二十章 椭圆曲线
20.1 ECC介绍 椭圆曲线算法可以看作是定义在特殊集合下数的运算,满足一定的规则.椭圆曲线在如下两个域中定义:Fp域和F2m域. Fp域,素数域,p为素数: F2m域:特征为2的有限域,称之为二 ...
- Git系列①之仓库管理互联网托管平台github.com的使用
互联网项目托管平台github.com的使用 1.安装git客户端 # yum install -y git 配置git全局用户以及邮箱 [root@web01 ~]# git config --gl ...
- FormData中delete方法在ios不兼容
1.移动端直接用的input的file上传图片(name=“file”必填) <input type="file" id="exampleInputFile1&qu ...
- 一篇文章教你读懂UI绘制流程
最近有好多人问我Android没信心去深造了,找不到好的工作,其实我以一个他们进行回复,发现他们主要是内心比较浮躁,要知道技术行业永远缺少的是高手.建议先阅读浅谈Android发展趋势分析,在工作中, ...
- kindEditor 富文本编辑器 使用介绍
第一版:存放位置: ---->把该创建的文件包放到javaWeb 过程的 WEB_INF 下:如图所示. 第二步:< kindEditor 插件的引用> :JS引用 <scr ...
- Confluence 6 配置 HTTP 超时设置
当宏,例如 RSS Macro 进行 HTTP 请求的时候,有可能因为请求的时间比较长,而导致超时.你可以通过设置系统参数来避免这个问题. 配置 HTTP 超时设置: 在屏幕的右上角单击 控制台按钮 ...