pytorch hook学习

register_hook

  1. import torch
  2. x = torch.Tensor([0,1,2,3]).requires_grad_()
  3. y = torch.Tensor([4,5,6,7]).requires_grad_()
  4. w = torch.Tensor([1,2,3,4]).requires_grad_()
  5. z = x+y;
  6. o = w.matmul(z) # o = w(x+y) 中间变量z
  7. o.backward()
  8. print(x.grad,y.grad,z.grad,w.grad,o.grad)

这里的o和z都是中间变量,不是通过指定值来定义的变量,所以是中间变量,所以pytorch并不存储这些变量的梯度。

对于中间变量z,hook的使用方式为: z.register_hook(hook_fn),其中 hook_fn为一个用户自定义的函数,其签名为:hook_fn(grad) -> Tensor or None。

它的输入为变量 z 的梯度,输出为一个 Tensor 或者是 None (None 一般用于直接打印梯度)。反向传播时,梯度传播到变量 z,再继续向前传播之前,将会传入 hook_fn。如果 hook_fn的返回值是 None,那么梯度将不改变,继续向前传播,如果 hook_fn的返回值是 Tensor 类型,则该 Tensor 将取代 z 原有的梯度,向前传播。

  1. import torch
  2. x = torch.Tensor([0,1,2,3]).requires_grad_()
  3. y = torch.Tensor([4,5,6,7]).requires_grad_()
  4. w = torch.Tensor([1,2,3,4]).requires_grad_()
  5. z = x+y;
  6. def hook_fn(grad):
  7. print(grad)
  8. return None
  9. z.register_hook(hook_fn)
  10. o = w.matmul(z) # o = w(x+y) 中间变量z
  11. o.backward()
  12. print(x.grad,y.grad,w.grad,z.grad,o.grad)

register_forward_hook

register_forward_hook的作用是获取前向传播过程中,各个网络模块的输入和输出。对于模块 module,其使用方式为:module.register_forward_hook(hook_fn) 。其中 hook_fn的签名为:

  1. hook_fn(module, input, output) -> None

eg

  1. import torch
  2. from torch import nn
  3. class Model(nn.Module):
  4. def __init__(self):
  5. super(Model,self).__init__()
  6. self.fc1 = nn.Linear(3,4) # WT * X + bias
  7. self.relu1 = nn.ReLU()
  8. self.fc2 = nn.Linear(4,1)
  9. self.init()
  10. def init(self):
  11. with torch.no_grad():
  12. # WT * X + bias,所以W为4*3的矩阵,bias为1*4
  13. self.fc1.weight = torch.nn.Parameter(
  14. torch.Tensor([[1., 2., 3.],
  15. [-4., -5., -6.],
  16. [7., 8., 9.],
  17. [-10., -11., -12.]]))
  18. self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
  19. self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1.0, 2.0, 3.0, 4.0]]))
  20. self.fc2.bias = torch.nn.Parameter(torch.Tensor([1.0]))
  21. def forward(self,x):
  22. o = self.fc1(x)
  23. o = self.relu1(o)
  24. o = self.fc2(o)
  25. return o
  26. def hook_fn_forward(module,input,output):
  27. print(module)
  28. print(input)
  29. print(output)
  30. model = Model()
  31. modules = model.named_children()
  32. '''
  33. named_children()
  34. Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
  35. '''
  36. for name,module in modules:
  37. # 这里的name就是自己定义的self.xx的xx。如上面的fc1,fc2.
  38. # module代指的就是fc1代表的module等等
  39. module.register_forward_hook(hook_fn_forward)
  40. x = torch.Tensor([[1.0,1.0,1.0]]).requires_grad_()
  41. o = model(x)
  42. o.backward()
  43. '''
  44. Linear(in_features=3, out_features=4, bias=True)
  45. (tensor([[1., 1., 1.]], requires_grad=True),)
  46. tensor([[ 7., -13., 27., -29.]], grad_fn=<AddmmBackward>)
  47. ReLU()
  48. (tensor([[ 7., -13., 27., -29.]], grad_fn=<AddmmBackward>),)
  49. tensor([[ 7., 0., 27., 0.]], grad_fn=<ReluBackward0>)
  50. Linear(in_features=4, out_features=1, bias=True)
  51. (tensor([[ 7., 0., 27., 0.]], grad_fn=<ReluBackward0>),)
  52. tensor([[89.]], grad_fn=<AddmmBackward>)
  53. '''

register_backward_hook

理同前者。得到梯度值。

  1. hook_fn(module, grad_input, grad_output) -> Tensor or None

