【pytorch学习】之自动微分
5 自动微分
求导是几乎所有深度学习优化算法的关键步骤。虽然求导的计算很简单,只需要一些基本的微积分。但对于复杂的模型,手工进行更新是一件很痛苦的事情(而且经常容易出错)。深度学习框架通过自动计算导数,即自动微分(automatic differentiation)来加快求导。实际中,根据设计好的模型,系统会构建一个计算图(computational graph),来跟踪计算是哪些数据通过哪些操作组合起来产生输出。自动微分使系统能够随后反向传播梯度。这里,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。
5.1 一个简单的例子
作为一个演示例子,假设我们想对函数$ y = 2\mathbf{x}^T\mathbf{x}$关于列向量x求导。首先,我们创建变量x并为其分配一个初始值。
import torch
x = torch.arange(4.0)
x
tensor([0., 1., 2., 3.])
在我们计算y关于x的梯度之前,需要一个地方来存储梯度。重要的是,我们不会在每次对一个参数求导时都分配新的内存。因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。注意,一个标量函数关于向量x的梯度是向量,并且与x具有相同的形状。
x.requires_grad_(True) # 等价于x=torch.arange(4.0,requires_grad=True)
x.grad # 默认值是None
现在计算y。
y = 2 * torch.dot(x, x)
y
tensor(28., grad_fn=<MulBackward0>)
x是一个长度为4的向量,计算x和x的点积,得到了我们赋值给y的标量输出。接下来,通过调用反向传播函数来自动计算y关于x每个分量的梯度,并打印这些梯度。
y.backward()
x.grad
tensor([ 0., 4., 8., 12.])
函数$ y = 2\mathbf{x}^T\mathbf{x}\(关于\)x$的梯度应为\(4x\)。让我们快速验证这个梯度是否计算正确。
x.grad == 4 * x
tensor([True, True, True, True])
现在计算x的另一个函数。
# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
x.grad
tensor([1., 1., 1., 1.])
5.2 非标量变量的反向传播
当y不是标量时,向量y关于向量x的导数的最自然解释是一个矩阵。对于高阶和高维的y和x,求导的结果可以是一个高阶张量。
然而,虽然这些更奇特的对象确实出现在高级机器学习中(包括深度学习中),但当调用向量的反向计算时,我们通常会试图计算一批训练样本中每个组成部分的损失函数的导数。这里,我们的目的不是计算微分矩阵,而是单独计算批量中每个样本的偏导数之和。
# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x
# 等价于y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad
tensor([0., 2., 4., 6.])
5.3 分离计算
有时,我们希望将某些计算移动到记录的计算图之外。例如,假设y是作为x的函数计算的,而z则是作为y和x的函数计算的。想象一下,我们想计算z关于x的梯度,但由于某种原因,希望将y视为一个常数,并且只考虑到x在y被计算后发挥的作用。
这里可以分离y来返回一个新变量u,该变量与y具有相同的值,但丢弃计算图中如何计算y的任何信息。换句话说,梯度不会向后流经u到x。因此,下面的反向传播函数计算\(z=u*x\)关于x的偏导数,同时将u作为常数处理,而不是\(z=x*x*x\)关于x的偏导数。
x.grad.zero_()
y = x * x
u = y.detach()
z = u * x
z.sum().backward()
x.grad == u,u
(tensor([True, True, True, True]), tensor([0., 1., 4., 9.]))
由于记录了y的计算结果,我们可以随后在y上调用反向传播,得到y=xx关于的x的导数,即2x。
x.grad.zero_()
y.sum().backward()
x.grad == 2 * x
tensor([True, True, True, True])
5.4 Python控制流的梯度计算
使用自动微分的一个好处是:即使构建函数的计算图需要通过Python控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度。在下面的代码中,while循环的迭代次数和if语句的结果都取决于输入a的值。
def f(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = 100 * b
return c
计算梯度。
a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()
我们现在可以分析上面定义的f函数。请注意,它在其输入a中是分段线性的。换言之,对于任何a,存在某个常量标量k,使得f(a)=k*a,其中k的值取决于输入a,因此可以用d/a验证梯度是否正确。
a.grad == d / a
tensor(True)
【pytorch学习】之自动微分的更多相关文章
- pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分
参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...
- PyTorch自动微分基本原理
序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...
- PyTorch 自动微分示例
PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...
- PyTorch 自动微分
PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...
- 自动微分(AD)学习笔记
1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...
- Pytorch学习笔记(一)---- 基础语法
书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...
- 【pytorch】pytorch学习笔记(一)
原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...
- Pytorch学习笔记(一)——简介
一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...
- PyTorch 学习
PyTorch torch.autograd模块 深度学习的算法本质上是通过反向传播求导数, PyTorch的autograd模块实现了此功能, 在Tensor上的所有操作, autograd都会为它 ...
- MindSpore:自动微分
MindSpore:自动微分 作为一款「全场景 AI 框架」,MindSpore 是人工智能解决方案的重要组成部分,与 TensorFlow.PyTorch.PaddlePaddle 等流行深度学习框 ...
随机推荐
- 什么叫运行时的Java程序?
Java程序的运行包含编写.编译和运行三个主要步骤. 1.在编写阶段: 开发人员在Java开发环境中输入程序代码,形成后缀名为.java的Java源文件. 2.在编译阶段: 使用Java编译器对源文件 ...
- 后端基础PHP—正则表达
后端基础PHP-正则表达式 1.正则表达式的介绍 2.正则表达式的语法 一.正则表达式的介绍 正则表达式的介绍 · 正则表达式,又称规则表达式,通过一种特殊的语言来挑选符合条件的数据 · 在代码中简写 ...
- 各种O总结及阿里代码规范总结
首先梳理下POJO POJO包括 DO/DTO/BO/VO(所有的POJO类属性必须使用包装数据类型.) 定义 DO/DTO/VO 等 POJO 类时,不要设定任何属性默认值. controller使 ...
- 标记SA_RESTART的作用
在程序执行的过程中,有时候会收到信号,我们可以捕捉信号并执行信号处理函数,信号注册函数里有一个struct sigaction的结构体,其中有一个sa_flags的成员,如果sa_flags |= S ...
- Android混淆后的bug日志通过mapping文件找对应行号
背景 由于项目中提测以及线上的apk都是经过混淆处理的,因此拿到日志后也无法正常查看崩溃日志的行号 这个原因是因为混淆了文件,输出的日志是对应不上源文件的,为了正确找到行号需要用到mapping.tx ...
- js使用typeof与instanceof相结合编写一个判断常见变量类型的函数
/** * 常见类型判断 * @param {any} param */ function getParamType(param) { // 先判断是否能用typeof 直接判断 let types1 ...
- 解决Idea找不到URL问题
解决Idea找不到URL问题 我这几天遇到一个特别恶心的问题,查了很多资料,都是没用的后来自己静下心来,发现自己的import导包错了,我用的是jakarta,jakarta主要是利用Tomcat ...
- 基于Rust的Tile-Based游戏开发杂记(01)导入
什么是Tile-Based游戏? Tile-based游戏是一种使用tile(译为:瓦片,瓷砖)作为基本构建单位来设计游戏关卡.地图或其他视觉元素的游戏类型.在这样的游戏中,游戏世界的背景.地形.环境 ...
- DETR:Facebook提出基于Transformer的目标检测新范式,性能媲美Faster RCNN | ECCV 2020 Oral
DETR基于标准的Transorfmer结构,性能能够媲美Faster RCNN,而论文整体思想十分简洁,希望能像Faster RCNN为后续的很多研究提供了大致的思路 来源:晓飞的算法工程笔记 ...
- Modbus报文详解
Modbus是一种串行通信协议,最初由Modicon公司(现为施耐德电气的一部分)在1979年为使用其PLC(可编程逻辑控制器)而开发.Modbus已成为工业领域内广泛使用的一种通信协议,特别是对于监 ...