论文信息

论文标题:PCL: Proxy-based Contrastive Learning for Domain Generalization
论文作者:
论文来源:
论文地址:download 
论文代码:download
引用次数:

1 前言

  域泛化是指从一组不同的源域中训练一个模型,可以直接推广到不可见的目标域的问题。一个很有前途的解决方案是对比学习,它试图通过利用来自不同领域的样本到样本对之间丰富的语义关系来学习领域不变表示。一种简单的方法是将来自不同域的正样本对拉得更近,同时将其他负样本对推得更远。

  在本文中,我们发现直接应用基于对比的方法(如有监督的对比学习)在领域泛化中是无效的。本文认为,由于不同域之间的显著分布差距,对准正样本到样本对往往会阻碍模型的泛化。为了解决这个问题,提出了一种新的基于代理的对比学习方法,它用代理到样本关系代替了原始的样本-样本关系,显著缓解了正对齐问题。

2 方法

  整体框架

    

2.1 启发

  现有对比学习的对比损失大多考虑正对和负对,本文受到 [ 61 ] 损失函数的启发,它只考虑正样本之间的关系,假设 $x_i$、$x_j$ 是从同一类的不同源域进行采样。设 $z=F_{\theta}(\boldsymbol{x})$ 是由特征提取器 $F_{\theta}$ 提取的特征,我们有:

    $\mathcal{L}_{\mathrm{pos}}=\frac{1}{\alpha} \log \left(1+\sum \exp \left(-\boldsymbol{z}_{i}^{\top} \boldsymbol{z}_{j} \cdot \alpha\right)\right)  \quad\quad\quad(1)$

  实验:是否使用 包含正对之间对比的 $\text{instance} - \text{instance}$ 之间的对比学习?

    

  结果:单纯使用交叉熵损失比 交叉熵损失 +  正对之间的对齐 效果还好,所以跨域之间的正对对齐是有害的。

2.2 问题定义

  多源域适应;

  特征提取器:$F_{\theta}: X \rightarrow Z$

  分类器:$G_{\psi}: \mathcal{Z} \rightarrow \mathbb{R}^{C}$

2.3 交叉熵回顾

  交叉熵损失函数:

    $\mathcal{L}_{\mathrm{CE}}=-\log \frac{\exp \left(\boldsymbol{w}_{c}^{\top} \boldsymbol{z}_{i}\right)}{\exp \left(\boldsymbol{w}_{c}^{\top} \boldsymbol{z}_{i}\right)+\sum_{j=1}^{C-1} \exp \left(\boldsymbol{w}_{j}^{\top} \boldsymbol{z}_{i}\right)}  \quad\quad\quad(2)$

  其中,$\boldsymbol{w}_{c}$ 代表目标域的某一类中心;

  $\text{Softmax CE}$ 损失只考虑了代理到样本的关系,而忽略了丰富的语义样本与样本之间的关系。

2.4 对比损失回顾

  对比损失函数:

    $\mathcal{L}_{\mathrm{CL}}=-\log \frac{\exp \left(\boldsymbol{z}_{i}^{\top} \boldsymbol{z}_{+} \cdot \alpha\right)}{\exp \left(\boldsymbol{z}_{i}^{\top} \boldsymbol{z}_{+} \cdot \alpha\right)+\sum \exp \left(\boldsymbol{z}_{i}^{\top} \boldsymbol{z}_{-} \cdot \alpha\right)}$

  基于对比的损失考虑了丰富的样本与样本之间的关系。其关键思想是学习一个距离,将 $\text{positive pairs}$ 拉近,将 $\text{negative pairs}$ 推远。

2.5 困难样本挖掘

  公式:

    $\begin{aligned}\mathcal{L}_{\mathrm{CL}} & =\lim _{\alpha \rightarrow \infty} \frac{1}{\alpha}-\log \left(\frac{\exp \left(\alpha \cdot s_{p}\right)}{\exp \left(\alpha \cdot s_{p}\right)+\sum_{j=1}^{N-1} \exp \left(\alpha \cdot s_{n}^{j}\right)}\right) \\& =\lim _{\alpha \rightarrow \infty} \frac{1}{\alpha} \log \left(1+\sum_{j=1}^{N-1} \exp \left(\alpha\left(s_{n}^{j}-s_{p}\right)\right)\right) \\& =\max \left[s_{n}^{j}-s_{p}\right]_{+} .\end{aligned}$

  理解:由于域之间的域差异很大,简单的拉近正对之间的距离,拉远负对之间的距离是不合适的,这是由于往往存在某些难学的样本,使得模型总是识别错误。

2.6 基于代理的对比学习

  $\text{Softmax}$ 损失 在学习类代理方面是有效的,能够快速、安全地收敛,但不考虑样本与样本之间的关系。基于对比损失利用了丰富的 样本-样本 关系,但在优化密集的 样本-样本 关系方面训练复杂性高。

    $\mathcal{L}_{\mathrm{PCL}}=-\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp \left(\boldsymbol{w}_{c}^{\top} \boldsymbol{z}_{i} \cdot \alpha\right)}{Z}$

  基于代理的对比损失:

    $\mathcal{L}_{\mathrm{PCL}}=-\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp \left(\boldsymbol{w}_{c}^{\top} \boldsymbol{z}_{i} \cdot \alpha\right)}{Z}$

  其中,

    $Z=\exp \left(\boldsymbol{w}_{c}^{\top} \boldsymbol{z}_{i} \cdot \alpha\right)+\sum_{k=1}^{C-1} \exp \left(\boldsymbol{w}_{k}^{\top} \boldsymbol{z}_{j} \cdot \alpha\right)+\sum_{j=1, j \neq i}^{K} \exp \left(\boldsymbol{z}_{i}^{\top} \boldsymbol{z}_{j} \cdot \alpha\right)$

  Note:$N$ 代表的是 $\text{batch_size}$ 的大小,$K$ 代表的是 $x_i$ 负样本的数量。

