pytorch hook学习

register_hook

import torch
x = torch.Tensor([0,1,2,3]).requires_grad_()
y = torch.Tensor([4,5,6,7]).requires_grad_()
w = torch.Tensor([1,2,3,4]).requires_grad_()
z = x+y;
o = w.matmul(z) # o = w(x+y) 中间变量z
o.backward()
print(x.grad,y.grad,z.grad,w.grad,o.grad)

这里的o和z都是中间变量,不是通过指定值来定义的变量,所以是中间变量,所以pytorch并不存储这些变量的梯度。

对于中间变量z,hook的使用方式为: z.register_hook(hook_fn),其中 hook_fn为一个用户自定义的函数,其签名为:hook_fn(grad) -> Tensor or None。

它的输入为变量 z 的梯度,输出为一个 Tensor 或者是 None (None 一般用于直接打印梯度)。反向传播时,梯度传播到变量 z,再继续向前传播之前,将会传入 hook_fn。如果 hook_fn的返回值是 None,那么梯度将不改变,继续向前传播,如果 hook_fn的返回值是 Tensor 类型,则该 Tensor 将取代 z 原有的梯度,向前传播。

import torch
x = torch.Tensor([0,1,2,3]).requires_grad_()
y = torch.Tensor([4,5,6,7]).requires_grad_()
w = torch.Tensor([1,2,3,4]).requires_grad_()
z = x+y;
def hook_fn(grad):
print(grad)
return None z.register_hook(hook_fn)
o = w.matmul(z) # o = w(x+y) 中间变量z
o.backward()
print(x.grad,y.grad,w.grad,z.grad,o.grad)

register_forward_hook

register_forward_hook的作用是获取前向传播过程中,各个网络模块的输入和输出。对于模块 module,其使用方式为:module.register_forward_hook(hook_fn) 。其中 hook_fn的签名为:

hook_fn(module, input, output) -> None

eg

import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.fc1 = nn.Linear(3,4) # WT * X + bias
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(4,1)
self.init()
def init(self):
with torch.no_grad():
# WT * X + bias,所以W为4*3的矩阵,bias为1*4
self.fc1.weight = torch.nn.Parameter(
torch.Tensor([[1., 2., 3.],
[-4., -5., -6.],
[7., 8., 9.],
[-10., -11., -12.]]))
self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1.0, 2.0, 3.0, 4.0]]))
self.fc2.bias = torch.nn.Parameter(torch.Tensor([1.0])) def forward(self,x):
o = self.fc1(x)
o = self.relu1(o)
o = self.fc2(o)
return o
def hook_fn_forward(module,input,output):
print(module)
print(input)
print(output) model = Model()
modules = model.named_children()
'''
named_children()
Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
'''
for name,module in modules:
# 这里的name就是自己定义的self.xx的xx。如上面的fc1,fc2.
# module代指的就是fc1代表的module等等
module.register_forward_hook(hook_fn_forward)
x = torch.Tensor([[1.0,1.0,1.0]]).requires_grad_()
o = model(x)
o.backward()
'''
Linear(in_features=3, out_features=4, bias=True)
(tensor([[1., 1., 1.]], requires_grad=True),)
tensor([[ 7., -13., 27., -29.]], grad_fn=<AddmmBackward>)
ReLU()
(tensor([[ 7., -13., 27., -29.]], grad_fn=<AddmmBackward>),)
tensor([[ 7., 0., 27., 0.]], grad_fn=<ReluBackward0>)
Linear(in_features=4, out_features=1, bias=True)
(tensor([[ 7., 0., 27., 0.]], grad_fn=<ReluBackward0>),)
tensor([[89.]], grad_fn=<AddmmBackward>) '''

register_backward_hook

理同前者。得到梯度值。

hook_fn(module, grad_input, grad_output) -> Tensor or None

上面的代码forward全部替换为backward,结果为:

'''
Linear(in_features=4, out_features=1, bias=True)
(tensor([1.]), tensor([[1., 2., 3., 4.]]), tensor([[ 7.],
[ 0.],
[27.],
[ 0.]]))
(tensor([[1.]]),)
ReLU()
(tensor([[1., 0., 3., 0.]]),)
(tensor([[1., 2., 3., 4.]]),)
Linear(in_features=3, out_features=4, bias=True)
(tensor([1., 0., 3., 0.]), tensor([[22., 26., 30.]]), tensor([[1., 0., 3., 0.],
[1., 0., 3., 0.],
[1., 0., 3., 0.]]))
(tensor([[1., 0., 3., 0.]]),)
'''

