论文解读(CDTrans)《CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation》
论文信息
论文标题:CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation
论文作者:Tongkun Xu, Weihua Chen, Pichao Wang, Fan Wang, Hao Li, Rong Jin
论文来源:ICLR 2022
论文地址:download
论文代码:download
1 Introduction
无监督域自适应(Unsupervised domain adaptation,UDA)的目的是将从标记源域学习到的知识转移到不同的未标记目标域。
UDA 方法:
① Domain-level UDA ,通过将源域和目标域在不同尺度水平上进入相同的分布来缓解源域之间的分布差异;
② fine-grained category-level UDA,通过将目标样本推向每个类别中的源样本的分布,对源域数据和目标域数据之间的每个类别分布进行对齐;(仍然存在标签噪声问题)
2 Method
2.1 The cross attention in Transformer
传统的方法给目标域打伪标签的过程中存在噪声,由于噪声的存在,需要对齐的源域和目标域的图片可能不属于同一类,强行对其可能产生很大的负面影响。而本文经过实验发现 Transformer 中的 CrossAttention 可以有效的避免噪声给对其造成的影响,CrossAttention 更多的关注源域和目标域中图片中的相似信息。换句话说,即使图片对不属于同一类,被拉近的也只会是两者相似的部分。因此,CDTrans 具有一定的抗噪能力。
由于在 UDA 任务中,目标域是没有标签的。因此只能借鉴伪标签的思路,来生成潜在的可能属于同一个 ID 的样本对。但是,伪标签生成的样本对中不可避免的会存在噪声。这时,本文发现 Cross Attention 对样本对中的噪声有着很强的鲁棒性。本文分析这主要是因为 Attention 机制所决定的,Attention 的 weight 更多的会关注两张图片相似的部分,而忽略其不相似的部分。如果源域图片和目标域图片不属于同一个类别的话,比如Figure 1.a“Car vs. Truck”的例子,Attention 的 weight 主要集中于两个图片中相似部分的对齐(比如轮胎),而对其他部位的对齐会给很小的 weight。
换句话说,Cross Attention 没有在使劲拉近对齐小轿车和卡车,而更多的是在努力对齐两个图片中的轮胎。一方面,Cross Attention 避免了强行拉近小轿车和卡车,减弱了噪声样本对 UDA 训练的影响;另一方面,拉近不同域的轮胎,在一定程度上可能帮助到目标域轮胎的识别。
自注意力(self-attention):
$\operatorname{Attn}_{\text {self }}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})=\operatorname{softmax}\left(\frac{\boldsymbol{Q} \boldsymbol{K}^{T}}{\sqrt{d_{k}}}\right) \boldsymbol{V}\quad\quad(1)$
交叉注意力(cross-attention):
$\operatorname{Attn}_{\text {cross }}\left(\boldsymbol{Q}_{s}, \boldsymbol{K}_{t}, \boldsymbol{V}_{t}\right)=\operatorname{softmax}\left(\frac{\boldsymbol{Q}_{s} \boldsymbol{K}_{t}^{T}}{\sqrt{d_{k}}}\right) \boldsymbol{V}_{t}\quad\quad(2)$
2.2 Two way center-aware pseudo labeling
2.2.1 Two way labeling
为了构建交叉注意模块的训练对,一种直观的方法是,对源域中的每一幅图像,我们设法从目标域找到最相似的图像。所选数据对的设置 $\mathbb{P}_{S}$ 为:
$\mathbb{P}_{S}=\left\{(s, t) \mid t=\underset{k}{\text{min}} \quad d\left(\boldsymbol{f}_{s}, \boldsymbol{f}_{k}\right), \forall k \in T, \forall s \in S\right\}\quad\quad\quad(3)$
其中,$S$、$T$ 分别为源数据和目标数据。$d\left(\boldsymbol{f}_{i}, \boldsymbol{f}_{j}\right)$ 表示图像 $i$ 和图像 $j$ 的特征之间的距离。
这种策略的优点是充分利用源数据,而其弱点显然是只涉及到目标数据的一部分。为了消除目标数据的这种训练偏差,我们从相反的方式引入了更多的对 $\mathbb{P}$,包括所有目标数据及其在源域中最相似的图像。
$\mathbb{P}_{T}=\left\{(s, t) \mid s=\underset{k}{\text{min}} \quad d\left(\boldsymbol{f}_{t}, \boldsymbol{f}_{k}\right), \forall t \in T, \forall k \in S\right\}\quad\quad\quad(4)$
因此,最终的 $\mathbb{P}$ 是两个集的并集,即 $\mathbb{P}=\left\{\mathbb{P}_{S} \cup \mathbb{P}_{T}\right\}$,使训练对包括所有的源数据和目标数据。
2.2.2 Center-Aware filtering
$\mathbb{P}$ 中的 pair 是基于两个域图像的特征相似性构建的,因此 pair 的伪标签的准确性高度依赖于特征相似性。
本文发现,源数据的预训练模型也有助于进一步提高精度。首先,我们通过将所有目标数据送到通过预先训练的模型,从分类器得到它们在源类别上的概率分布 $\delta$。这些分布可以通过加权 k-means 聚类来计算目标域内每个类别的初始中心:
${\large \boldsymbol{c}_{k}=\frac{\sum_{t \in T} \delta_{t}^{k} \boldsymbol{f}_{t}}{\sum_{t \in T} \delta_{t}^{k}}}\quad\quad(5) $
其中,$\delta_{t}^{k}$ 表示图像 $t$ 在类别 $k$ 上的概率。目标数据的伪标签可以通过最近邻分类器产生:
$y_{t}=\arg \underset{k}{\text{min}} \; d\left(\boldsymbol{c}_{k}, \boldsymbol{f}_{t}\right) \quad\quad(6) $
其中,$t \in T$ 和 $d(i, j)$ 是特征 $i$ 和 $j$ 的距离。基于伪标签,我们可以计算出新的中心:
${\normalsize \boldsymbol{c}_{k}^{\prime}=\frac{\sum_{t \in T} \mathbb{1}\left(y_{t}=k\right) \boldsymbol{f}_{t}}{\sum_{t \in T} \mathbb{1}\left(y_{t}=k\right)}} \quad\quad(7) $
2.3 CDTrans:Cross-Domain Transformer
框架如下:
上述 CDTrans 框架包括三个权重共享的 transformer ,分别是 source branch, source-target branch, target branch 。
输入对中的源图像和目标图像分别被发送到 source branch 和 target branch 。在这两个分支中,self-attention 涉及到学习特定领域的表示。并利用 softmax cross-entropy loss 进行分类训练。值得注意的是,由于两个图像的相同标签,所有三个分支共享相同的分类器。交叉注意力模块被导入到 source-target branch 中。source-target branch 的输入来自其他两个分支。在第 $N$ 层中,交叉注意模块的 query 来自于source branch 的第 $N$ 层中的查询,而 key 和 value 来自于 target branch 的查询。source-target branch 的特征不仅对齐了两个域的分布,而且由于交叉注意模块,对输入对中的噪声具有鲁棒性。因此,本文使用 source-target branch 的输出来指导目标分支的训练。具体来说,source-target branch 和目标分支分别表示为 teacher 和 student 。本文将分类器在 source-target branch 中的概率分布作为一个软标签,可以通过蒸馏损失来进一步监督目标分支
${\large L_{d t l}=\sum\limits_{k} q_{k} \log p_{k}} \quad\quad(8) $
其中,$q_{k}$ 和 $p_{k}$ 分别是 source-target branch 和 target branch 的概率。
论文解读(CDTrans)《CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation》的更多相关文章
- [论文解读] 阿里DIEN整体代码结构
[论文解读] 阿里DIEN整体代码结构 目录 [论文解读] 阿里DIEN整体代码结构 0x00 摘要 0x01 文件简介 0x02 总体架构 0x03 总体代码 0x04 模型基类 4.1 基本逻辑 ...
- CVPR2020论文解读:三维语义分割3D Semantic Segmentation
CVPR2020论文解读:三维语义分割3D Semantic Segmentation xMUDA: Cross-Modal Unsupervised Domain Adaptation for 3 ...
- 自监督学习(Self-Supervised Learning)多篇论文解读(上)
自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...
- 论文解读丨表格识别模型TableMaster
摘要:在此解决方案中把表格识别分成了四个部分:表格结构序列识别.文字检测.文字识别.单元格和文字框对齐.其中表格结构序列识别用到的模型是基于Master修改的,文字检测模型用到的是PSENet,文字识 ...
- NLP论文解读:无需模板且高效的语言微调模型(上)
原创作者 | 苏菲 论文题目: Prompt-free and Efficient Language Model Fine-Tuning 论文作者: Rabeeh Karimi Mahabadi 论文 ...
- 论文解读(SR-GNN)《Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data》
论文信息 论文标题:Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data论文作者:Qi Zhu, ...
- itemKNN发展史----推荐系统的三篇重要的论文解读
itemKNN发展史----推荐系统的三篇重要的论文解读 本文用到的符号标识 1.Item-based CF 基本过程: 计算相似度矩阵 Cosine相似度 皮尔逊相似系数 参数聚合进行推荐 根据用户 ...
- CVPR2019 | Mask Scoring R-CNN 论文解读
Mask Scoring R-CNN CVPR2019 | Mask Scoring R-CNN 论文解读 作者 | 文永亮 研究方向 | 目标检测.GAN 推荐理由: 本文解读的是一篇发表于CVPR ...
- AAAI2019 | 基于区域分解集成的目标检测 论文解读
Object Detection based on Region Decomposition and Assembly AAAI2019 | 基于区域分解集成的目标检测 论文解读 作者 | 文永亮 学 ...
- Gaussian field consensus论文解读及MATLAB实现
Gaussian field consensus论文解读及MATLAB实现 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 一.Introduction ...
随机推荐
- 组件化开发1-git命令简洁版
1-给项目添加git git init 2-查询当前状态,(红色显示的为在工作区,绿色为暂缓区) git status 3-提交到暂缓区 git add . 4-提交到本地仓库('xxxx'里面为注释 ...
- 努力一周,开源一个超好用的接口Mock工具——Msw-Tools
作为一名前端开发,是不是总有这样的体验:基础功能逻辑和页面UI开发很快速,本来可以提前完成,但是接口数据联调很费劲,耗时又耗力,有时为了保证进度还不得不加加班. 为了摆脱这种痛苦,经过一周的努力,从零 ...
- Docker容器获取宿主机信息
最近在做产品授权的东西,开始宿主机为Window,程序获取机器硬件信息相对简单些,后来部署时发现各种各样的的环境问题,所有后来改用dokcer部署,docker方式获取宿主机信息时花了些时间,特此记录 ...
- aws-cli命令-S3相关的操作及管理
在工作中,我们可能经常会将本地数据上传S3进行备份,或者将S3数据下载到本地 本文主要讲解下,工作中可能经常会用到的与S3相关的操作 1.将本地目录的数据同步到指定的S3位置,及s3资源管理 # 同步 ...
- Mybatis PageHelper 使用的注意事项
什么时候会导致不安全的分页? PageHelper 方法使用了静态的 ThreadLocal 参数,分页参数和线程是绑定的. 只要你可以保证在 PageHelper 方法调用后紧跟 MyBatis 查 ...
- AspNetCore中 使用 Grpc 简单Demo
为什么要用Grpc 跨语言进行,调用服务,获取跨服务器调用等 目前我的需要使用 我的抓取端是go 写的 查询端用 Net6 写的 导致很多时候 我需要把一些临时数据写入到 Redis 在两个服务器进行 ...
- IDEA对数据库、表、记录的(增删改查可视化操作)、数据库安全性问题的演示
对数据库的增删改查 新增数据库 修改数据库 删除数据库 对表的增删改查 新增表 修改表 删除表 对记录的增删改查 数据库安全性问题的演示 演示脏读 一个事物里面读到了另外一个事物没有提交的数据: ...
- Docker_基础知识
容器概述 容器本义:盛装物体.隔离物体. 容器意义:解决虚拟化资源浪费的问题. 容器沿革:1979---2013--- 版本:企业版(EE)/社区版(CE)1. ...
- SQL面试50题------(初始化工作、建立表格)
文章目录 1.建表 1.1 学生表和插入数据 1.2 教师表和数据 1.3 课程表和数据 1.4 成绩表和数据 2.数据库数据 2.1 学生表 2.2 教师表 2.3 课程表 2.4 得分表 1.建表 ...
- 3.版本穿梭&分支概述
版本穿梭 如果我们提交了多个版本到本地仓库,想将工作区恢复到历史版本 可以先使用git reflog查看历史记录,获取到版本号 然后使用git rest --hard 版本号 命令恢复到指定版本 gi ...