由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。

钩子函数包括Variable的钩子和nn.Module钩子,用法相似。

一、register_hook

import torch
from torch.autograd import Variable grad_list = [] def print_grad(grad):
grad_list.append(grad) x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data print(grad_list)
[Variable containing:
1.5653
3.5175
[torch.FloatTensor of size 2x1]
]

二、register_forward_hook & register_backward_hook

这两个函数的功能类似于variable函数的register_hook,可在module前向传播或反向传播时注册钩子。

每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None,而反向传播则具有如下形式:hook(module, grad_input, grad_output) -> Tensor or None

钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。下面给出实现的伪代码。

model = VGG()
features = t.Tensor()
def hook(module, input, output):
'''把这层的输出拷贝到features中'''
features.copy_(output.data) handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# 用完hook后删除
handle.remove()

测试LeNet网络

import torch as t
import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10) def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

先模拟一下单次的向前传播,

net = LeNet()
img = t.autograd.Variable((t.arange(32*32*1).view(1,1,32,32)))
net(img)
Variable containing:

Columns 0 to 7
27.6373 -13.4590 23.0988 -16.4491 -8.8454 -15.6934 -4.8512 1.3490 Columns 8 to 9
3.7801 -15.9396
[torch.FloatTensor of size 1x10]

仿照上面示意,进行钩子注册,获取第一卷积层输出结果,

def hook(module, inputdata, output):
'''把这层的输出拷贝到features中'''
print(output.data) handle = net.conv2.register_forward_hook(hook)
net(img)
# 用完hook后删除
handle.remove()

……

……

[torch.FloatTensor of size 1x16x10x10]

看看hook能识别什么

import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable def for_hook(module, input, output):
print(module)
for val in input:
print("input val:",val)
for out_val in output:
print("output val:", out_val) class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x): return x+1 model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()

可见对于目标层,其输入输出都可以获取到,

Model(
)
input val: Variable containing:
1
[torch.FloatTensor of size 1] output val: Variable containing:
2
[torch.FloatTensor of size 1] Variable containing:
2
[torch.FloatTensor of size 1]

『PyTorch』第十六弹_hook技术的更多相关文章

  1. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  2. 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  3. 『PyTorch』第十四弹_torch.nn.Module类属性

    nn.Module基类的构造函数: def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 『MXNet』第十二弹_再谈新建计算节点

    上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...

  6. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  7. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  8. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  9. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

随机推荐

  1. CSS前叙

    1 css是什么?层叠样式表,修饰网页结构2 如何去使用css?a.在html网页中,加入一个style标签,在这个style标签里面写css代码b.可以直接把style里面的代码放到一个单独的文件中 ...

  2. Windows下使用MakeFile(Mingw)文件

    下面是我基于<C++GUI QT4编程(第二版)> 2.3节快速设计对话框编写例子地址: https://files.cnblogs.com/files/senior-engineer/g ...

  3. linux系统启动顺序及init模式

    磁盘的第一个扇区(512bytes)主要记录了两个重要信息: 主引导分区MBR:master boot record,安装引导加载程序的地方,446bytes 分区表:partition table: ...

  4. 如何用tomcat发布自己的Java项目

    如何用tomcat发布自己的Java项目 tomcat是什么?它是一个免费的开放源代码的Web 应用服务器,属于轻量级应用服务器.我们用Java开发出来的web项目,通过tomcat发布出来,别人就可 ...

  5. 05: 配置yum源

    1.1 将镜像复制到本地创建yum源 1.将准备好的系统镜像放到指定的目录,本次目录指定在:/dawnfs/sourcecode 2.创建挂载目录:mkdir /mnt/yum 3.挂载镜像: mou ...

  6. 前向算法Python实现

    前言 这里的前向算法与神经网络里的前向传播算法没有任何联系...这里的前向算法是自然语言处理领域隐马尔可夫模型第一个基本问题的算法. 前向算法是什么? 这里用一个海藻的例子来描述前向算法是什么.网上有 ...

  7. Python3基础 else 循环完整结束才执行

             Python : 3.7.0          OS : Ubuntu 18.04.1 LTS         IDE : PyCharm 2018.2.4       Conda ...

  8. 三种常用的js数组去重方法

    第一种是比较常规的方法 思路: 1.构建一个新的数组存放结果 2.for循环中每次从原数组中取出一个元素,用这个元素循环与结果数组对比 3.若结果数组中没有该元素,则存到结果数组中 Array.pro ...

  9. ubuntu下安装mkfs.jffs工具

    一.环境 Os: ubuntu 16.04 二.安装 2.1安装依赖库 sudo apt install zlib1g-dev liblzo2-dev uuid-dev 2.2编译安装mtd-util ...

  10. json获取元素数量

    var keleyijson={"plug1":"myslider","plug2":"zonemenu"} funct ...