上面的代码forward全部替换为backward,结果为:

  1. '''
  2. Linear(in_features=4, out_features=1, bias=True)
  3. (tensor([1.]), tensor([[1., 2., 3., 4.]]), tensor([[ 7.],
  4. [ 0.],
  5. [27.],
  6. [ 0.]]))
  7. (tensor([[1.]]),)
  8. ReLU()
  9. (tensor([[1., 0., 3., 0.]]),)
  10. (tensor([[1., 2., 3., 4.]]),)
  11. Linear(in_features=3, out_features=4, bias=True)
  12. (tensor([1., 0., 3., 0.]), tensor([[22., 26., 30.]]), tensor([[1., 0., 3., 0.],
  13. [1., 0., 3., 0.],
  14. [1., 0., 3., 0.]]))
  15. (tensor([[1., 0., 3., 0.]]),)
  16. '''

register_backward_hook只能操作简单模块,而不能操作包含多个子模块的复杂模块。 如果对复杂模块用了 backward hook,那么我们只能得到该模块最后一次简单操作的梯度信息。

可以这么用,可以得到一个模块的梯度。

  1. class Mymodel(nn.Module):
  2. ......
  3. model = Mymodel()
  4. model.register_backward_hook(hook_fn_backward)

[torch] pytorch hook学习的更多相关文章

  1. pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...

  2. [pytorch] PyTorch Hook

      PyTorch Hook¶ 为什么要引入hook? -> hook可以做什么? 都有哪些hook? 如何使用hook?   1. 为什么引入hook?¶ 参考:Pytorch中autogra ...

  3. 【pytorch】学习笔记(三)-激励函数

    [pytorch]学习笔记-激励函数 学习自:莫烦python 什么是激励函数 一句话概括 Activation: 就是让神经网络可以描述非线性问题的步骤, 是神经网络变得更强大 1.激活函数是用来加 ...

  4. 【pytorch】学习笔记(二)- Variable

    [pytorch]学习笔记(二)- Variable 学习链接自莫烦python 什么是Variable Variable就好像一个篮子,里面装着鸡蛋(Torch 的 Tensor),里面的鸡蛋数不断 ...

  5. PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类

    迁移学习的两个主要场景 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练 下面修改 ...

  6. PyTorch深度学习实践——反向传播

    反向传播 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 目录 反向传播 笔记 作业 笔记 在之前课程中介绍的线性 ...

  7. PyTorch深度学习实践——多分类问题

    多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...

  8. PyTorch深度学习实践——处理多维特征的输入

    处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...

  9. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

随机推荐

  1. MySQL数据库主从同步实战过程

       Linux系统MySQL数据库主从同步实战过程 安装环境说明 系统环境: [root@~]# cat /etc/redhat-release CentOS release 6.5 (Final) ...

  2. 雪花算法生成ID

    前言我们的数据库在设计时一般有两个ID,自增的id为主键,还有一个业务ID使用UUID生成.自增id在需要分表的情况下做为业务主键不太理想,所以我们增加了uuid作为业务ID,有了业务id仍然还存在自 ...

  3. inetd - 因特网“超级服务”

    总览 inetd - [ -d ] [ -q 队列长度 ] [ 配置文件名 ] 描述 inetd通常在系统启动时由/etc/rc.local引导.inetd会监听指定internet端口是否有连接要求 ...

  4. 接口测试断言详解(Jmeter)

    接口测试是目前最主流的自动化测试手段,它向服务器发送请求,接收和解析响应结果,通过验证响应报文是否满足需求规约来验证系统逻辑正确性.接口的响应类型通过Content-Type指定,常见的响应类型有: ...

  5. java知识

    DiskFileUploadhttps://blog.csdn.net/FightingITPanda/article/details/79742631 import java.util.ArrayL ...

  6. Elasticsearch:hanlp 中文分词器

    HanLP 中文分词器是一个开源的分词器,是专为Elasticsearch而设计的.它是基于HanLP,并提供了HanLP中大部分的分词方式.它的源码位于: https://github.com/Ke ...

  7. 【转】encodeURI和decodeURI方法

    为什么要两次调用encodeURI来解决乱码问题 https://blog.csdn.net/howlaa/article/details/12834595 请注意 encodeURIComponen ...

  8. Jenkins自动打包并部署(以java -jar形势运行)

    1.打包 与平常maven项目打包一致,不再赘述 2.杀死原有进程 通过 pid=`ps -ef|grep $APP_NAME|grep -v grep|awk '{print $2}' ` 获取当前 ...

  9. [LOJ 6704] 健身计划

    问题描述 九条可怜是一个肥胖的女孩. 她最近长胖了,她想要通过健身达到减肥的目的,于是她决定每天做n次仰卧起坐以达到健身的目的. 她可以将这n次动作分为若干组完成,每一次完成ai次仰卧起坐,每做完一次 ...

  10. 【leetcode】1213.Intersection of Three Sorted Arrays

    题目如下: Given three integer arrays arr1, arr2 and arr3 sorted in strictly increasing order, return a s ...