Residual Attention

文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition, ICCV2021

下面说一下我对这篇文章的浅陋之见, 如有错误, 请多包涵指正.

文章的核心方法

如下图所示为其处理流程:

图中 X 为CNN骨干网络提取得到的feature, 其大小为 d*h*w , 为1个batch数据. 一般 d*h*w=2048*7*7 .

从图中可以看到, 有2个分支, 一个是 average pooling, 一个是 spatial pooling, 最后二者加权融合得到 residual attention .

Spatial pooling

其过程为:

这里有个 1*1 的卷积操作FC , 其大小为 C*d*1*1 , C 为类别数, 如果直接使用矩阵乘法计算, FC(X) 后的大小为 C*h*w .

但文章中的公式是将其展开为对每个空间点单独计算, 其中 \(\pmb{m_i}\)​​​ 为 FCi 个类别的参数, 其大小为 d*1*1, 计算得到的 \(s^i_j\)​​​ 为第 i 个类别在第 j 个位置的概率, \(\pmb{a^i}\)​​​​ 为第 i 个类别的特征, 其大小为 d*1 .

如果, \(\pmb{m_i}\) 和 \(\pmb{a^i}\) 计算就可以得到第 i 个类别的概率. 这样就可以用到每个空间点的特征, 有利于不同目标不同类别物体的分类识别.

公式中有个温度参数 T 用来控制 \(s^i_j\)​​​​ 的大小, 当 T 趋于无穷时, spatial pooling 就变成了 max pooling

Average pooling

其过程为:

上式其实就是一般分类模型的做法, 全局均值池化.

Residual Attention

如下所示, 将上述2个过程进行加权融合:

其中, \(\pmb{f^i}\) 大小为 d*1, \(\pmb{m_i}^T \pmb{f^i}\) 为第 i 个类别的概率.

至于为什么叫 Residual Attention , 文章中的说法是:

the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.

我的理解是, 公式5形式有点像 residual 形式.

文章实验结果

多标签

如下表所示为作者对多个数据集的测试, 除了ImageNet 为单标签外, 其它都为多标签. 可以看到多标签提升还是不错的.

热力图

由于利用到了不同位置空间点的信息, 获得的 heatmap 会更加准确, 文章中给出了一张结果, 如下:

我觉得这里有个遗憾的是, 文中没有进行对比.

个人理解

关于原理

根据流程图, 结合文中作者给出的核心代码, 其基本原理就是 average pooling + max pooling.

上述代码中: y_avg 大小为 C*1, 为 average pooling ; y_max 大小为 C*1, 为 max pooling .

下面是上述代码的一个例子, y_raw 的大小为 1*3*9 , B=1, C=3, H3H, W=3:

可以看到, y_avg 刚好为 average pooling , y_max 刚好为 max pooling .

关于公式

公式中的温度参数 T 用于调整参数大小, 而给出的核心代码中, 只有T趋于无穷的情况(等价于max pooling), 对于多个 Head 的情况, T=2,3,4,5 等, 代码中是如何体现出来的?

关于效果

对于 multi-label , 使用了 spatial poolingmulti-head 来提高效果, 从实验结果来看, 确实有效果, 但对于单标签情况, max pooling 应该改善不大, 从实验结果上看也确实可以看到, 单标签数据集上, 最高提升了0.02个百分点.

测试代码

测试代码如下, 可以参考这里.

import torch
from torch import nn class ResidualAttention(nn.Module):
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False) def forward(self, x):
y_raw = self.fc(x).flatten(2) # b, num_class, h*w
y_avg = torch.mean(y_raw, dim=2) # b, num_class
y_max = torch.max(y_raw, dim=2)[0] # b, num_class
score = y_avg + self.la * y_max
return score if __name__ == '__main__': channel = 4
num_class = 3
batchsize = 1
input = torch.randn(batchsize, channel, 3, 3)
resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
output = resatt(input)
print(output.shape)

