由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。

钩子函数包括Variable的钩子和nn.Module钩子,用法相似。

一、register_hook

import torch
from torch.autograd import Variable grad_list = [] def print_grad(grad):
grad_list.append(grad) x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data print(grad_list)
[Variable containing:
1.5653
3.5175
[torch.FloatTensor of size 2x1]
]

二、register_forward_hook & register_backward_hook

这两个函数的功能类似于variable函数的register_hook,可在module前向传播或反向传播时注册钩子。

每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None,而反向传播则具有如下形式:hook(module, grad_input, grad_output) -> Tensor or None

钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。下面给出实现的伪代码。

model = VGG()
features = t.Tensor()
def hook(module, input, output):
'''把这层的输出拷贝到features中'''
features.copy_(output.data) handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# 用完hook后删除
handle.remove()

测试LeNet网络

import torch as t
import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10) def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

先模拟一下单次的向前传播,

net = LeNet()
img = t.autograd.Variable((t.arange(32*32*1).view(1,1,32,32)))
net(img)
Variable containing:

Columns 0 to 7
27.6373 -13.4590 23.0988 -16.4491 -8.8454 -15.6934 -4.8512 1.3490 Columns 8 to 9
3.7801 -15.9396
[torch.FloatTensor of size 1x10]

仿照上面示意,进行钩子注册,获取第一卷积层输出结果,

def hook(module, inputdata, output):
'''把这层的输出拷贝到features中'''
print(output.data) handle = net.conv2.register_forward_hook(hook)
net(img)
# 用完hook后删除
handle.remove()

……

……

[torch.FloatTensor of size 1x16x10x10]

看看hook能识别什么

import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable def for_hook(module, input, output):
print(module)
for val in input:
print("input val:",val)
for out_val in output:
print("output val:", out_val) class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x): return x+1 model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()

可见对于目标层,其输入输出都可以获取到,

Model(
)
input val: Variable containing:
1
[torch.FloatTensor of size 1] output val: Variable containing:
2
[torch.FloatTensor of size 1] Variable containing:
2
[torch.FloatTensor of size 1]

『PyTorch』第十六弹_hook技术的更多相关文章

  1. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  2. 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  3. 『PyTorch』第十四弹_torch.nn.Module类属性

    nn.Module基类的构造函数: def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 『MXNet』第十二弹_再谈新建计算节点

    上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...

  6. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  7. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  8. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  9. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

随机推荐

  1. linux常用命令:iostat 命令

    Linux系统中的 iostat 是I/O statistics(输入/输出统计)的缩写,iostat工具将对系统的磁盘操作活动进行监视.它的特点是汇报磁盘活动统计情况,同时也会 汇报出CPU使用情况 ...

  2. python-自定义异常,with用法

    抛出异常 #coding=utf-8 def  exceptionTest(num): if num<0: print "if num<0" raise Excepti ...

  3. Linux服务器---基础设置

    Centos分辨率      virtualbox里新安装的Centos 7 的分辨率默认的应该是800*600. 如果是‘最小化安装’的Centos7 进入的就是命令模式 .如果安装的是带有GUI的 ...

  4. MySQL笔记(二)数据库对象的创建和管理

    学校用 sqlserver ,记录数据移植到 mysql 过程中的一些问题(对应数据类型,主键外键等). 索引: 查看数据的物理路径 查看表相关的信息(SHOW CREATE TABLE.DESC) ...

  5. OpenCV中HSV颜色模型及颜色分量范围

    HSV颜色模型 HSV(Hue, Saturation, Value)是根据颜色的直观特性由A. R. Smith在1978年创建的一种颜色空间, 也称六角锥体模型(Hexcone Model)..这 ...

  6. Python入门之python可变对象与不可变对象

    本文分为如下几个部分 概念 地址问题 作为函数参数 可变参数在类中使用 函数默认参数 类的实现上的差异 概念 可变对象与不可变对象的区别在于对象本身是否可变. python内置的一些类型中 可变对象: ...

  7. Mysql的基本语句

    Mysql的基本语句 1.查询当前数据库所有表名: -- 方案一: show tables; --方案二:jeesite为数据库 select table_name from information_ ...

  8. Android 实践项目开发二

    在地图开发中项目中,我这周主要完成的任务是和遇到的问题是以下几个方面. 1.在本次的项目中主要是利用百度地图的.jar包实现地图的定位与搜索功能,需要在百度地图开发中心网站取得 密钥,并下载相关.ja ...

  9. UVa 10891 Game of Sum - 动态规划

    因为数的总和一定,所以用一个人得分越高,那么另一个人的得分越低. 用$dp[i][j]$表示从$[i, j]$开始游戏,先手能够取得的最高分. 转移通过枚举取的数的个数$k$来转移.因为你希望先手得分 ...

  10. bzoj 2427 软件安装 - Tarjan - 树形动态规划

    题目描述 现在我们的手头有N个软件,对于一个软件i,它要占用Wi的磁盘空间,它的价值为Vi.我们希望从中选择一些软件安装到一台磁盘容量为M计算机上,使得这些软件的价值尽可能大(即Vi的和最大). 但是 ...