Pytorch Guided Backpropgation

Intro

guided backpropgation通过修改RELU的梯度反传,使得小于0的部分不反传,只传播大于0的部分,这样到第一个conv层的时候得到的梯度就是对后面relu激活起作用的梯度,这时候我们对这些梯度进行可视化,得到的就是对网络起作用的区域。(实际上可视化的是梯度)。

简单记一下。用到hook的神经网络可视化方法。

code

import torch
import torch.nn as nn
from torchvision import transforms,models
import re
from models.densenet import densenet121
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class Guided_Prop():
def __init__(self,model):
self.model = model
self.model.eval()
self.out_img = None
self.activation_maps = [] def register_hooks(self):
def register_first_layer_hook(module,grad_in,grad_out):
self.out_img = grad_in[0] #(b,c,h,w) -> (c,h,w)
def forward_hook_fn(module,input_feature,output_feature):
self.activation_maps.append(output_feature)
def backward_hook_fn(module,grad_in,grad_out):
grad = self.activation_maps.pop()
grad[grad > 0] = 1
g_positive = torch.clamp(grad_out[0],min = 0.)
result_grad = grad * g_positive
return (result_grad,) modules = list(self.model.features.named_children())
for name,module in modules:
if isinstance(module,nn.ReLU):
module.register_forward_hook(forward_hook_fn)
module.register_backward_hook(backward_hook_fn)
first_layer = modules[0][1]
first_layer.register_backward_hook(register_first_layer_hook) def visualize(self,input_image):
softmax = nn.Softmax(dim = 1)
idx_tensor = torch.tensor([float(i) for i in range(61)])
self.register_hooks()
self.model.zero_grad()
out = self.model(input_image) # [[b,n],[b,n],[b,n]]
yaw = softmax(out[0])
yaw = torch.sum(yaw * idx_tensor,dim = 1) * 3 - 90.
pitch = softmax(out[1])
pitch = torch.sum(pitch * idx_tensor,dim = 1) * 3 - 90.
roll = softmax(out[2])
roll = torch.sum(roll * idx_tensor,dim = 1) * 3 - 90. #print(yaw)
out = yaw + pitch + roll
out.backward()
result = self.out_img.data[0].permute(1,2,0) # chw -> hwc(opencv)
return result.numpy()
def normalize(I):
norm = (I-I.mean())/I.std()
norm = norm * 0.1
norm = norm + 0.5
norm = norm.clip(0, 1)
return norm
if __name__ == "__main__":
input_size = 224
model = densenet121(pretrained = False,num_classes = 61)
model.load_state_dict(torch.load("./ckpt/DenseNet/model_2692_.pkl")) img = Image.open("/media/xueaoru/其他/ML/head_pose_work/brick/head_and_heads/test/BIWI00009409_-17_+1_+17.png")
transform = transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
tensor = transform(img).unsqueeze(0).requires_grad_() viz = Guided_Prop(model) result = viz.visualize(tensor)
result = normalize(result) plt.imshow(result)
plt.show()

由于是多任务问题,所以直接拿结果反传,对于一般的分类问题,可以给定target来用gt用one-hot反传。

head pose estimation 的梯度可视化。

