觉得本文不错的可以点个赞。有问题联系作者微信cyx645016617,之后主要转战公众号,不在博客园和CSDN更新。

论文名称:“Grad-CAM:

Visual Explanations from Deep Networks via Gradient-based Localization”

论文地址:https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf

论文期刊:ICCV International Conference on Computer Vision

1 综述

总的来说,卷积网络的可解释性一直是一个很重要的问题,你要用人工智能去说服别人,但是这是一个“黑箱”问题,卷积网络运行机理是一个black box,说不清楚内在逻辑。

因此很多学者提出了各种各样的可视化来解释的方法。我个人尝试过的、可以从一定角度进行解释的可视化方法有:t-sne降维,attention可视化,可变卷积的可视化等,但是其实这些的可视化方法,并不能直接的对模型进行解释,只是能说明模型分类是准确的

CAM的全称是Class Activation Mapping,对于分类问题,我们可以直观的通过这种方法,来进行解释方向的可视化。

grad-CAM是CAM的进阶版本,更加方便实施、即插即用。

2 CAM

CAM的原理是实现可解释性的根本,所以我通俗易懂的讲一讲。

上面是一个传统CNN的结构,通过卷积和池化层后,把特征图拉平成一维,然后是全连接层进行分类。

那么CAM的网络是什么样子呢?基本和上面的结构相同

图中有一个GAP池化层,全局平均池化层。这个就是求取每一个通道的均值,可以理解为核是和特征图一样大的一般的平均池化层,假如输出特征图是一个8通道的,224x224的特征图,那么经过GAP这个池化层,就会得到8个数字,一个通道贡献一个数字,这个数字是一个通道的代表

然后经过GAP之后的一维向量,再跟上一个全连接层,得到类别的概率。

上图中左边就是经过GAP得到的向量,其数量就是最后一层特征图的通道数,右边的向量的数量就是类别的数量。

关键来了,CAM的可解释性的逻辑在于:假设我们最终预测的类别是羊驼,也就是说,模型给羊驼的打分最高。我们可以得到,左边向量计算出羊驼的权重值,也就是全连接层中的一部分权重值。这个权重值就是!!!就是最后一层特征图每一个通道的权重值。之前也提到了GAP的输出的一个向量代表着GAP输入的特征图的每一个通道嘛

这样我们通过最后一个全连接层获取到最后一个特征图的每一个通道对于某一个类别的贡献的权重值。我们对最后一个特征图的每一个通道的加权平均,就是我们得到的CAM对卷积的解释。之后可以上采样到整个图片那么大小,就像是论文给出的样子:

大家应该明白这个原理了,但是这样要修改模型的结构。之前训练的模型用不了了,这很麻烦,所以才有了Grad-CAM的提出。

3 Grad-CAM

Grad-CAM思路和CAM也是相同的,也是需要得到特征图每一个通道的权重值,然后做一个加权和。

所以关键在于,如何计算这个权重值,论文提出了这样的计算方法:

其中,z是一个特征图的像素量,就是width*height,可以看到,前面就是CAM的GAP的一个过程,后面的\(y^c\)是模型给类别c的打分,\(A_{ij}^k\)就是特征图中ij这个位置的元素值。那么对这个求导,其实就是这个位置的梯度。

所以用pytorch的实现如下:

  1. self.model.features.zero_grad()
  2. self.model.classifier.zero_grad()
  3. one_hot.backward(retain_graph=True)#仅包含有最大概率值,然后进行反向传播
  4. grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()
  5. weights = np.mean(grads_val, axis=(2, 3))[0, :]#求平均,就是上面这段公式
  6. # 简单的说上面的逻辑就是先反向传播之后,然后获取对应位置的梯度,然后计算平均。

在论文中作者证明了Grad-CAM和CAM的等价的结论,想了解的可以看看。

4 pytorch完整代码

官方提供了github代码:https://github.com/jacobgil/pytorch-grad-cam

