Pytorch Guided Backpropgation

Intro

guided backpropgation通过修改RELU的梯度反传,使得小于0的部分不反传,只传播大于0的部分,这样到第一个conv层的时候得到的梯度就是对后面relu激活起作用的梯度,这时候我们对这些梯度进行可视化,得到的就是对网络起作用的区域。(实际上可视化的是梯度)。

简单记一下。用到hook的神经网络可视化方法。

code

import torch
import torch.nn as nn
from torchvision import transforms,models
import re
from models.densenet import densenet121
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class Guided_Prop():
def __init__(self,model):
self.model = model
self.model.eval()
self.out_img = None
self.activation_maps = [] def register_hooks(self):
def register_first_layer_hook(module,grad_in,grad_out):
self.out_img = grad_in[0] #(b,c,h,w) -> (c,h,w)
def forward_hook_fn(module,input_feature,output_feature):
self.activation_maps.append(output_feature)
def backward_hook_fn(module,grad_in,grad_out):
grad = self.activation_maps.pop()
grad[grad > 0] = 1
g_positive = torch.clamp(grad_out[0],min = 0.)
result_grad = grad * g_positive
return (result_grad,) modules = list(self.model.features.named_children())
for name,module in modules:
if isinstance(module,nn.ReLU):
module.register_forward_hook(forward_hook_fn)
module.register_backward_hook(backward_hook_fn)
first_layer = modules[0][1]
first_layer.register_backward_hook(register_first_layer_hook) def visualize(self,input_image):
softmax = nn.Softmax(dim = 1)
idx_tensor = torch.tensor([float(i) for i in range(61)])
self.register_hooks()
self.model.zero_grad()
out = self.model(input_image) # [[b,n],[b,n],[b,n]]
yaw = softmax(out[0])
yaw = torch.sum(yaw * idx_tensor,dim = 1) * 3 - 90.
pitch = softmax(out[1])
pitch = torch.sum(pitch * idx_tensor,dim = 1) * 3 - 90.
roll = softmax(out[2])
roll = torch.sum(roll * idx_tensor,dim = 1) * 3 - 90. #print(yaw)
out = yaw + pitch + roll
out.backward()
result = self.out_img.data[0].permute(1,2,0) # chw -> hwc(opencv)
return result.numpy()
def normalize(I):
norm = (I-I.mean())/I.std()
norm = norm * 0.1
norm = norm + 0.5
norm = norm.clip(0, 1)
return norm
if __name__ == "__main__":
input_size = 224
model = densenet121(pretrained = False,num_classes = 61)
model.load_state_dict(torch.load("./ckpt/DenseNet/model_2692_.pkl")) img = Image.open("/media/xueaoru/其他/ML/head_pose_work/brick/head_and_heads/test/BIWI00009409_-17_+1_+17.png")
transform = transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
tensor = transform(img).unsqueeze(0).requires_grad_() viz = Guided_Prop(model) result = viz.visualize(tensor)
result = normalize(result) plt.imshow(result)
plt.show()

由于是多任务问题,所以直接拿结果反传,对于一般的分类问题,可以给定target来用gt用one-hot反传。

head pose estimation 的梯度可视化。