register_backward_hook只能操作简单模块,而不能操作包含多个子模块的复杂模块。 如果对复杂模块用了 backward hook,那么我们只能得到该模块最后一次简单操作的梯度信息。

可以这么用,可以得到一个模块的梯度。

class Mymodel(nn.Module):
...... model = Mymodel()
model.register_backward_hook(hook_fn_backward)

[torch] pytorch hook学习的更多相关文章

  1. pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...

  2. [pytorch] PyTorch Hook

      PyTorch Hook¶ 为什么要引入hook? -> hook可以做什么? 都有哪些hook? 如何使用hook?   1. 为什么引入hook?¶ 参考:Pytorch中autogra ...

  3. 【pytorch】学习笔记(三)-激励函数

    [pytorch]学习笔记-激励函数 学习自:莫烦python 什么是激励函数 一句话概括 Activation: 就是让神经网络可以描述非线性问题的步骤, 是神经网络变得更强大 1.激活函数是用来加 ...

  4. 【pytorch】学习笔记(二)- Variable

    [pytorch]学习笔记(二)- Variable 学习链接自莫烦python 什么是Variable Variable就好像一个篮子,里面装着鸡蛋(Torch 的 Tensor),里面的鸡蛋数不断 ...

  5. PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类

    迁移学习的两个主要场景 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练 下面修改 ...

  6. PyTorch深度学习实践——反向传播

    反向传播 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 目录 反向传播 笔记 作业 笔记 在之前课程中介绍的线性 ...

  7. PyTorch深度学习实践——多分类问题

    多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...

  8. PyTorch深度学习实践——处理多维特征的输入

    处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...

  9. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

随机推荐

  1. Spring Boot任务(定时,异步,邮件)

    一.定时任务 开启定时任务(在Spring Boot项目主程序上添加如下注解) @EnableScheduling //开启定时任务的注解 创建定时任务(创建一个Service如下) @Service ...

  2. python 单引号、双引号和三引号混用

    单引号: 当单引号中存在单引号时,内部的单引号需要使用转义字符,要不然就会报错: 当单引号中存在双引号时,双引号可以不用加转义字符,默认双引号为普通的字符,反之亦然. 双引号: 当双引号中存在双引号时 ...

  3. linux下 设置php的环境变量 php: command not found

    在自己的根目录进行运行phpinfo();     查看php的根目录. 假如自己查询的目录是/www/wdlinux/apache_php-5.6.21/bin, 查询完成后,先进入linux目录查 ...

  4. dedecms织梦调用二级和三级分类标签

    dedecms调用二级.三级以及调用栏目所有子栏目 <!--频道分类具体内容开始--> <div class="channel_sort"> {dede:c ...

  5. Ubuntu利用ROS搭建手机移动网络摄像头(Android)

    所需设备 PC -> Ubuntu 16.04 - > ROS Kinetic Android系统手机 1.Android移动端APP下载安装 配置手机端:(一般默认即可RTSP) 2.源 ...

  6. 【POJ2486】Apple Tree

    题目大意:给定一棵 N 个节点的有根树,点有点权,边权均为1.现允许从根节点出发走 K 步,求可以经过的点权之和最大是多少. 题解:可以将点权看作是价值,将可以走的步数看作是重量,则转化成了一个树上背 ...

  7. Python之网络编程之concurrent.futures模块

    需要注意一下不能无限的开进程,不能无限的开线程最常用的就是开进程池,开线程池.其中回调函数非常重要回调函数其实可以作为一种编程思想,谁好了谁就去掉 只要你用并发,就会有锁的问题,但是你不能一直去自己加 ...

  8. Python之网路编程进程理论基础

    背景知识 顾名思义,进程即一个软件正在进行的过程.进程是对正在运行程序的一个抽象. 进程的概念起源于操作系统,是操作系统最核心的概念,也是操作系统提供的最古老的也是最重要的抽象概念之一.操作系统的其他 ...

  9. Synchronized 失效原因

    Synchronized 同步出现失效 Synchronized ,大家都知道这个是Java 提供的一种原子性内置锁,其实现原理是通过获取对象的监视器monitor进行来实现同步的,只有当线程获取到对 ...

  10. 【ZJOJ5186】【NOIP2017提高组模拟6.30】tty's home

    题目 分析 如果直接求方案数很麻烦. 但是,我们可以反过来做:先求出所有的方案数,在减去不包含的方案数. 由于所有的路径连在一起, 于是\(设f[i]表示以i为根的子树中,连接到i的方案数\) 则\( ...