【pytorch学习】之自动微分
5 自动微分
求导是几乎所有深度学习优化算法的关键步骤。虽然求导的计算很简单,只需要一些基本的微积分。但对于复杂的模型,手工进行更新是一件很痛苦的事情(而且经常容易出错)。深度学习框架通过自动计算导数,即自动微分(automatic differentiation)来加快求导。实际中,根据设计好的模型,系统会构建一个计算图(computational graph),来跟踪计算是哪些数据通过哪些操作组合起来产生输出。自动微分使系统能够随后反向传播梯度。这里,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。
5.1 一个简单的例子
作为一个演示例子,假设我们想对函数$ y = 2\mathbf{x}^T\mathbf{x}$关于列向量x求导。首先,我们创建变量x并为其分配一个初始值。
import torch
x = torch.arange(4.0)
x
tensor([0., 1., 2., 3.])
在我们计算y关于x的梯度之前,需要一个地方来存储梯度。重要的是,我们不会在每次对一个参数求导时都分配新的内存。因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。注意,一个标量函数关于向量x的梯度是向量,并且与x具有相同的形状。
x.requires_grad_(True) # 等价于x=torch.arange(4.0,requires_grad=True)
x.grad # 默认值是None
现在计算y。
y = 2 * torch.dot(x, x)
y
tensor(28., grad_fn=<MulBackward0>)
x是一个长度为4的向量,计算x和x的点积,得到了我们赋值给y的标量输出。接下来,通过调用反向传播函数来自动计算y关于x每个分量的梯度,并打印这些梯度。
y.backward()
x.grad
tensor([ 0., 4., 8., 12.])
函数$ y = 2\mathbf{x}^T\mathbf{x}\(关于\)x$的梯度应为\(4x\)。让我们快速验证这个梯度是否计算正确。
x.grad == 4 * x
tensor([True, True, True, True])
现在计算x的另一个函数。
# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
x.grad
tensor([1., 1., 1., 1.])
5.2 非标量变量的反向传播
当y不是标量时,向量y关于向量x的导数的最自然解释是一个矩阵。对于高阶和高维的y和x,求导的结果可以是一个高阶张量。
然而,虽然这些更奇特的对象确实出现在高级机器学习中(包括深度学习中),但当调用向量的反向计算时,我们通常会试图计算一批训练样本中每个组成部分的损失函数的导数。这里,我们的目的不是计算微分矩阵,而是单独计算批量中每个样本的偏导数之和。
# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x
# 等价于y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad
tensor([0., 2., 4., 6.])
5.3 分离计算
有时,我们希望将某些计算移动到记录的计算图之外。例如,假设y是作为x的函数计算的,而z则是作为y和x的函数计算的。想象一下,我们想计算z关于x的梯度,但由于某种原因,希望将y视为一个常数,并且只考虑到x在y被计算后发挥的作用。
这里可以分离y来返回一个新变量u,该变量与y具有相同的值,但丢弃计算图中如何计算y的任何信息。换句话说,梯度不会向后流经u到x。因此,下面的反向传播函数计算\(z=u*x\)关于x的偏导数,同时将u作为常数处理,而不是\(z=x*x*x\)关于x的偏导数。
x.grad.zero_()
y = x * x
u = y.detach()
z = u * x
z.sum().backward()
x.grad == u,u
(tensor([True, True, True, True]), tensor([0., 1., 4., 9.]))
由于记录了y的计算结果,我们可以随后在y上调用反向传播,得到y=xx关于的x的导数,即2x。
x.grad.zero_()
y.sum().backward()
x.grad == 2 * x
tensor([True, True, True, True])
5.4 Python控制流的梯度计算
使用自动微分的一个好处是:即使构建函数的计算图需要通过Python控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度。在下面的代码中,while循环的迭代次数和if语句的结果都取决于输入a的值。
def f(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = 100 * b
return c
计算梯度。
a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()
我们现在可以分析上面定义的f函数。请注意,它在其输入a中是分段线性的。换言之,对于任何a,存在某个常量标量k,使得f(a)=k*a,其中k的值取决于输入a,因此可以用d/a验证梯度是否正确。
a.grad == d / a
tensor(True)
【pytorch学习】之自动微分的更多相关文章
- pytorch学习-AUTOGRAD: AUTOMATIC DIFFERENTIATION自动微分
参考:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autog ...
- PyTorch自动微分基本原理
序言:在训练一个神经网络时,梯度的计算是一个关键的步骤,它为神经网络的优化提供了关键数据.但是在面临复杂神经网络的时候导数的计算就成为一个难题,要求人们解出复杂.高维的方程是不现实的.这就是自动微分出 ...
- PyTorch 自动微分示例
PyTorch 自动微分示例 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后训练第一个神经网络.autograd 软件包为 Tensors 上的所有算子提供自动微分 ...
- PyTorch 自动微分
PyTorch 自动微分 autograd 包是 PyTorch 中所有神经网络的核心.首先简要地介绍,然后将会去训练的第一个神经网络.该 autograd 软件包为 Tensors 上的所有操作提供 ...
- 自动微分(AD)学习笔记
1.自动微分(AD) 作者:李济深链接:https://www.zhihu.com/question/48356514/answer/125175491来源:知乎著作权归作者所有.商业转载请联系作者获 ...
- Pytorch学习笔记(一)---- 基础语法
书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...
- 【pytorch】pytorch学习笔记(一)
原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...
- Pytorch学习笔记(一)——简介
一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...
- PyTorch 学习
PyTorch torch.autograd模块 深度学习的算法本质上是通过反向传播求导数, PyTorch的autograd模块实现了此功能, 在Tensor上的所有操作, autograd都会为它 ...
- MindSpore:自动微分
MindSpore:自动微分 作为一款「全场景 AI 框架」,MindSpore 是人工智能解决方案的重要组成部分,与 TensorFlow.PyTorch.PaddlePaddle 等流行深度学习框 ...
随机推荐
- day06-Java流程控制
Java流程控制 1.用户交互Scanner java.util.Scanner是Java5的新特征,我们可以通过Scannner类来获取用户的输入. 基本语法: Scanner s = new Sc ...
- 关于初始化page入参的设计思路
最近在重构老的代码,在写的过程中发现之前的逻辑如果遇到没有入参pageNo会Npe,于是乎我想找找公司项目有啥方式处理page入参的有两种如下 使用三元表达式直接判断是否null,然后赋值 使用map ...
- LoggerMessageAttribute 高性能的日志记录
.NET 6 引入了 LoggerMessageAttribute 类型. 使用时,它会以source-generators的方式生成高性能的日志记录 API. source-generators可在 ...
- PXE批量安装操作系统自动化
PXEz自动化 在PXE服务器操作: *yum -y install dhcp xinetd tftp tftp-server* *yum -y install system-config-kicks ...
- FLTK基于cmake编译以及使用(Windows、macOS以及Linux)
最近因为一些学习的原因,需要使用一款跨平台的轻量级的GUI+图像绘制 C/C++库.经过一番调研以后,最终从GTK+.FLTK中选出了FLTK,跨平台.够轻量.本文将在Windows.macOS以及L ...
- C++ Qt开发:QTcpSocket网络通信组件
Qt 是一个跨平台C++图形界面开发库,利用Qt可以快速开发跨平台窗体应用程序,在Qt中我们可以通过拖拽的方式将不同组件放到指定的位置,实现图形化开发极大的方便了开发效率,本章将重点介绍如何运用QTc ...
- 记录--Uni-app接入腾讯人脸核身
这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 人脸核身功能有多种接入方式,其中包含微信H5.微信小程序.APP.独立H5.PC端.API接入6种方式. 我们的产品是使用uni-ap ...
- 记录--通过Promise实现分批处理接口请求
这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 如何通过 Promise 实现百条接口请求? 实际项目中遇到需要批量发起上百条接口请求怎么办? 最新案例代码在此!点击看看 前言 不知你项 ...
- 记录--vue刷新当前页面
这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 背景 项目当中如果做新增/修改/删除等等操作通常情况下都需要刷新数据或者刷新当前页面. 思路 (1)如果页面简单,调用接口刷新数据即可. ...
- 记录mysql order by xxx limit xxx数据重复的问题
引用 http://vsalw.com/9768.html 记录mysql排序字段有重复值,分页数据错乱问题,下面2个sql 除了分页limit外,其他都一样, 但是第三页的结果却包含部分第二页的数据 ...