[NN] Guided Backpropgation 可视化的更多相关文章

  1. 吴裕雄 python神经网络 水果图片识别(3)

    import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...

  2. TF之NN:matplotlib动态演示深度学习之tensorflow将神经网络系统自动学习并优化修正并且将输出结果可视化—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

  3. MySQL 慢查询日志分析及可视化结果

    MySQL 慢查询日志分析及可视化结果 MySQL 慢查询日志分析 pt-query-digest分析慢查询日志 pt-query-digest --report slow.log 报告最近半个小时的 ...

  4. mininet之miniedit可视化操作

    Mininet 2.2.0之后的版本内置了一个mininet可视化工具miniedit,使用Mininet可视化界面方便了用户自定义拓扑创建,为不熟悉python脚本的使用者创造了更简单的环境,界面直 ...

  5. python之gui-tkinter可视化编辑界面 自动生成代码

    首先提供资源链接 http://pan.baidu.com/s/1kVLOrIn#list/path=%2F

  6. 学习TensorFlow,TensorBoard可视化网络结构和参数

    在学习深度网络框架的过程中,我们发现一个问题,就是如何输出各层网络参数,用于更好地理解,调试和优化网络?针对这个问题,TensorFlow开发了一个特别有用的可视化工具包:TensorBoard,既可 ...

  7. 数据分析之---Python可视化工具

    1. 数据分析基本流程 作为非专业的数据分析人员,在平时的工作中也会遇到一些任务:需要对大量进行分析,然后得出结果,解决问题. 所以了解基本的数据分析流程,数据分析手段对于提高工作效率还是非常有帮助的 ...

  8. AI - TensorFlow - 可视化工具TensorBoard

    TensorBoard TensorFlow自带的可视化工具,能够以直观的流程图的方式,清楚展示出整个神经网络的结构和框架,便于理解模型和发现问题. 可视化学习:https://www.tensorf ...

  9. 【TensorFlow篇】--Tensorflow框架可视化之Tensorboard

    一.前述 TensorBoard是tensorFlow中的可视化界面,可以清楚的看到数据的流向以及各种参数的变化,本文基于一个案例讲解TensorBoard的用法. 二.代码 设计一个MLP多层神经网 ...

随机推荐

  1. java中的集合类详情解析以及集合和数组的区别

    数组和链表 数组:所谓数组就是相同数据类型的元素按照一定顺序排列的集合. 它的存储区间是连续的,占用内存严重,所以空间复杂度很大,为o(n),但是数组的二分查找时间复杂度很小为o(1). 特点是大小固 ...

  2. 手把手带你了解sass

    sass的使用 减少重复的工作 1.变量的声明: 是以$开头给变量命名; $height-color: #F30 2.变量的使用范围: 变量可以在多个地方存在,不一定限制在代码块中.但是如果定义在了代 ...

  3. React结合AntD的upload组件写头像上传

    upload组件里面action就是调upload接口,获取图片url地址 setImg获取url,点击保存传到后台   action 上传头像方法 //上传头像 changeImg = info = ...

  4. Vmware 安装 ghost 版 win 7

    很早就弄过vmware,很可惜一直没有仔细研究过,这次要安装一个win7系统,重新又学一下了一下,下面说一下安装的操作步骤吧. 第一步,下载vmware,原版的下载地址就不说了,上传到百度网盘自己下载 ...

  5. 3.SpringBoot整合Mybatis(一对多)

    前言: Mybatis一对多的处理关系: 一个人有好多本书,每本书的主人只有一个人.当我们查询某个人拥有的所有书籍时,就涉及到了一对多的映射关系. 一.添加数据表: CREATE TABLE `boo ...

  6. zencart产品批量采集伪原创方法,再也不用担心与别人的数据重复了

    首先,请你提供与产品相关的关键词一份,至于关键词如何来,相信做SEO的你很清楚了,SEO关键词搜索工具应该很多,比如谷歌相关关键词搜索,用记事本的形式保存为每行一个关键词.采集产品的时候,我会帮你将关 ...

  7. .NET CORE学习笔记系列 开篇介绍

    ASP.NET Core学习和使用了一段时间了,好记性不如烂笔头,通过查阅官网学习文档和一些大神们的博客总结一下.主要路线先总结一下ASP.NET Core的基础知识,然后是ASP.NET Core  ...

  8. 洛谷P5055 可持久化文艺平衡树 (可持久化treap)

    题目链接 文艺平衡树的可持久化版,可以使用treap实现. 作为序列使用的treap相对splay的优点如下: 1.代码短 2.容易实现可持久化 3.边界处理方便(splay常常需要在左右两端加上保护 ...

  9. luoguP3723 HNOI2017 礼物

    链接 首先,两个手环增加非负整数亮度,等于其中一个增加一个整数亮度,可以为负. 令增加量为\(x\),旋转以后的原数列为,那么在不考虑转圈圈的情况下,现在的费用就是: \[\sum_{i=1}^n\l ...

  10. 【NOIP2016提高A组模拟8.15】Throw

    题目 分析 首先对于一个状态(a,b,c),假定a<=b<=c: 现在考虑一下这个状态,的转移方案: \[1,中间向两边跳(a,b,c)-->(a*2-b,a,c).(a,b,c)- ...