.简介
torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现
Variable和tensor的区别和联系
Variable是篮子,而tensor是鸡蛋,鸡蛋应该放在篮子里才能方便拿走(定义variable时一个参数就是tensor)
Variable这个篮子里除了装了tensor外还有requires_grad参数,表示是否需要对其求导,默认为False
Variable这个篮子呢,自身有一些属性
比如grad,梯度variable.grad是d(y)/d(variable)保存的是变量y对variable变量的梯度值,如果requires_grad参数为False,所以variable.grad返回值为None,如果为True,返回值就为对variable的梯度值
比如grad_fn,对于用户自己创建的变量(Variable())grad_fn是为none的,也就是不能调用backward函数,但对于由计算生成的变量,如果存在一个生成中间变量的requires_grad为true,那其的grad_fn不为none,反则为none
比如data,这个就很简单,这个属性就是装的鸡蛋(tensor)
Varibale包含三个属性:
data:存储了Tensor,是本体的数据
grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
grad_fn:指向Function对象,用于反向传播的梯度计算之用
 
具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。
那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式)。
如果用Variable计算的话,那返回的也是一个同类型的Variable。
【tensor 是一个多维矩阵】
用一个例子说明,Variable的定义:
import torch
from torch.autograd import Variable # torch 中 Variable 模块
tensor = torch.FloatTensor([[,],[,]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True) print(tensor)
""" [torch.FloatTensor of size 2x2]
""" print(variable)
"""
Variable containing: [torch.FloatTensor of size 2x2]
"""

 注:tensor不能反向传播,variable可以反向传播

二、Variable求梯度

Variable计算时,它会逐渐地生成计算图。这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力。

v_out.backward()    # 模拟 v_out 的误差反向传递

print(variable.grad)    # 初始 Variable 的梯度
'''
0.5000 1.0000
1.5000 2.0000
'''

 

三、获取Variable里面的数据

直接print(Variable) 只会输出Variable形式的数据,在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。

print(variable)     #  Variable 形式
"""
Variable containing: [torch.FloatTensor of size 2x2]
""" print(variable.data) # 将variable形式转为tensor 形式
""" [torch.FloatTensor of size 2x2]
"""
print(variable.data.numpy()) # numpy 形式
"""
[[ . .]
[ . .]]
"""

四:关于require_grad对variable的作用

代码一:

import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(,),requires_grad = False)
temp = Variable(torch.zeros(,),requires_grad = True)
y = x + temp +
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)

none

(因为requires_grad=False)

代码二:

import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(,),requires_grad = False)
temp = Variable(torch.zeros(,),requires_grad = True)
y = x + temp +
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(temp.grad) # d(y)/d(temp)

tensor([[0.2500, 0.2500],
        [0.2500, 0.2500]])

代码三:

import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(,),requires_grad = False)
temp = Variable(torch.zeros(,),requires_grad = True)
y = x +
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)
Traceback (most recent call last):
  File "path", line 12, in <module>
    y.backward()
(报错了,因为生成变量y的中间变量只有x,而x的requires_grad是False,所以y的grad_fn是none)
 
代码四:
import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(,),requires_grad = False)
temp = Variable(torch.zeros(,),requires_grad = True)
y = x +
y = y.mean() #求平均数
#y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(y.grad_fn) # d(y)/d(x)

none

五:grad属性

在每次backward后,grad值是会累加的,所以利用BP算法,每次迭代是需要将grad清零的。

x.grad.data.zero_()

(in-place操作需要加上_,即zero_)

六:扩展
在PyTorch中计算图的特点总结如下:

autograd根据用户对Variable的操作来构建其计算图。
requires_grad
variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。
volatile
variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。
retain_graph
多次反向传播(多层监督)时,梯度是累加的。一般来说,单次反向传播后,计算图会free掉,也就是反向传播的中间缓存会被清空【这就是动态度的特点】。为进行多次反向传播需指定retain_graph=True来保存这些缓存。
.backward()
反向传播,求解Variable的梯度。放在中间缓存中。

莫烦pytorch学习笔记(二)——variable的更多相关文章

  1. 莫烦 - Pytorch学习笔记 [ 二 ] CNN ( 1 )

    CNN原理和结构 观点提出 关于照片的三种观点引出了CNN的作用. 局部性:某一特征只出现在一张image的局部位置中. 相同性: 同一特征重复出现.例如鸟的羽毛. 不变性:subsampling下图 ...

  2. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  3. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  4. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  5. 莫烦PyTorch学习笔记(三)——激励函数

    1. sigmod函数 函数公式和图表如下图     在sigmod函数中我们可以看到,其输出是在(0,1)这个开区间内,这点很有意思,可以联想到概率,但是严格意义上讲,不要当成概率.sigmod函数 ...

  6. 莫烦 - Pytorch学习笔记 [ 一 ]

    1. Numpy VS Torch #相互转换 np_data = torch_data.numpy() torch_data = torch.from_numpy(np_data) #abs dat ...

  7. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  8. 莫烦PyTorch学习笔记(五)——分类

    import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.p ...

  9. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

随机推荐

  1. Deep Dive into Neo4j 3.5 Full Text Search

    In this blog we will go over the Full Text Search capabilities available in the latest major release ...

  2. maven中报错Missing artifact com.oracle:ojdbc14:jar:10.2.0.4.0

    在检索完工程后报错Missing artifact com.oracle:ojdbc14:jar:10.2.0.4.0. 由于oracle的ojdbc收费,所以在maven项目导入时没有ojdbc14 ...

  3. Dede没见过的漏洞

    payload:plus/search.php?keyword=xxx&arrs1[]=99&arrs1[]=102&arrs1[]=103&arrs1[]=95&am ...

  4. CSIC_716_20191031【计算机的组成】

    引言 什么是编程语言 语言是人与人之间的沟通介质之一,编程语言是人与机器(包括计算机)之间沟通的介质. 一个完整的计算机系统主要包括     应用程序  .操作系统  和硬件   等. 计算机三大核心 ...

  5. mysql 数据库 内容的增删改查

    /*所有字段插入值*//*注意插入值数目要与字段值一致*/INSERT INTO student VALUES(1,'熊大','123','2019-10-18',1200);INSERT INTO ...

  6. CSS3——伸缩布局及应用

    CSS3在布局方面做了非常大的改进,使得我们对块级元素的布局排列变得十分灵活,适应性非常强,其强大的伸缩性,在响应式开中可以发挥极大的作用. 主轴:Flex容器的主轴主要用来配置Flex项目,默认是水 ...

  7. 重磅发布: 阿里云WAF日志实时分析上线 (含视频)

    摘要: 阿里云WAF与日志服务打通,对外开发Web访问与攻击日志.提供近实时的网站具体的日志自动采集存储.并提供基于日志服务的查询分析.报表报警.下游计算对接与投递的能力. 背景 Web攻击形势 互联 ...

  8. linux watch命令查看网卡流量

    watch命令可以反复的执行一个命令,默认时间间隔为2秒钟.TX是发送(transport),RX是接收(receive)RX bytes:总下行流量TX bytes:总上行流量 可以每隔两秒监视网络 ...

  9. python和go对比字符串的链式处理

    一.什么是链式处理 对数据的操作进行多步骤的处理称为链式处理,链式处理器是一种常见的编程设计,链式处理的开发思想将数据和操作拆分,解耦,让开发者可以根据自己的技术优势和需求,进行系统开发,同时将自己的 ...

  10. POJ-1976-A Mini Locomotive-dp

    A train has a locomotive that pulls the train with its many passenger coaches. If the locomotive bre ...