其中关键的地方是:

  1. class FeatureExtractor():
  2. """ Class for extracting activations and
  3. registering gradients from targetted intermediate layers """
  4. def __init__(self, model, target_layers):
  5. self.model = model
  6. self.target_layers = target_layers
  7. self.gradients = []
  8. def save_gradient(self, grad):
  9. self.gradients.append(grad)
  10. def __call__(self, x):
  11. outputs = []
  12. self.gradients = []
  13. for name, module in self.model._modules.items():
  14. x = module(x)
  15. if name in self.target_layers:
  16. x.register_hook(self.save_gradient)
  17. outputs += [x]
  18. return outputs, x
  19. class ModelOutputs():
  20. """ Class for making a forward pass, and getting:
  21. 1. The network output.
  22. 2. Activations from intermeddiate targetted layers.
  23. 3. Gradients from intermeddiate targetted layers. """
  24. def __init__(self, model, feature_module, target_layers):
  25. self.model = model
  26. self.feature_module = feature_module
  27. self.feature_extractor = FeatureExtractor(self.feature_module, target_layers)
  28. def get_gradients(self):
  29. return self.feature_extractor.gradients
  30. def __call__(self, x):
  31. target_activations = []
  32. for name, module in self.model._modules.items():
  33. if module == self.feature_module:
  34. target_activations, x = self.feature_extractor(x)
  35. elif "avgpool" in name.lower():
  36. x = module(x)
  37. x = x.view(x.size(0),-1)
  38. else:
  39. if name is 'classifier':
  40. x = x.view(x.size(0), -1)
  41. x = module(x)
  42. return target_activations, x
  43. class GradCam:
  44. def __init__(self, model, feature_module, target_layer_names, use_cuda):
  45. self.model = model
  46. self.feature_module = feature_module
  47. self.model.eval()
  48. self.cuda = use_cuda
  49. if self.cuda:
  50. self.model = model.cuda()
  51. self.extractor = ModelOutputs(self.model, self.feature_module, target_layer_names)
  52. def forward(self, input_img):
  53. return self.model(input_img)
  54. def __call__(self, input_img, target_category=None):
  55. if self.cuda:
  56. input_img = input_img.cuda()
  57. features, output = self.extractor(input_img)
  58. if target_category == None:
  59. target_category = np.argmax(output.cpu().data.numpy())
  60. one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
  61. one_hot[0][target_category] = 1
  62. one_hot = torch.from_numpy(one_hot).requires_grad_(True)
  63. if self.cuda:
  64. one_hot = one_hot.cuda()
  65. one_hot = torch.sum(one_hot * output)
  66. self.feature_module.zero_grad()
  67. self.model.zero_grad()
  68. one_hot.backward(retain_graph=True)
  69. grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()
  70. target = features[-1]
  71. target = target.cpu().data.numpy()[0, :]
  72. weights = np.mean(grads_val, axis=(2, 3))[0, :]
  73. cam = np.zeros(target.shape[1:], dtype=np.float32)
  74. for i, w in enumerate(weights):
  75. cam += w * target[i, :, :]
  76. cam = np.maximum(cam, 0)
  77. cam = cv2.resize(cam, input_img.shape[2:])
  78. cam = cam - np.min(cam)
  79. cam = cam / np.max(cam)
  80. return cam

把这一段复制到自己的代码中后,可以参考下面的代码逻辑,简单改写自己的代码即可实现可视化(看不懂的话还是看github):

  1. grad_cam = GradCam(model = model,feature_module = model.features,target_layer_names=['11'],use_cuda=True)
  2. def draw(ax,grayscale_cam,data):
  3. heatmap = cv2.applyColorMap(np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET)
  4. heatmap = heatmap + data.detach().cpu().numpy()[0,0].reshape(28,28,1).repeat(3,axis=2)
  5. heatmap = heatmap / np.max(heatmap)
  6. ax.imshow(heatmap)
  7. for data,target in val_loader:
  8. if torch.cuda.is_available():
  9. data = data.cuda()
  10. target = target.cuda()
  11. # 绘制9张可视化图
  12. fig = plt.figure(figsize=(12,12))
  13. for i in range(9):
  14. d = data[i:i+1]
  15. grayscale_cam = grad_cam(d)
  16. ax = fig.add_subplot(3,3,i+1)
  17. draw(ax,grayscale_cam,d)
  18. break

输出图像为:

有问题欢迎联系作者讨论,请多指教。

