论文信息

论文标题:CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation
论文作者:Tongkun Xu, Weihua Chen, Pichao Wang, Fan Wang, Hao Li, Rong Jin
论文来源:ICLR 2022
论文地址:download 
论文代码:download

1 Introduction

  无监督域自适应(Unsupervised domain adaptation,UDA)的目的是将从标记源域学习到的知识转移到不同的未标记目标域。

  UDA 方法:
  ① Domain-level UDA ,通过将源域和目标域在不同尺度水平上进入相同的分布来缓解源域之间的分布差异;
  ② fine-grained category-level UDA,通过将目标样本推向每个类别中的源样本的分布,对源域数据和目标域数据之间的每个类别分布进行对齐;(仍然存在标签噪声问题)

2 Method

2.1 The cross attention in Transformer

  传统的方法给目标域打伪标签的过程中存在噪声,由于噪声的存在,需要对齐的源域和目标域的图片可能不属于同一类,强行对其可能产生很大的负面影响。而本文经过实验发现 Transformer 中的 CrossAttention 可以有效的避免噪声给对其造成的影响,CrossAttention 更多的关注源域和目标域中图片中的相似信息。换句话说,即使图片对不属于同一类,被拉近的也只会是两者相似的部分。因此,CDTrans 具有一定的抗噪能力。

  由于在 UDA 任务中,目标域是没有标签的。因此只能借鉴伪标签的思路,来生成潜在的可能属于同一个 ID 的样本对。但是,伪标签生成的样本对中不可避免的会存在噪声。这时,本文发现 Cross Attention 对样本对中的噪声有着很强的鲁棒性。本文分析这主要是因为 Attention 机制所决定的,Attention 的 weight 更多的会关注两张图片相似的部分,而忽略其不相似的部分。如果源域图片和目标域图片不属于同一个类别的话,比如Figure 1.a“Car vs. Truck”的例子,Attention 的 weight 主要集中于两个图片中相似部分的对齐(比如轮胎),而对其他部位的对齐会给很小的 weight。

  

  换句话说,Cross Attention 没有在使劲拉近对齐小轿车和卡车,而更多的是在努力对齐两个图片中的轮胎。一方面,Cross Attention 避免了强行拉近小轿车和卡车,减弱了噪声样本对 UDA 训练的影响;另一方面,拉近不同域的轮胎,在一定程度上可能帮助到目标域轮胎的识别。

  自注意力(self-attention):

    $\operatorname{Attn}_{\text {self }}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})=\operatorname{softmax}\left(\frac{\boldsymbol{Q} \boldsymbol{K}^{T}}{\sqrt{d_{k}}}\right) \boldsymbol{V}\quad\quad(1)$

  交叉注意力(cross-attention):

    $\operatorname{Attn}_{\text {cross }}\left(\boldsymbol{Q}_{s}, \boldsymbol{K}_{t}, \boldsymbol{V}_{t}\right)=\operatorname{softmax}\left(\frac{\boldsymbol{Q}_{s} \boldsymbol{K}_{t}^{T}}{\sqrt{d_{k}}}\right) \boldsymbol{V}_{t}\quad\quad(2)$

2.2 Two way center-aware pseudo labeling

2.2.1 Two way labeling

  为了构建交叉注意模块的训练对,一种直观的方法是,对源域中的每一幅图像,我们设法从目标域找到最相似的图像。所选数据对的设置 $\mathbb{P}_{S}$ 为:

    $\mathbb{P}_{S}=\left\{(s, t) \mid t=\underset{k}{\text{min}} \quad d\left(\boldsymbol{f}_{s}, \boldsymbol{f}_{k}\right), \forall k \in T, \forall s \in S\right\}\quad\quad\quad(3)$

  其中,$S$、$T$ 分别为源数据和目标数据。$d\left(\boldsymbol{f}_{i}, \boldsymbol{f}_{j}\right)$ 表示图像 $i$ 和图像 $j$ 的特征之间的距离。

  这种策略的优点是充分利用源数据,而其弱点显然是只涉及到目标数据的一部分。为了消除目标数据的这种训练偏差,我们从相反的方式引入了更多的对 $\mathbb{P}$,包括所有目标数据及其在源域中最相似的图像。

    $\mathbb{P}_{T}=\left\{(s, t) \mid s=\underset{k}{\text{min}} \quad d\left(\boldsymbol{f}_{t}, \boldsymbol{f}_{k}\right), \forall t \in T, \forall k \in S\right\}\quad\quad\quad(4)$

  因此,最终的 $\mathbb{P}$ 是两个集的并集,即 $\mathbb{P}=\left\{\mathbb{P}_{S} \cup \mathbb{P}_{T}\right\}$,使训练对包括所有的源数据和目标数据。

