DETR:Facebook提出基于Transformer的目标检测新范式,性能媲美Faster RCNN | ECCV 2020 Oral
DETR基于标准的Transorfmer结构,性能能够媲美Faster RCNN,而论文整体思想十分简洁,希望能像Faster RCNN为后续的很多研究提供了大致的思路
来源:晓飞的算法工程笔记 公众号
论文: End-to-End Object Detection with Transformers
Introduction
之前也看过一些工作研究将self-attention应用到视觉任务中,比如Stand-Alone Self-Attention in Vision Models和On the Relationship between Self-Attention and Convolutional Layers,但这些方法大都只是得到与卷积类似的效果,还没有很出彩的表现,而DETR基于transformer颠覆了主流目标检测的做法,主要有三个亮点:
- Standard Transformer,DETR采用标准的Transformer和前向网络FFN进行特征的处理以及结果的输出,配合精心设计的postion encoding以及object queries,不需要anchor,直接预测bbox坐标以及类别。
- Set prediction,DETR在训练过程中使用匈牙利排序算法将GT和模型预测结果一一对应,使得在推理时的模型预测结果即为最终结果,不需要后续的NMS操作。
- 目标检测性能超越了经典的Faster RCNN,打开了目标检测研究的新路线,并且DETR也能改装应用于全景分割任务,性能也不错。
The DETR model
DETR architecture
DETR的整体架构很简单,如图2所示,包含3个主要部分:CNN主干、encoder-decoder transformer和简单的前向网络(FFN)。
Backbone
定义初始图片\(x_{img} \in \mathbb{R}^{3\times H_o\times W_o}\),使用常规的CNN主干生成低分辨率特征图\(f\in \mathbb{R}^{C\times H\times W}\),论文采用\(C=2048\)以及\(H,W=\frac{H_o}{32}, \frac{W_o}{32}\)。
Transformer encoder
先用\(1\times 1\)卷积将输入降至较小的维度\(d\),得到新特征图\(z_o \in \mathbb{R}^{d\times H\times W}\),再将特征图\(z_o\)空间维度折叠成1维,转换为\(d\times HW\)的序列化输入。DETR包含多个encoder,每个encoder都为标准结构,包含mullti-head self-attention模块和前向网络FFN。由于transformer是排序不变的,为每个attention层补充一个固定的位置encoding输入。
Transformer decoder
decoder也是transformer的标准结构,使用multi-head self-attention模块和encoder-decoder注意力机制输出\(N\)个大小为\(d\)的embedding,唯一不同的是DETR并行地decode \(N\)个目标,不需要自回归的机制。由于decoder也是排序不变的,采用学习到的位置encdoing(等同于anchor)作为输入,称为object queries。类似于encoder,将位置encoding输入到每个attention层,另外还有空间位置encoding,见图10。decoder将\(N\)个object queries转换为\(N\)个输出embedding,然后独立地解码成box坐标和class标签,得到\(N\)个最终的预测结构。由于了使用self-attention以及encoder-decoder注意力机制,模型能够全局地考虑所有的目标。
Prediction feed-forward networks (FFNs)
使用带ReLU激活的3层感知机以及线性映射层来解码得到最终的预测结果,感知机的隐藏层维度为\(d\)。FFN预测\(N个\)归一化的中心坐标、高度、宽度以及softmax后的类别得分,由于\(N\)一般大于目标个数,所以使用特殊的类别\(\emptyset\)来标记无预测目标,类似于背景类。需要注意,最后用于输出的FFN与encoder和decoder里的FFN是不一样的。
Auxiliary decoding losses
论文发现对decoder使用辅助损失进行训练十分有效,特别是帮助模型输出正确的目标个数,所以在每个decoder层添加FFN和Hugarian loss,所有的FFN共享参数,另外也使用了共享的layer-norm来归一化FFN的输入。
Object detection set prediction loss
DETR输出固定的\(N\)个预测结果,最大的困难在于根据GT对预测结果进行评分,需要先找到预测结果和GT的对应关系。定义\(y\)为GT集合,大小为N,缺少的用\(\emptyset\)填充,\(\hat{y}=\{ \hat{y}_i\}^N_{i=1}\)为预测结果,为了最好地匹配GT和预测结果,使用匈牙利算法(二部图匹配方法)找到能够最小化匹配损失的最优排列方法\(\sigma\):
\(\mathcal{L}_{match} (y_i, \hat{y}_{\sigma(i)})=-\Bbb{1}_{\{c_i \ne \emptyset\}}\hat{p}_{\sigma(i)}(c_i)+1_{\{c_i \ne \emptyset \} } \mathcal{L_{box}}(b_i, \hat{b}_{\sigma(i)})\)为排序后GT-预测结果对的匹配损失,匹配损失考虑类别预测以及bbox的相似度。\(y_i=(c_i, b_i)\)为GT,其中\(c_i\)为类别,\(b_i\in [0, 1]^4\)为相对于图片大小的坐标向量(x, y, hetight, weight),\(\hat{p}_{\sigma(i)}(c_i)\)和\(\hat{b}_{\sigma(i)}\)分别为预测的类别置信度和bbox。这里的匹配过程类似于目前检测算法中anchor和GT的匹配逻辑,而区别在于这里的预测结果和GT是一一对应的。
在找到最优排列方法\(\hat{\sigma}\)后,计算Hungarian loss:
在实现时,考虑分类不均衡,对\(c_i=\emptyset\)的分类项降权10倍。
与普通的目标检测方法预测bbox的差值不同,DETR直接预测bbox的坐标。虽然这个方法实现很简单,但计算损失时会受目标尺寸的影响,论文采用线性\(\mathcal{l}_1\)损失和IoU损失来保证尺度不变,bbox损失\(\mathcal{L}_{box}(b_i, \hat{b}_{\sigma (i)})\)为\(\lambda_{iou}\mathcal{L}_{iou}(b_i, \hat{b}_{\sigma(i)})+\lambda_{L1} || b_i - \hat{b}_{\sigma(i)} ||_1\),bbox损失需要用正样本数进行归一化。
Experiments
DETR性能超越了经典的Faster RCNN。
探究encoder层数对性能的影响
每层decoder输出进行预测的准确率,可以看到逐层递增。
位置embedding方式对性能的影响,这里的spatial pos对应图10的spatial positional encoding,而output pos则对应图10的Object queries。
损失函数对性能的影响。
DETR for panoptic segmentation
DETR也可以在decoder的输出接一个mask head来进行全景分割任务,主要利用了DETR模型的特征提取能力。
与当前主流模型的全景分割性能对比。
Conclustion
DETR基于标准的Transorfmer结构,性能能够媲美Faster RCNN,而论文整体思想十分简洁,希望能像Faster RCNN为后续的很多研究提供了大致的思路。
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】
DETR:Facebook提出基于Transformer的目标检测新范式,性能媲美Faster RCNN | ECCV 2020 Oral的更多相关文章
- 谷歌大脑提出:基于NAS的目标检测模型NAS-FPN,超越Mask R-CNN
谷歌大脑提出:基于NAS的目标检测模型NAS-FPN,超越Mask R-CNN 朱晓霞发表于目标检测和深度学习订阅 235 广告关闭 11.11 智慧上云 云服务器企业新用户优先购,享双11同等价格 ...
- [转]CNN目标检测(一):Faster RCNN详解
https://blog.csdn.net/a8039974/article/details/77592389 Faster RCNN github : https://github.com/rbgi ...
- [炼丹术]基于SwinTransformer的目标检测训练模型学习总结
基于SwinTransformer的目标检测训练模型学习总结 一.简要介绍 Swin Transformer是2021年提出的,是一种基于Transformer的一种深度学习网络结构,在目标检测.实例 ...
- 10分钟内基于gpu的目标检测
10分钟内基于gpu的目标检测 Object Detection on GPUs in 10 Minutes 目标检测仍然是自动驾驶和智能视频分析等应用的主要驱动力.目标检测应用程序需要使用大量数据集 ...
- 目标检测算法(1)目标检测中的问题描述和R-CNN算法
目标检测(object detection)是计算机视觉中非常具有挑战性的一项工作,一方面它是其他很多后续视觉任务的基础,另一方面目标检测不仅需要预测区域,还要进行分类,因此问题更加复杂.最近的5年使 ...
- 目标检测模型的性能评估--MAP(Mean Average Precision)
目标检测模型中性能评估的几个重要参数有精确度,精确度和召回率.本文中我们将讨论一个常用的度量指标:均值平均精度,即MAP. 在二元分类中,精确度和召回率是一个简单直观的统计量,但是在目标检测中有所不同 ...
- AI佳作解读系列(二)——目标检测AI算法集杂谈:R-CNN,faster R-CNN,yolo,SSD,yoloV2,yoloV3
1 引言 深度学习目前已经应用到了各个领域,应用场景大体分为三类:物体识别,目标检测,自然语言处理.本文着重与分析目标检测领域的深度学习方法,对其中的经典模型框架进行深入分析. 目标检测可以理解为是物 ...
- 目标检测算法(一):R-CNN详解
参考博文:https://blog.csdn.net/hjimce/article/details/50187029 R-CNN(Regions with CNN features)--2014年提出 ...
- 基于 RocketMQ 的 Dubbo-go 通信新范式
本文作者:郝洪范 ,Dubbo-go Committer,京东资深研发工程师. 一.MQ Request Reply特性介绍 什么是 RPC 通信? 如上图所示,类似于本地调用,A 服务响应调用 B ...
- 实战小项目之基于yolo的目标检测web api实现
上个月,对微服务及web service有了一些想法,看了一本app后台开发及运维的书,主要是一些概念性的东西,对service有了一些基本了解.互联网最开始的构架多是cs构架,浏览器兴起以后,变成了 ...
随机推荐
- 【Unity3D】动画混合
1 简介 2D动画.人体模型及动画.人物跟随鼠标位置中介绍了 Aniamtion.Animator.人体模型.人体骨骼.人体动画等基础知识及人体动画的应用,本文将进一步介绍动画混合. 实现动画 ...
- 链表--insert
分别是使用了二级指针和一级指针的两种方法,最后会按插入的顺序依次打印1,2,3,4 主要区别在于,使用二级指针,可以在main函数里直接用一个空的Node指针,而一级指针是在main函数里面先添加了一 ...
- 【Android逆向】IDA动态调试JNI_OnLoad 和 .init_array
由于 JNI_OnLoad 和 .init_array 执行时间很早,so一加载到内存中就执行了,所以动态调试步骤会稍微要麻烦一些 1. 进入手机, 执行./android_server (如果是64 ...
- 50从零开始用Rust编写nginx,原来TLS证书还可以这么申请
wmproxy wmproxy已用Rust实现http/https代理, socks5代理, 反向代理, 负载均衡, 静态文件服务器,websocket代理,四层TCP/UDP转发,内网穿透等,会将实 ...
- django学习第十三天--自定义中间件
jquery操作cookie 下载地址 http://plugins.jquery.com/cookie/ 引入 <script type="text/javascript" ...
- 基于Python GDAL为长时间序列遥感图像绘制时相变化曲线图
本文介绍基于Python中gdal模块,对大量多时相栅格图像,批量绘制像元时间序列折线图的方法. 首先,明确一下本文需要实现的需求:现有三个文件夹,其中第一个文件夹存放了某一研究区域原始的多时 ...
- 【Azure Key Vault】客户端获取Key Vault机密信息全部失败问题分析
问题描述 在应用中获取存储在Azure Key Vault的机密信息,全部失败. 报错日志内容如下: [reactor-http-epoll-4] [reactor.netty.http.client ...
- 【Azure API 管理】在 Azure API 管理中使用 OAuth 2.0 授权和 Azure AD 保护 Web API 后端,在请求中携带Token访问后报401的错误
问题描述 在 Azure API 管理中使用 OAuth 2.0 授权和 Azure AD 保护 Web API 后端的文档中操作 "在开发人员门户中启用 OAuth 2.0 用户授权&qu ...
- 基于图数据库 NebulaGraph 实现的欺诈检测方案及代码示例
本文是一个基于 NebulaGraph 图算法.图数据库.机器学习.GNN 的 Fraud Detection 方法综述.在阅读本文了解欺诈检测的基本实现方法之余,也可以在我给大家准备的 Playgr ...
- Go语言的100个错误使用场景(55-60)|并发基础
目录 前言 8. 并发基础 8.1 混淆并发与并行的概念(#55) 8.2 认为并发总是更快(#56) 8.3 分不清何时使用互斥锁或 channel(#57) 8.4 不理解竞态问题(#58) 8. ...