『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数
一、封装新的PyTorch函数
继承Function类
forward:输入Variable->中间计算Tensor->输出Variable
backward:均使用Variable
线性映射
from torch.autograd import Function
class MultiplyAdd(Function): # <----- 类需要继承Function类
@staticmethod # <-----forward和backward都是静态方法
def forward(ctx, w, x, b): # <-----ctx作为内部参数在前向反向传播中协调
print('type in forward',type(x))
ctx.save_for_backward(w,x) # <-----ctx保存参数
output = w * x + b
return output # <-----forward输入参数和backward输出参数必须一一对应
@staticmethod # <-----forward和backward都是静态方法
def backward(ctx, grad_output): # <-----ctx作为内部参数在前向反向传播中协调
w,x = ctx.saved_variables # <-----ctx读取参数
print('type in backward',type(x))
grad_w = grad_output * x
grad_x = grad_output * w
grad_b = grad_output * 1
return grad_w, grad_x, grad_b # <-----backward输入参数和forward输出参数必须一一对应
调用方法一
类名.apply(参数)
输出变量.backward()
import torch as t
from torch.autograd import Variable as V x = V(t.ones(1))
w = V(t.rand(1), requires_grad = True)
b = V(t.rand(1), requires_grad = True)
print('开始前向传播')
z=MultiplyAdd.apply(w, x, b) # <-----forward
print('开始反向传播')
z.backward() # 等效 # <-----backward # x不需要求导,中间过程还是会计算它的导数,但随后被清空
print(x.grad, w.grad, b.grad)
开始前向传播
type in forward <class 'torch.FloatTensor'>
开始反向传播
type in backward <class 'torch.autograd.variable.Variable'>(None,
Variable containing:
1
[torch.FloatTensor of size 1],
Variable containing:
1
[torch.FloatTensor of size 1])
调用方法二
类名.apply(参数)
输出变量.grad_fn.apply()
x = V(t.ones(1))
w = V(t.rand(1), requires_grad = True)
b = V(t.rand(1), requires_grad = True)
print('开始前向传播')
z=MultiplyAdd.apply(w,x,b) # <-----forward
print('开始反向传播') # 调用MultiplyAdd.backward
# 会自动输出grad_w, grad_x, grad_b
z.grad_fn.apply(V(t.ones(1))) # <-----backward,在计算中间输出,buffer并未清空,所以x的梯度不是None
开始前向传播
type in forward <class 'torch.FloatTensor'>
开始反向传播
type in backward <class 'torch.autograd.variable.Variable'>(Variable containing:
1
[torch.FloatTensor of size 1], Variable containing:
0.7655
[torch.FloatTensor of size 1], Variable containing:
1
[torch.FloatTensor of size 1])
之所以forward函数的输入是tensor,而backward函数的输入是variable,是为了实现高阶求导。backward函数的输入输出虽然是variable,但在实际使用时autograd.Function会将输入variable提取为tensor,并将计算结果的tensor封装成variable返回。在backward函数中,之所以也要对variable进行操作,是为了能够计算梯度的梯度(backward of backward)。下面举例说明,有关torch.autograd.grad的更详细使用请参照文档。
二、高阶导数
grad_x =t.autograd.grad(y, x, create_graph=True)
grad_grad_x = t.autograd.grad(grad_x[0],x)
x = V(t.Tensor([5]), requires_grad=True)
y = x ** 2 grad_x = t.autograd.grad(y, x, create_graph=True)
print(grad_x) # dy/dx = 2 * x grad_grad_x = t.autograd.grad(grad_x[0],x)
print(grad_grad_x) # 二阶导数 d(2x)/dx = 2
(Variable containing:
10
[torch.FloatTensor of size 1],)(Variable containing:
2
[torch.FloatTensor of size 1],)
三、梯度检查
t.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)
此外在实现了自己的Function之后,还可以使用gradcheck函数来检测实现是否正确。gradcheck通过数值逼近来计算梯度,可能具有一定的误差,通过控制eps的大小可以控制容忍的误差。
class Sigmoid(Function):
@staticmethod
def forward(ctx, x):
output = 1 / (1 + t.exp(-x))
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_variables
grad_x = output * (1 - output) * grad_output
return grad_x
# 采用数值逼近方式检验计算梯度的公式对不对
test_input = V(t.randn(3,4), requires_grad=True)
t.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)
True
测试效率,
def f_sigmoid(x):
y = Sigmoid.apply(x)
y.backward(t.ones(x.size())) def f_naive(x):
y = 1/(1 + t.exp(-x))
y.backward(t.ones(x.size())) def f_th(x):
y = t.sigmoid(x)
y.backward(t.ones(x.size())) x=V(t.randn(100, 100), requires_grad=True)
%timeit -n 100 f_sigmoid(x)
%timeit -n 100 f_naive(x)
%timeit -n 100 f_th(x)
实际测试结果,
245 µs ± 70.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
211 µs ± 23.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
219 µs ± 36.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
书中说的结果,
100 loops, best of 3: 320 µs per loop
100 loops, best of 3: 588 µs per loop
100 loops, best of 3: 271 µs per loop
很奇怪,我的结果竟然是:简单堆砌<官方封装<自己封装……不过还是引用一下书中的结论吧:
显然
f_sigmoid要比单纯利用autograd加减和乘方操作实现的函数快不少,因为f_sigmoid的backward优化了反向传播的过程。另外可以看出系统实现的buildin接口(t.sigmoid)更快。
『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数的更多相关文章
- 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法
在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...
- 『PyTorch』第五弹_深入理解autograd_中:Variable梯度探究
查看非叶节点梯度的两种方法 在反向传播过程中非叶子节点的导数计算完之后即被清空.若想查看这些变量的梯度,有两种方法: 使用autograd.grad函数 使用hook autograd.grad和ho ...
- 『PyTorch』第五弹_深入理解Tensor对象_中下:数学计算以及numpy比较_&_广播原理简介
一.简单数学操作 1.逐元素操作 t.clamp(a,min=2,max=4)近似于tf.clip_by_value(A, min, max),修剪值域. a = t.arange(0,6).view ...
- 『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor
Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arang ...
- 『PyTorch』第五弹_深入理解Tensor对象_中上:索引
一.普通索引 示例 a = t.Tensor(4,5) print(a) print(a[0:1,:2]) print(a[0,:2]) # 注意和前一种索引出来的值相同,shape不同 print( ...
- 『PyTorch』第五弹_深入理解Tensor对象_上:初始化以及尺寸调整
一.创建Tensor 特殊方法: t.arange(1,6,2)t.linspace(1,10,3)t.randn(2,3) # 标准分布,*size t.randperm(5) # 随机排序,从0到 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
随机推荐
- python3.4学习笔记(十二) python正则表达式的使用,使用pyspider匹配输出带.html结尾的URL
python3.4学习笔记(十二) python正则表达式的使用,使用pyspider匹配输出带.html结尾的URL实战例子:使用pyspider匹配输出带.html结尾的URL:@config(a ...
- 比特币、莱特币钱包下载和把数据迁移到C盘以外其他盘
比特币是目前最热门和价格最高的虚拟币,国内外多个平台可以进行交易,有些商家可以用比特币进行支付有些国家可以在ATM取款. Bitcoin-Qt就是最早的比特币客户端,构建了比特币的骨干网络,具有高度的 ...
- Js基础知识7-JavaScript所有内置对象属性和方法汇总
对象什么的,程序员可是有很多呢... JS三大对象 对象,是任何一个开发者都无法绕开和逃避的话题,她似乎有些深不可测,但如此伟大和巧妙的存在,一定值得你去摸索.发现.征服. 我们都知道,JavaScr ...
- 【运维技术】JENKINS管道部署容器化初探
目标服务器安装docker参考官方文档 https://docs.docker.com/install/linux/docker-ce/centos/ (可选)在目标服务器上安装docker私服 ht ...
- bzoj3505 / P3166 [CQOI2014]数三角形
P3166 [CQOI2014]数三角形 前置知识:某两个点$(x_{1},,y_{1}),(x_{2},y_{2})\quad (x_{1}<x_{2},y_{1}<y_{2})$所连成 ...
- loj10009 P1717 钓鱼
P1717 钓鱼 贪心+优先队列 先枚举最后走到哪个湖,然后用优先队列跑一遍贪心即可 #include<iostream> #include<cstdio> #include& ...
- Python3 异常: name 'basestring' is not defined
Python3 异常: name 'basestring' is not defined 问题分析: python3 里已经没有basestring 类型,用str代替了basestring : 解决 ...
- Eclipse启动Tomcat时,45秒超时解决方式
Eclipse启动Tomcat时,45秒超时解决方式 在Eclipse中启动Tomcat服务器时,经常由于系统初始化项目多,导致出现45秒超时的Tomcat服务器启动错误. 一般通过找到XML配置文 ...
- JDBC批量插入数据优化,使用addBatch和executeBatch
JDBC批量插入数据优化,使用addBatch和executeBatch SQL的批量插入的问题,如果来个for循环,执行上万次,肯定会很慢,那么,如何去优化呢? 解决方案:用 preparedSta ...
- c++中的字符集与中文
就非西欧字符而言,比如中国以及港澳台,在任何编程语言的开发中都不得不考虑字符集及其表示.在c++中,对于超过1个字节的字符,有两种方式可以表示: 1.多字节表示法:通常用于存储(空间效率考虑). 2. ...