【PyTorch深度学习60分钟快速入门 】Part2:Autograd自动化微分
在PyTorch中,集中于所有神经网络的是autograd
包。首先,我们简要地看一下此工具包,然后我们将训练第一个神经网络。
autograd
包为张量的所有操作提供了自动微分。它是一个运行式定义的框架,这意味着你的后向传播是由你的代码运行方式来定义的,并且每一个迭代都可以是不同的。
下面,让我们使用一些更简单的术语和例子来解释这个问题。
0x01 变量(Variable)
autograd.Variable
是autograd
包的核心类,它封装了一个张量,并支持几乎所有在该张量上定义的操作。一旦完成了你的计算,你可以调用.backward()
,它会自动计算所有梯度。
你可以通过.data
属性访问原始的张量,而梯度w.r.t.这个变量被累积到.grad
。
还有一个类对于autograd的实现非常重要——一个函数。
变量和函数是相互联系的,并建立一个非循环图,它编码了计算的一个完整历史。每个变量都有一个.grad_fn
属性,该属性引用了一个创建了该变量的函数(除了由用户创建的变量之外,它们的grad_fn
是None
)。
如果你想计算导数,你可以在一个变量上调用.backward()
。如果变量是一个标量(也就是说它包含一个元素数据),那么你不需要为backward()
指定任何参数,但是如果它有更多元素,那么你需要指定一个grad_output
参数,该参数是一个匹配形状的张量。
import torch
from torch.autograd import Variable
创建一个变量:
x = Variable(torch.ones(2, 2), requires_grad=True)
print(x)
输出结果:
Variable containing:
1 1
1 1
[torch.FloatTensor of size 2x2]
做一个变量操作:
y = x + 2
print(y)
输出结果:
Variable containing:
3 3
3 3
[torch.FloatTensor of size 2x2]
y
是由于操作而创建的,所以它有一个grad_fn
。
print(y.grad_fn)
输出结果:
<AddBackward0 object at 0x7ff91b4f0908>
对y
做更多操作:
z = y * y * 3
out = z.mean()
print(z, out)
输出结果:
Variable containing:
27 27
27 27
[torch.FloatTensor of size 2x2]
Variable containing:
27
[torch.FloatTensor of size 1]
0x02 梯度(Gradients)
现在我们介绍后向传播,out.backward()
等效于做out.backward(torch.Tensor([1.0]))
out.backward()
打印梯度d(out)/dx:
print(x.grad)
输出结果:
Variable containing:
4.5000 4.5000
4.5000 4.5000
[torch.FloatTensor of size 2x2]
你应该得到一个元素为4.5的矩阵。我们将这个变量叫做"o"。此时,我们有:
你可以利用梯度做很多疯狂的事情!
x = torch.randn(3)
x = Variable(x, requires_grad=True)
y = x * 2
while y.data.norm() < 1000:
y = y * 2
print(y)
输出结果:
Variable containing:
164.9539
-511.5981
-1356.4794
[torch.FloatTensor of size 3]
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)
输出结果:
Variable containing:
204.8000
2048.0000
0.2048
[torch.FloatTensor of size 3]
扩展阅读: 变量和函数的文档在这里http://pytorch.org/docs/autograd
以上脚本的总运行时间为:0分0.009秒。
本文中所使用的Python代码:autograd_tutorial.py
【PyTorch深度学习60分钟快速入门 】Part2:Autograd自动化微分的更多相关文章
- 【PyTorch深度学习60分钟快速入门 】Part1:PyTorch是什么?
0x00 PyTorch是什么? PyTorch是一个基于Python的科学计算工具包,它主要面向两种场景: 用于替代NumPy,可以使用GPU的计算力 一种深度学习研究平台,可以提供最大的灵活性 ...
- 【PyTorch深度学习60分钟快速入门 】Part0:系列介绍
说明:本系列教程翻译自PyTorch官方教程<Deep Learning with PyTorch: A 60 Minute Blitz>,基于PyTorch 0.3.0.post4 ...
- 【PyTorch深度学习60分钟快速入门 】Part4:训练一个分类器
太棒啦!到目前为止,你已经了解了如何定义神经网络.计算损失,以及更新网络权重.不过,现在你可能会思考以下几个方面: 0x01 数据集 通常,当你需要处理图像.文本.音频或视频数据时,你可以使用标准 ...
- 【PyTorch深度学习60分钟快速入门 】Part5:数据并行化
在本节中,我们将学习如何利用DataParallel使用多个GPU. 在PyTorch中使用多个GPU非常容易,你可以使用下面代码将模型放在GPU上: model.gpu() 然后,你可以将所有张 ...
- 【PyTorch深度学习60分钟快速入门 】Part3:神经网络
神经网络可以通过使用torch.nn包来构建. 既然你已经了解了autograd,而nn依赖于autograd来定义模型并对其求微分.一个nn.Module包含多个网络层,以及一个返回输出的方法f ...
- pytorch深度学习60分钟闪电战
https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 官方推荐的一篇教程 Tensors #Construct a ...
- Vue.js 60 分钟快速入门
Vue.js 60 分钟快速入门 转载 作者:keepfool 链接:http://www.cnblogs.com/keepfool/p/5619070.html Vue.js介绍 Vue.js是当下 ...
- 不会几个框架,都不好意思说搞过前端: Vue.js - 60分钟快速入门
Vue.js——60分钟快速入门 Vue.js是当下很火的一个JavaScript MVVM库,它是以数据驱动和组件化的思想构建的.相比于Angular.js,Vue.js提供了更加简洁.更易于理 ...
- Vue.js——60分钟快速入门(转)
vue:Vue.js——60分钟快速入门 <!doctype html> <html lang="en"> <head> <meta ch ...
随机推荐
- mencached
是一个免费开源的,分布式内存对象缓存系统数据库. 是一个非关系型数据库形式,属于NOSQL NOT OLNY SQL ,不仅仅是关系数据库 它属于K V 存储 KEY VALUE 相对应的存储 KEY ...
- 安卓ViewStub用法
安卓ViewStub用法 在开发应用程序的时候,经常会遇到这样的情况,在运行时动态根据条件来决定显示哪个View或某个布局. 那么最通常的想法就是把可能用到的View都写在上面,先把它们的可见性都设为 ...
- android studio 关闭SVN关联
<?xml version="1.0" encoding="UTF-8"?> <project version="4"&g ...
- 18. pt-pmp
pt-pmp 是一个非常简单的工具,可以用来获取MySQL的堆栈信息.工具首先获取运行过程中的mysqld堆栈信息,然后将相似的线程进行汇总排序,根据调用频繁程度从高到低打印出来. 查看pt-pmp的 ...
- 锻造(forging)
--九校联考24OI__D1T1 题目背景 勇者虽然武力值很高,但在经历了多次战斗后,发现怪物越来越难打,于是开始思考是不是自己平时锻炼没到位,于是苦练一个月后发现--自己连一个史莱姆都打不过了. 勇 ...
- [Python] 怎么把HTML的报告转换为图片,利用无头浏览器
How to convert HTML Report to picture format in Email? So that we can see the automation report also ...
- Java集合:ConcurrentHashMap原理分析
集合是编程中最常用的数据结构.而谈到并发,几乎总是离不开集合这类高级数据结构的支持.比如两个线程需要同时访问一个中间临界区(Queue),比如常会用缓存作为外部文件的副本(HashMap).这篇文章主 ...
- UML类图中箭头和线条的含义和用法
UML类图中箭头和线条的含义和用法 在学习UML过程中,你经常会遇到UML类图关系,这里就向大家介绍一下UML箭头.线条代表的意义,相信通过本文的介绍你对UML中箭头.线条的意义有更明确的认识. AD ...
- Spring-Data-JPA @Query注解 Sort排序
当我们使用方法名称很难,达到预期的查询结果,就可以使用@Query进行查询,@Query是一种添加自定义查询的便利方式 (方法名称查询见http://blog.csdn.net/niugang0920 ...
- springsecurity 源码解读之 SecurityContext
在springsecurity 中,我们一般可以通过代码: SecurityContext securityContext = SecurityContextHolder.getContext(); ...