[NN] Guided Backpropgation 可视化的更多相关文章

  1. 吴裕雄 python神经网络 水果图片识别(3)

    import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...

  2. TF之NN:matplotlib动态演示深度学习之tensorflow将神经网络系统自动学习并优化修正并且将输出结果可视化—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

  3. MySQL 慢查询日志分析及可视化结果

    MySQL 慢查询日志分析及可视化结果 MySQL 慢查询日志分析 pt-query-digest分析慢查询日志 pt-query-digest --report slow.log 报告最近半个小时的 ...

  4. mininet之miniedit可视化操作

    Mininet 2.2.0之后的版本内置了一个mininet可视化工具miniedit,使用Mininet可视化界面方便了用户自定义拓扑创建,为不熟悉python脚本的使用者创造了更简单的环境,界面直 ...

  5. python之gui-tkinter可视化编辑界面 自动生成代码

    首先提供资源链接 http://pan.baidu.com/s/1kVLOrIn#list/path=%2F

  6. 学习TensorFlow,TensorBoard可视化网络结构和参数

    在学习深度网络框架的过程中,我们发现一个问题,就是如何输出各层网络参数,用于更好地理解,调试和优化网络?针对这个问题,TensorFlow开发了一个特别有用的可视化工具包:TensorBoard,既可 ...

  7. 数据分析之---Python可视化工具

    1. 数据分析基本流程 作为非专业的数据分析人员,在平时的工作中也会遇到一些任务:需要对大量进行分析,然后得出结果,解决问题. 所以了解基本的数据分析流程,数据分析手段对于提高工作效率还是非常有帮助的 ...

  8. AI - TensorFlow - 可视化工具TensorBoard

    TensorBoard TensorFlow自带的可视化工具,能够以直观的流程图的方式,清楚展示出整个神经网络的结构和框架,便于理解模型和发现问题. 可视化学习:https://www.tensorf ...

  9. 【TensorFlow篇】--Tensorflow框架可视化之Tensorboard

    一.前述 TensorBoard是tensorFlow中的可视化界面,可以清楚的看到数据的流向以及各种参数的变化,本文基于一个案例讲解TensorBoard的用法. 二.代码 设计一个MLP多层神经网 ...

随机推荐

  1. Homebrew学习(一)之初认识

    Homebrew Homebrew是一款Mac OS平台下的软件包管理工具,拥有安装.卸载.更新.查看.搜索等很多实用的功能.简单的一条指令,就可以实现包管理,而不用你关心各种依赖和文件路径的情况,会 ...

  2. 082、数据收集利器 cAdvisor (2019-04-30 周二)

    参考https://www.cnblogs.com/CloudMan6/p/7683190.html   cAdvisor 是google 开发的容器监控工具,下面我们开始安装和体验 cAdvisor ...

  3. Linux学习--第八天--acl、SetUID、SetGID、chattr、lsattr、sudo

    acl权限 文件只能有一个所属组 acl就是不管用户什么组了,直接针对某个文件给他特定权限. acl需要所在分区文件系统的支持. df -h #查看分区 dumpe2fs -h /dev/sda3 # ...

  4. AIX中的服务管理

    1.SRC AIX系统使用资源控制器(SRC,system   resource  controller),控制各种服务子系统,包括启动,停止进程,搜集进程状态信息等.   AIX系统中服务有子系统组 ...

  5. 【原】iptables 交叉编译

    防火墙在做数据包过滤决定时,有一套遵循和组成的规则,这些规则存储在专用的数据包过滤表中,而这些表集成在 Linux 内核中.在数据包过滤表中,规则被分组放在我们所谓的链(chain)中.而netfil ...

  6. BZOJ[3252]攻略(长链剖分)

    BZOJ[3252]攻略 Description 题目简述:树版[k取方格数] 众所周知,桂木桂马是攻略之神,开启攻略之神模式后,他可以同时攻略k部游戏.今天他得到了一款新游戏<XX半岛> ...

  7. outlook 使用临时邮箱 使用旧数据

    控制面板-->邮件32位 显示配置文件 删除再添加 具体可参考 https://blog.csdn.net/liuyukuan/article/details/80043840 偷懒,图片从网上 ...

  8. linux命令历史

    本人qq群也有许多的技术文档,希望可以为你提供一些帮助(非技术的勿加). QQ群:   281442983 (点击链接加入群:http://jq.qq.com/?_wv=1027&k=29Lo ...

  9. SpringMVC 中的注解@RequestParam与@PathVariable的区别

    @PathVariable绑定URI模板变量值 @PathVariable是用来获得请求url中的动态参数的 @PathVariable用于将请求URL中的模板变量映射到功能处理方法的参数上.//配置 ...

  10. Css min-height max-height min-width max-height

    Css min-height应用地方解释我们有时设置一个对象盒子时候避免对象没有内容时候不能撑开,但内容多少不能确定所以又不能固定高度,这个时候我们就会需要css来设置min-height最小高度撑高 ...