2.2.2 Center-Aware filtering

  $\mathbb{P}$ 中的 pair 是基于两个域图像的特征相似性构建的,因此 pair 的伪标签的准确性高度依赖于特征相似性。

  本文发现,源数据的预训练模型也有助于进一步提高精度。首先,我们通过将所有目标数据送到通过预先训练的模型,从分类器得到它们在源类别上的概率分布 $\delta$。这些分布可以通过加权 k-means 聚类来计算目标域内每个类别的初始中心:

    ${\large \boldsymbol{c}_{k}=\frac{\sum_{t \in T} \delta_{t}^{k} \boldsymbol{f}_{t}}{\sum_{t \in T} \delta_{t}^{k}}}\quad\quad(5) $
  其中,$\delta_{t}^{k}$ 表示图像 $t$ 在类别 $k$ 上的概率。目标数据的伪标签可以通过最近邻分类器产生:

    $y_{t}=\arg   \underset{k}{\text{min}} \; d\left(\boldsymbol{c}_{k}, \boldsymbol{f}_{t}\right)  \quad\quad(6) $

  其中,$t \in T$ 和 $d(i, j)$ 是特征 $i$ 和 $j$ 的距离。基于伪标签,我们可以计算出新的中心:

    ${\normalsize \boldsymbol{c}_{k}^{\prime}=\frac{\sum_{t \in T} \mathbb{1}\left(y_{t}=k\right) \boldsymbol{f}_{t}}{\sum_{t \in T} \mathbb{1}\left(y_{t}=k\right)}}  \quad\quad(7) $

2.3 CDTrans:Cross-Domain Transformer

框架如下:

  

  上述 CDTrans 框架包括三个权重共享的 transformer ,分别是 source branch, source-target branch, target branch 。

  输入对中的源图像和目标图像分别被发送到 source branch 和 target branch 。在这两个分支中,self-attention 涉及到学习特定领域的表示。并利用 softmax cross-entropy loss 进行分类训练。值得注意的是,由于两个图像的相同标签,所有三个分支共享相同的分类器。交叉注意力模块被导入到 source-target branch 中。source-target branch 的输入来自其他两个分支。在第 $N$ 层中,交叉注意模块的 query 来自于source branch 的第 $N$ 层中的查询,而 key 和 value 来自于  target branch 的查询。source-target branch 的特征不仅对齐了两个域的分布,而且由于交叉注意模块,对输入对中的噪声具有鲁棒性。因此,本文使用 source-target branch 的输出来指导目标分支的训练。具体来说,source-target branch 和目标分支分别表示为 teacher 和 student 。本文将分类器在 source-target branch 中的概率分布作为一个软标签,可以通过蒸馏损失来进一步监督目标分支

    ${\large L_{d t l}=\sum\limits_{k} q_{k} \log p_{k}}  \quad\quad(8) $

  其中,$q_{k}$ 和 $p_{k}$ 分别是 source-target branch 和 target branch 的概率。

  在推理期间,只使用目标分支。输入值为来自测试数据的图像,只触发目标数据流,即 Fig.2 中的蓝线。利用其分类器的输出作为最终的预测标签。 
 