卷积网络可解释性复现 | Grad-CAM | ICCV | 2017的更多相关文章

  1. CSG:清华大学提出通过分化类特定卷积核来训练可解释的卷积网络 | ECCV 2020 Oral

    论文提出类特定控制门CSG来引导网络学习类特定的卷积核,并且加入正则化方法来稀疏化CSG矩阵,进一步保证类特定.从实验结果来看,CSG的稀疏性能够引导卷积核与类别的强关联,在卷积核层面产生高度类相关的 ...

  2. 复现ICCV 2017经典论文—PyraNet

    . 过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含“伪代码”.这是今年 AAAI 会议上一个严峻的 ...

  3. 使用Caffe完成图像目标检测 和 caffe 全卷积网络

    一.[用Python学习Caffe]2. 使用Caffe完成图像目标检测 标签: pythoncaffe深度学习目标检测ssd 2017-06-22 22:08 207人阅读 评论(0) 收藏 举报 ...

  4. ACNet:用于图像超分的非对称卷积网络

    编辑:Happy 首发:AIWalker Paper:https://arxiv.org/abs/2103.13634 Code:https://github.com/hellloxiaotian/A ...

  5. 基于孪生卷积网络(Siamese CNN)和短时约束度量联合学习的tracklet association方法

    基于孪生卷积网络(Siamese CNN)和短时约束度量联合学习的tracklet association方法 Siamese CNN Temporally Constrained Metrics T ...

  6. PRML读书会第五章 Neural Networks(神经网络、BP误差后向传播链式求导法则、正则化、卷积网络)

    主讲人 网神 (新浪微博:@豆角茄子麻酱凉面) 网神(66707180) 18:55:06 那我们开始了啊,前面第3,4章讲了回归和分类问题,他们应用的主要限制是维度灾难问题.今天的第5章神经网络的内 ...

  7. 学习笔记TF028:实现简单卷积网络

    载入MNIST数据集.创建默认Interactive Session. 初始化函数,权重制造随机噪声打破完全对称.截断正态分布噪声,标准差设0.1.ReLU,偏置加小正值(0.1),避免死亡节点(de ...

  8. 全卷积网络 FCN 详解

    背景 CNN能够对图片进行分类,可是怎么样才能识别图片中特定部分的物体,在2015年之前还是一个世界难题.神经网络大神Jonathan Long发表了<Fully Convolutional N ...

  9. 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec

    人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...

随机推荐

  1. dubbo起停之服务暴露

    由上一节可知带上dubbo@Service注解的对象,在注册成为bean之后会进一步注册一个ServiceBean,服务暴露便是在这里 public void afterPropertiesSet() ...

  2. presto 访问kudu 多schemas配置

    presto需要访问kudu数据源,但是impala可以直接支持多数据库存储,但是presto不能原生支持,按照presto的官网设置了然而并不起作用. 官方文档: 到官方github提问了,然后并没 ...

  3. RocketMq(三):server端处理框架及消费数据查找实现

    rocketmq作为一个高性能的消息中间件,咱们光停留在使用层面,总感觉缺点什么.虽然rocketmq的官方设计文档讲得还是比较详细的,但纸上得来终觉浅!今天我们就来亲自挖一挖rocketmq的实现细 ...

  4. 第11.14节 正则表达式转义符和Python转义符相同引发问题的解决办法

    正则表达式使用反斜杠('\')来把特殊字符转义成普通字符(为了方便称为"正则表达式转义"),而反斜杠在普通的 Python 字符串里也是转义符(称为"字符串转义" ...

  5. PyQt(Python+Qt)学习随笔:QListWidget的访问当前项的currentItem和setCurrentItem方法

    老猿Python博文目录 专栏:使用PyQt开发图形界面Python应用 老猿Python博客地址 currentItem方法返回列表部件当前选择的项,setCurrentItem方法用于设置当前项. ...

  6. PyQt(Python+Qt)学习随笔:toolButton的popupMode属性

    属性介绍 toolButton的popupMode属性为设有菜单集或Action列表的toolButton指定菜单弹出模式,类型为枚举类型ToolButtonPopupMode,有如下三种模式: 1. ...

  7. CTFHub Web题学习笔记(SQL注入题解writeup)

    Web题下的SQL注入 1,整数型注入 使用burpsuite,?id=1%20and%201=1 id=1的数据依旧出现,证明存在整数型注入 常规做法,查看字段数,回显位置 ?id=1%20orde ...

  8. jmeter使用中的问题

    1.响应乱码 step1:指定请求节点下,新建后置控制器"BeanShell PostProcessor" step2:其脚本框中输入以下代码,保存 //获取响应代码Unicode ...

  9. Redis整合MySQL和MyCAT分库组件(来源是我的新书)

    MyCAT是一个开源的分布式数据库组件,在项目里,一般用这个组件实现针对数据库的分库分表功能,从而提升对数据表,尤其是大数据库表的访问性能.而且在实际项目里,MyCAT分库分表组件一般会和MySQL以 ...

  10. 痞子衡嵌入式:了解i.MXRT1060系列ROM中串行NOR Flash启动初始化流程优化点

    大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家分享的是i.MXRT1060系列ROM中串行NOR Flash启动初始化流程优化点. 前段时间痞子衡写了一篇 <深入i.MXRT1050系 ...