• 文章转自:微信公众号「机器学习炼丹术」
  • 作者:炼丹兄(已授权)
  • 联系方式:微信cyx645016617
  • 论文名称:“Learning From Synthetic Data: Addressing Domain Shift for Segmentation”

「前言」:最近好久没更新公众号了,我一不小心陷入了一个误区:我以为自己看的文章足够多了,用之前的风格迁移和GAN的知识来解决一个domain adaptive的问题,一顿乱拳并没有打死老师傅,反而自己累个够呛。然后找到这样一篇不错的DA framework,来认真学习一下章法,假期结束重新用章法组合拳再来会会。

0 综述

不同于以往的对抗模型或者是超像素信息来实现这个领域迁移,本文使用的是对抗生成网络GAN来将两个领域的特征空间拉近。

本文提出的是语义分割的领域自适应算法。论文特别关注的问题是:目标领域没有label

传统的DA方法包含最小化某些可以衡量source和target两个分布的距离函数。两种常见的度量是:

  • 最大均值差(Maximum Mean Discrepancy, MMD)
  • 通过对抗学习,使用DCNN来学习distance metric

本文的主要贡献在于提出了一种基于生成模型的特征空间源分布与目标分布对齐算法。

1 method

从图片中来初步判断,其实是比较好理解的:

  • 首先,我猜测其做域迁移,可能是仿照GAN领域中做风格迁移的办法;
  • 图片中总共有4个网络,F网络应该是特征提取网络,C网络是做分割的网络,G网络是把F提取的特征再还原成原图的网络,D网络是做分类的网络,和一般GAN不同的是,D中做四个分类,是True source,True target, False source, False targe. 类似于把cycleGAN中的两个二分类的discriminator合并了。

2 细节

原始图片定义为\(X\),source domain的图片定义为\(X^s\),target domain的图片定义为\(X^t\).

  • base network. 架构类似于预训练的VGG16,被分成了两个部分:特征提取部分叫做F网络,做像素分割的叫做C网络。
  • G网络是用来从F生成的embedding特征中,重建原始图像的;D网络不仅要分别出图片是否是real or fake,还会做一个分割任务,类似于C网络。这个分割任务仅仅针对source domain,因为target domain不存在标签。

现在我们假定已经准备好了数据和标签\({X^s,Y^s}\):

  • 首先经过F提取出来feature expression,\(F(X^s)\)
  • C网络生成分割的标签\(\hat{Y}^s\)
  • G网络重建图片\(\hat{X}^s\)

基于最近的相关的成功的研究,不再在G的输入中显式的concatenate一个随机变量,而是在Generator中使用dropout layer

3 损失

作者提出了很多的对抗损失:

  • 在一个domin内的损失有:

    • Discriminator损失,分辨src-real和src-fake;
    • Discriminator损失,分辨tgt-real和tgt-fake;
    • Generator损失,让fake source可以被discriminator判断成src-real的损失;
  • 在不同domain的损失:
    • F网络的损失,可以让fake source的输入被判断为real target;
    • F网络的损失,可以让fake target的输入被判断为real source;

除了上面说到的对抗损失外,还有下面的分割损失

  • \(L_{seg}\):在标准分割网络C中的pixel-wise的交叉熵损失;
  • \(L_{aux}\):D网络也会输出一个分割结果,交叉熵损失;
  • \(L_{rec}\):原始图像和重建图像之间的L1损失。

4 训练过程

在每一个iteration中,一个随机的三元组被输入到模型中:\(\{X^s,Y^s,X^t\}\),然后网络按照下面的顺序进行更新参数:

  1. 先更新参数D,更新策略如下:

    • 对于source input,用\(L_{aux}\)和\(L^s_{adv,D}\);
    • 对于target input,用\(L^t_{adv,D}\)

  1. 然后更新G,更新策略如下:

    • 愚弄discriminator的两个loss,\(L^s_{adv,G}\)和\(L^t_{adv,G}\);
    • 重建损失,\(L^s_{rec}\)和\(L^t_{rec}\);

  1. F网络的更新策略如下:

    • F网络的更新是最关键的!(论文中说的)

- 更新F网络是为了实现domain adaptive,$L^s_{adv,F}$是为了混淆fake source 和real target;
- 类似于G-D之间的min-max game,这里是F和D之间的竞争,只不过前者是为了混淆fake和real,后者是为了混淆source domain和target domain;

5 D的设计动机

我们可以发现,这里面的D其实不是传统的GAN中的D,输出不再是单独的一个scalar,表示图片是fake or real的概率

最近有一篇GAN里面提到了,patch discriminator(这个论文恰好之前读过),这个是让D输出的也是一个二位的量,每一个值表示对应patch的fake or real的概率,这个措施极大的提高了G重建的图片的质量,这里继承延伸了patch discriminator的思想,输出的图片是一个pixel-wise的类似分割的结果,每一个像素有四个类别:fake-src,real-src,fake-tgt,real-tgt;

GAN一般是比较难训练的,尤其是针对大尺度的真实图片数据,一种稳定的方法来训练生成模型的架构是Auxiliary Classifier GAN(ACGAN)(真好,这个论文我之前也看过),简单的说通过增加一个辅助分类损失,可以训练一个更稳定的G,因此这也是为什么D中还会有一个分割损失\(L_{aux}\)