[论文阅读] Residual Attention(Multi-Label Recognition)的更多相关文章

  1. 论文阅读(Xiang Bai——【PAMI2017】An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition)

    白翔的CRNN论文阅读 1.  论文题目 Xiang Bai--[PAMI2017]An End-to-End Trainable Neural Network for Image-based Seq ...

  2. 论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline

    论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline  如上图所示,本文旨在解决一个问题:给定一张图像, ...

  3. 论文笔记——Deep Residual Learning for Image Recognition

    论文地址:Deep Residual Learning for Image Recognition ResNet--MSRA何凯明团队的Residual Networks,在2015年ImageNet ...

  4. [论文理解]Deep Residual Learning for Image Recognition

    Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...

  5. 论文阅读:Face Recognition: From Traditional to Deep Learning Methods 《人脸识别综述:从传统方法到深度学习》

     论文阅读:Face Recognition: From Traditional to Deep Learning Methods  <人脸识别综述:从传统方法到深度学习>     一.引 ...

  6. RAM: Residual Attention Module for Single Image Super-Resolution

    1. 摘要 注意力机制是深度神经网络的一个设计趋势,其在各种计算机视觉任务中都表现突出.但是,应用到图像超分辨领域的注意力模型大都没有考虑超分辨和其它高层计算机视觉问题的天然不同. 作者提出了一个新的 ...

  7. [论文阅读]阿里DIEN深度兴趣进化网络之总体解读

    [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 目录 [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 0x00 摘要 0x01论文概要 1.1 文章信息 1.2 基本观点 1.2.1 DIN的 ...

  8. Deep Residual Learning for Image Recognition (ResNet)

    目录 主要内容 代码 He K, Zhang X, Ren S, et al. Deep Residual Learning for Image Recognition[C]. computer vi ...

  9. Deep Reinforcement Learning for Dialogue Generation 论文阅读

    本文来自李纪为博士的论文 Deep Reinforcement Learning for Dialogue Generation. 1,概述 当前在闲聊机器人中的主要技术框架都是seq2seq模型.但 ...

随机推荐

  1. SpringBoot Cache 入门

    首先搭载开发环境,不会的可以参考笔者之前的文章SpringBoot入门 添加依赖 <dependency> <groupId>org.springframework.boot& ...

  2. 「万字图文」史上最姨母级Java继承详解

    摘要:继承是面向对象软件技术中的一个概念.它使得复用以前的代码非常容易,能够大大缩短开发周期,降低开发费用. 本文分享自华为云社区<「万字图文」史上最姨母级Java继承详解丨[奔跑吧!JAVA] ...

  3. fast-poster海报生成器v1.4.0,一分钟完成海报开发

    fast-poster海报生成器v1.4.0,一分钟完成海报开发 介绍 一个快速开发动态海报的工具 在线体验:https://poster.prodapi.cn/ v1.4.0 新特性 为了项目和团队 ...

  4. C# 8.0和.NET Core 3.0高级编程 分享笔记一:C#8.0与NET Core 3.0入门

    在学习C#相关知识的过程中,我们使用Visual Studio Code来入门整个C#. 一.安装Visual Studio Core环境 通过https://code.visualstudio.co ...

  5. wireshark 调试 https/http2和grpc流量

    本文浏览器以 Chrom 为例 平常需要抓包的场景比较少,记录一下防止下次忘记配置 1. 解析 TLS 在本地创建用于保存 ssl logfile 的文件(文件可以存放到任意位置), 并添加到环境变量 ...

  6. 一文看懂HTTPS、证书机构(CA)、证书、数字签名、私钥、公钥(转)

    说到https,我们就不得不说tls/ssl,那说到tls/ssl,我们就不得不说证书机构(CA).证书.数字签名.私钥.公钥.对称加密.非对称加密.这些到底有什么用呢,正所谓存在即合理,这篇文章我就 ...

  7. Python之抖音快手代码舞--字符舞

    先上效果,视频敬上: 字符舞: 代码舞 源代码: video_2_code_video.py 1 import argparse 2 import os 3 import cv2 4 import s ...

  8. Java核心基础第5篇-Java面向对象_类和对象

    Java面向对象之类和对象 一.面向对象概述 Java是完全的面向对象编程(Object Oriented Programming),简称OOP. 面向对象编程的思维方式更加符合大家的日常生活,因为我 ...

  9. 【LeetCode】28. 实现 strStr()

    28. 实现 strStr() 知识点:字符串:KMP算法 题目描述 实现 strStr() 函数. 给你两个字符串 haystack 和 needle ,请你在 haystack 字符串中找出 ne ...

  10. Redis学习——常用小功能

    一.慢查询分析(查询日志:所谓慢查询日志就是系统在命令执行前后计算每条命令的执行时间,当超过预设阀值,就将这条命令的相关信息(例如:发生时间,耗时,命令的详细信息)记录下来,Redis也提供了类似的功 ...