论文解读(CDTrans)《CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation》的更多相关文章

  1. [论文解读] 阿里DIEN整体代码结构

    [论文解读] 阿里DIEN整体代码结构 目录 [论文解读] 阿里DIEN整体代码结构 0x00 摘要 0x01 文件简介 0x02 总体架构 0x03 总体代码 0x04 模型基类 4.1 基本逻辑 ...

  2. CVPR2020论文解读:三维语义分割3D Semantic Segmentation

    CVPR2020论文解读:三维语义分割3D Semantic Segmentation xMUDA: Cross-Modal Unsupervised Domain Adaptation  for 3 ...

  3. 自监督学习(Self-Supervised Learning)多篇论文解读(上)

    自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...

  4. 论文解读丨表格识别模型TableMaster

    摘要:在此解决方案中把表格识别分成了四个部分:表格结构序列识别.文字检测.文字识别.单元格和文字框对齐.其中表格结构序列识别用到的模型是基于Master修改的,文字检测模型用到的是PSENet,文字识 ...

  5. NLP论文解读:无需模板且高效的语言微调模型(上)

    原创作者 | 苏菲 论文题目: Prompt-free and Efficient Language Model Fine-Tuning 论文作者: Rabeeh Karimi Mahabadi 论文 ...

  6. 论文解读(SR-GNN)《Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data》

    论文信息 论文标题:Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data论文作者:Qi Zhu, ...

  7. itemKNN发展史----推荐系统的三篇重要的论文解读

    itemKNN发展史----推荐系统的三篇重要的论文解读 本文用到的符号标识 1.Item-based CF 基本过程: 计算相似度矩阵 Cosine相似度 皮尔逊相似系数 参数聚合进行推荐 根据用户 ...

  8. CVPR2019 | Mask Scoring R-CNN 论文解读

    Mask Scoring R-CNN CVPR2019 | Mask Scoring R-CNN 论文解读 作者 | 文永亮 研究方向 | 目标检测.GAN 推荐理由: 本文解读的是一篇发表于CVPR ...

  9. AAAI2019 | 基于区域分解集成的目标检测 论文解读

    Object Detection based on Region Decomposition and Assembly AAAI2019 | 基于区域分解集成的目标检测 论文解读 作者 | 文永亮 学 ...

  10. Gaussian field consensus论文解读及MATLAB实现

    Gaussian field consensus论文解读及MATLAB实现 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 一.Introduction ...

随机推荐

  1. Docker 与 Containerd 并用配置

    描述: 事实上,Docker 和 Containerd 是可以同时使用的,只不过 Docker 默认使用的 Containerd 的命名空间不是 default,而是 moby,此处为了更方便我们学习 ...

  2. ECON 模式

    ECON模式通过调节发动机和空调系统的性能,有效提高燃油经济性. 在D行驶档的时候开启

  3. MySQL集群搭建(5)-MHA高可用架构

    1 概述 1.1 MHA 简介 MHA - Master High Availability 是由 Perl 实现的一款高可用程序,出现故障时,MHA 以最小的停机时间(通常10-30秒)执行 mas ...

  4. 使用supervisor管理tomcat,nginx等进程详解

    1,介绍 官网:http://supervisord.org Supervisor是用Python开发的一套通用的进程管理程序,能将一个普通的命令行进程变为后台daemon,并监控进程状态,异常退出时 ...

  5. 项目的依赖包(node_modules)删除

    快速删除依赖包一共分为三部 1.打开命令行(管理员身份),执行 npm i -g npkill 2.cd 进入到想删除的项目中,执行 npkill 3.执行完成会进入到npkill页面,等待搜索完成, ...

  6. WMS 相比于 ERP 系统有哪些优势?

    WMS与ERP系统是两个不同的系统,不存储优势的比较!WMS是仓库管理系统(Warehouse Management System) 的缩写,ERP是Enterprise Resource Plann ...

  7. NSIS检测并统计字符串中某个字符个数

    !include "LogicLib.nsh" OutFile "检查找字符串中c出现的次数.exe" Name "test" Sectio ...

  8. HDU1423 Greatest Common Increasing Subsequence (DP优化)

    LIS和LCS的结合. 容易写出方程,复杂度是nm2,但我们可以去掉一层没有必要的枚举,用一个变量val记录前一阶段的最优解,这样优化成nm. 1<=k<j,j增加1,k的上界也增加1,就 ...

  9. TomCat之负载均衡

    TomCat之负载均衡 本文讲述了tomcat当nginx负载均衡服务器配置步骤 以下是Tomcat负载均衡配置信息 1.修改nginx的nginx.conf文件 添加如下属性:localhost是名 ...

  10. java多线程的两种创建方式

    方式一:继承Thread类 1.创建一个继承于Thread类的子类 2.重写Thread类的run()方法---> 将此线程执行的操作声明在run()中 3.创建Thread类的子类的对象 4. ...