2.7 施加投影头的基于代理的对比学习

  公式:

    $\mathcal{L}_{\mathrm{PCL}-\mathrm{in}}=-\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp \left(\boldsymbol{v}_{c}^{\top} \boldsymbol{e}_{i}\right)}{E}$

  其中,

    $E=\exp \left(\boldsymbol{v}_{c}^{\top} \boldsymbol{e}_{i}\right)+\sum_{k=1}^{C-1} \exp \left(\boldsymbol{v}_{k}^{\top} \boldsymbol{e}_{j}\right)+\sum_{j=1, j \neq i}^{B} \exp \left(\boldsymbol{e}_{i}^{\top} \boldsymbol{e}_{j}\right)$

2.8 训练

  训练目标:

    $\mathcal{L}_{\text {final }}=\mathcal{L}_{\mathrm{CE}}+\lambda \cdot \mathcal{L}_{\text {PCL-in }}$

3 实验结果

正对齐实验的细节

  

消融实验

  

超参数实验

  

困难样本分析

  

4 总结

  略

迁移学习(PCL)《PCL: Proxy-based Contrastive Learning for Domain Generalization》的更多相关文章

  1. 论文解读(PCL)《Probabilistic Contrastive Learning for Domain Adaptation》

    论文信息 论文标题:Probabilistic Contrastive Learning for Domain Adaptation论文作者:Junjie Li, Yixin Zhang, Zilei ...

  2. 论文解读(PCL)《Prototypical Contrastive Learning of Unsupervised Representations》

    论文标题:Prototypical Contrastive Learning of Unsupervised Representations 论文方向:图像领域,提出原型对比学习,效果远超MoCo和S ...

  3. 【迁移学习】2010-A Survey on Transfer Learning

    资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...

  4. Google Tensorflow 迁移学习 Inception-v3

    附上代码加数据地址 https://github.com/Liuyubao/transfer-learning ,欢迎参考. 一.Inception-V3模型 1.1 详细了解模型可参考以下论文: [ ...

  5. 迁移学习(Transfer Learning)(转载)

    原文地址:http://blog.csdn.net/miscclp/article/details/6339456 在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型 ...

  6. 迁移学习(Transfer Learning)

    原文地址:http://blog.csdn.net/miscclp/article/details/6339456 在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型 ...

  7. 迁移学习( Transfer Learning )

    在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型:然后利用这个学习到的模型来对测试文档进行分类与预测.然而,我们看到机器学习算法在当前的Web挖掘研究中存在着一个关 ...

  8. 迁移学习-Transfer Learning

    迁移学习两种类型: ConvNet as fixed feature extractor:利用在大数据集(如ImageNet)上预训练过的ConvNet(如AlexNet,VGGNet),移除最后几层 ...

  9. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  10. Domain adaptation:连接机器学习(Machine Learning)与迁移学习(Transfer Learning)

    domain adaptation(域适配)是一个连接机器学习(machine learning)与迁移学习(transfer learning)的新领域.这一问题的提出在于从原始问题(对应一个 so ...

随机推荐

  1. WSL2与ensp的40故障

    在使用ensp做radius认证的时候看到了Linux平台的freeradius认证服务器,于是使用了Windows平台的sub system: WSL2,按照网上的教程安装,并且安装了docker ...

  2. 一个线程池的c++实现

    前面我们实现了CallBack类,实现了对任意可调用对象的封装,且统一了调用接口. 现在利用CallBack类,我们来实现一个线程池,我们的线程池包含: 1. 状态机, 用于控制和管理线程池的运行.停 ...

  3. Java 获取【.jar】文件里的资源文件

    获取jar文件里的图片等文件时,会发现使用相对路径不行了. 因为打包后的jar文件,在获取路径时稍有不同. 下面是获取jar文件中图片的例子: 1 Resource[] resources = new ...

  4. c++ 保存txt文件

    #include <iostream> #include <stdio.h> #include <fstream> #include <queue> # ...

  5. 重写mybatis-plus的saveUpdate方法

    重写mybatis-plus的saveUpdate方法 1.问题出现 同步外部数据的时候,如果需要同步逻辑删除的数据,mybatis-plus的saveOrUpdate||saveOrUpdateBa ...

  6. openwrt扩容

    方法二.三记得先使用Linux系统打开 GParted -- Download 方法三偏移地址获取: 1. 运行的openwrt安装losetup 2. 安装完毕后执行:losetup 获取偏移地址. ...

  7. springboot项目 报错No mapping for GET /css/bootstrap.css,前端无法展示样式

    说来也奇怪,前几天刚写完的项目 写的好好的 现在打开他就加载不了前端的静态资源了 报错No mapping for GET /css/bootstrap.css 解决方法: 新建一个配置类 ,将静态资 ...

  8. js两个数组对象合并去重

  9. 07 HBase操作

    1.理解HBase表模型及四维坐标:行键.列族.列限定符和时间戳. 2.启动HDFS,启动HBase,进入HBaseShell命令行. 3.列出HBase中所有的表信息list 4.创建表create ...

  10. PHP实现斐波那契数列(递归 + 非递归)实现

    非递归写法:function fbnq($n){ //传入数列中数字的个数    if($n <= 0){        return 0;    }    $array[1] = $array ...