pytorch hook使用
由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。
钩子函数包括Variable的钩子和nn.Module钩子,用法相似。
import torch
from torch.autograd import Variable grad_list = []
grad_listx = [] def print_grad(grad):
grad_list.append(grad) def print_gradx(grad):
grad_listx.append(grad) x = Variable(torch.randn(2, 1), requires_grad=True)
y = x*x + 2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
x.register_hook(print_gradx)
z.backward()
x.data -= lr * x.grad.data print("x.grad.data-------------")
print(x.grad.data) print("y-------------")
print(grad_list) print("x-------------")
print(grad_listx)
- 输出: 记录了y的梯度,然后x.data=记录x的梯度
/opt/conda/bin/python2.7 /root/rjw/pytorch_test/pytorch_exe03.py
x.grad.data------------- 32.3585
14.8162
[torch.FloatTensor of size 2x1] y-------------
[Variable containing:
7.1379
4.5970
[torch.FloatTensor of size 2x1]
]
x-------------
[Variable containing:
32.3585
14.8162
[torch.FloatTensor of size 2x1]
] Process finished with exit code 0
register_forward_hook & register_backward_hook
- 这两个函数的功能类似于variable函数的
register_hook,可在module前向传播或反向传播时注册钩子。每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None,而反向传播则具有如下形式:hook(module, grad_input, grad_output) -> Tensor or None。 - 钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。
- 『PyTorch』第十六弹_hook技术
pytorch hook使用的更多相关文章
- [pytorch] PyTorch Hook
PyTorch Hook¶ 为什么要引入hook? -> hook可以做什么? 都有哪些hook? 如何使用hook? 1. 为什么引入hook?¶ 参考:Pytorch中autogra ...
- [torch] pytorch hook学习
pytorch hook学习 register_hook import torch x = torch.Tensor([0,1,2,3]).requires_grad_() y = torch.Ten ...
- PyTorch之前向传播函数自动调用forward
参考:1. pytorch学习笔记(九):PyTorch结构介绍 2.pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解 3.Pytorch入门 ...
- [PyTorch 学习笔记] 5.2 Hook 函数与 CAM 算法
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_fmap_vis.py https://gi ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- pytorch Debug —交互式调试工具Pdb (ipdb是增强版的pdb)-1-在pytorch中使用
参考深度学习框架pytorch:入门和实践一书第六章 以深度学习框架PyTorch一书的学习-第六章-实战指南为前提 在pytorch中Debug pytorch作为一个动态图框架,与ipdb结合能为 ...
- 深度学习框架PyTorch一书的学习-第四章-神经网络工具箱nn
参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 本章介绍的nn模块是构建与autogr ...
- 深度学习框架PyTorch一书的学习-第三章-Tensor和autograd-2-autograd
参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 torch.autograd就是为了方 ...
- 『PyTorch × TensorFlow』第十七弹_ResNet快速实现
『TensorFlow』读书笔记_ResNet_V2 对比之前的复杂版本,这次的torch实现其实简单了不少,不过这和上面的代码实现逻辑过于复杂也有关系. 一.PyTorch实现 # Author : ...
随机推荐
- js缓存加密
1.访问A链接就以A链接的特定部分为密码盐,生成一个js跳转配置文件名 aojoweiojoiwjeiof2.PHP在生成js跳转文件名的时候,也是根据数据库中的跳转起始链接特定部分作为盐,生成的文件 ...
- Hadoop整理二(Hadoop分布式存储系统HDFS)
一.背景 当数据集的大小超过一台独立物理计算机的存储能力时,就有必要对它进行分区(partition) 并存储到若干台单独的计算机上.管理网络中跨多台计算机存储的文件系统称为分布式文件系统 (dist ...
- python之路【第十二篇】: MYSQL
一. 概述 Mysql是最流行的关系型数据库管理系统,在WEB应用方面MySQL是最好的RDBMS(Relational Database Management System:关系数据库管理系统)应用 ...
- CSU - 2056 a simple game
Description 这一天,小A和小B在玩一个游戏,他俩每人都有一个整数,然后两人轮流对他们的整数进行操作,每次在下列两个操作任选一个: (1)对整数进行翻转,如1234翻转成4321 ,1200 ...
- HDU - 1022 Train Problem I STL 压栈
Train Problem I Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others) ...
- shell run 屏蔽一切log
&命令: xxx >/dev/null 2>&1 & 屏蔽一切logxxx >/tmp/xxx.log 2 ...
- Xamarin 2017.10.9更新
Xamarin 2017.10.9更新 本次更新主要解决了一些bug.Visual Studio 2017升级到15.4获得新功能.Visual Studio 2015需要工具-选项-Xamarin ...
- Web2.0应用程序的7条原则
个人看好Web的发展潜力,本文字摘自<Collective Intelligence 实战> 网络是平台 使用传统许可模式软件的公司或用户必须运行软件.定期更新至最新版本,以及扩展它来满足 ...
- java设计模式(六)策略模式
适用于同一操作的不同行为,策略模式定义了一系列的算法,并将每一个算法封装起来,而且使它们可以相互替换,让算法独立于使用它的客户而独立变化,具体应用场景如第三方支付对接不同银行的算法. 要点:1)抽象策 ...
- hdu 4548 初始化+二分 *
题意:小明对数的研究比较热爱,一谈到数,脑子里就涌现出好多数的问题,今天,小明想考考你对素数的认识.问题是这样的:一个十进制数,如果是素数,而且它的各位数字和也是素数,则称之为“美素数”,如29,本身 ...