6 总结

作者提高,每一个组件都提供了关键的信息,不多说了,假期回实验室我要开始用章法组合拳来解决问题了。

域迁移DA | Learning From Synthetic Data: Addressing Domain Shift for Se | CVPR2018的更多相关文章

  1. In machine learning, is more data always better than better algorithms?

    In machine learning, is more data always better than better algorithms? No. There are times when mor ...

  2. 多标记学习--Learning from Multi-Label Data

    传统分类问题,即多类分类问题是,假设每个示例仅具有单个标记,且所有样本的标签类别数|L|大于1,然而,在很多现实世界的应用中,往往存在单个示例同时具有多重标记的情况. 而在多分类问题中,每个样本所含标 ...

  3. Coursera, Big Data 4, Machine Learning With Big Data (week 1/2)

    Week 1 Machine Learning with Big Data KNime - GUI based Spark MLlib - inside Spark CRISP-DM Week 2, ...

  4. 不平衡学习 Learning from Imbalanced Data

    问题: ICC警情数据分类不均,30+分类,最多的分类数据数量1w+条,只有10个类别数量超过1k,大部分分类数量少于100条. 解决办法: 下采样:通过非监督学习,找出每个分类中的异常点,减少数据. ...

  5. R8:Learning paths for Data Science[continuous updating…]

    Comprehensive learning path – Data Science in Python Journey from a Python noob to a Kaggler on Pyth ...

  6. 使用ADMT和PES实现window AD账户跨域迁移-介绍篇

    使用 ADMT 和 pwdmig 实现 window AD 账户跨域迁移系列: 介绍篇 ADMT 安装 PES 的安装 ADMT:迁移组 ADMT:迁移用户 ADMT:计算机迁移 ADMT:报告生成 ...

  7. A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2021)

    A Unified Deep Model of Learning from both Data and Queries for Cardinality Estimation 论文解读(SIGMOD 2 ...

  8. Overcoming Forgetting in Federated Learning on Non-IID Data

    郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布! 以下是对本文关键部分的摘抄翻译,详情请参见原文. NeurIPS 2019 Workshop on Federated Learning ...

  9. mysql迁移-----拷贝mysql目录/load data/mysqldump/into outfile

    摘要:本文简单介绍了mysql的三种备份,并解答了有一些实际备份中会遇到的问题.备份恢复有三种(除了用从库做备份之外), 直接拷贝文件,load data 和 mysqldump命令.少量数据使用my ...

随机推荐

  1. Vuex入门、同步异步存取值进阶版

    关注公众号查看原文: 1. vueX介&绍安装 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式.它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方 ...

  2. alpine jdk 中文乱码

    一.概述 使用alpine镜像构建了一个oracle jdk的镜像,运行java业务时,查看日志,显示中文乱码. 但是,基于Alpine Linux的Docker基础镜像的镜像文件很小,也有代价: 把 ...

  3. 更改EFI分区位置

    我是win10 + arch 双系统,并且efi分区用的是win10自动创建的(大小100m),所以这些空间很快就不够用了(内核和initramfs都放在了ESP分区当中) 我原本是直接把win的ef ...

  4. 【转载】KMP入门级别算法详解--终于解决了(next数组详解)

    [转载]https://blog.csdn.net/LEE18254290736/article/details/77278769 对于正常的字符串模式匹配,主串长度为m,子串为n,时间复杂度会到达O ...

  5. javascript处理HTML的Encode(转码)和解码(Decode)

    HTML的Encode(转码)和解码(Decode)在平时的开发中也是经常要处理的,在这里总结了使用javascript处理HTML的Encode(转码)和解码(Decode)的常用方式 一.用浏览器 ...

  6. Java 集合框架 04

    集合框架·Map 和 Collections集合工具类 Map集合的概述和特点 * A:Map接口概述 * 查看API可知: * 将键映射到值的对象 * 一个映射不能包含重复的键 * 每个键最多只能映 ...

  7. springmvc redis @Cacheable扩展(一)

    springmvc 中有自带的cache处理模块,可以是方法级别的缓存处理,那么在实际使用中,很可能自己造轮子,因为实际中永远会有更奇怪的需求点.比如: 1 清除缓存时候,能模糊的进行删除 2 针对不 ...

  8. Linux系统用户与用户组管理

    一.用户和用户组的管理 1.新增组 groupadd 命令 格式:groupadd 组名 2.删除组 groupdel 格式:groupdel 组名 3.增加用用户命令 useradd   格式:us ...

  9. 翻译:《实用的Python编程》05_00_Overview

    目录 | 上一节 (4 类和对象) | 下一节 (6 生成器) 5. Python 对象的内部工作原理 本节介绍 Python 对象的内部工作原理.来自其它语言的程序员通常会发现 Python 的类概 ...

  10. ajax轮询原理及其实现方式

    ajax轮询原理及其实现方式 ajax轮询的两种方式 方式1:设定一个定时器,无论有无结果返回,时间一到就会继续发起请求,这种轮询耗费资源,也不一定能得到想要的数据,这样的轮询是不推荐的. 方式2: ...