StarGAN的引入是为了解决多领域间的转换问题的,之前的CycleGAN等只能解决两个领域之间的转换,那么对于含有C个领域转换而言,需要学习C*(C-1)个模型,但StarGAN仅需要学习一个,而且效果很棒,如下:

创新点:为了实现可转换到多个领域,StarGAN加入了一个域的控制信息,类似于CGAN的形式。在网络结构设计上,鉴别器不仅仅需要学习鉴别样本是否真实,还需要对真实图片判断来自哪个域。

整个网络的处理流程如下:

  1. 将输入图片x和目标生成域c结合喂入到生成网络G来合成fake图片
  2. 将fake图片和真实图片分别喂入到鉴别器D,D需要判断图片是否真实,还需要判断它来自哪个域
  3. 与CycleGAN类似,还有一个一致性约束,将生成的fake图片和原始图片的域信息c'结合起来喂入到生成器G要求能输出重建出原始输入图片x

下面分析一下各个部分的损失函数:

一:GAN常见的对抗损失:

二:对于给定的输入图片x和目标域标签c,网络的目标是将x转换成输出图片y,输出图片y能够被归类成目标域c。为了实现这一点就需要鉴别器有判别域的功能。所以作者在D的顶端加了一个额外的域分类器,域分类器loss在优化D和G时都会用到,作者将这一损失分为两个方向,分别用来优化G和D。(这很容易理解,因为如下分析可以看到公式(3)没有办法为D提供训练需要的监督信息)

一个是真实图片的域分类损失用来优化D,另一个是fake图片的域分类损失来优化G。

1)

Dcls(c'|x)代表D对真实图片计算得到的域标签概率分布。这一学习目标将会使得D能够将输入图片x识别为对应的域c',这里的(x,c')是训练集给定的。

2)

fake图片的域分类的损失函数定义如(3),它用来优化G,也就是让G尽力去生成图片让它能够被D分类成目标域c。

三:还有一个重建损失

通过最小化对抗损失与分类损失,G努力尝试做到生成目标域中的现实图片。但是这无法保证学习到的转换只会改变输入图片的域相关的信息而不改变图片内容。所以加上了周期一致性损失:

这里就是将G(x,c)和图片x的原始标签c'结合喂入到G中,将生成的图片和x计算1范数差异。

总体损失:

在实际操作上,作者将对抗损失换成了WGAN的对抗损失:

以上对于单个数据集的训练来说已经足够了,但是现在想想另一个问题,假如我要联合训练多个数据集呢?

举例来说,celebA和RaFD数据集,前者有发色和性别信息,后者有面部表情信息,我能将celebA中的人物改变一下面部表情吗?

一个很简单的想法是如果我原来的域标注信息是5位的onehot编码,现在变长为8位不就可以了。但是这存在一个问题就是celebA中的人其实也有表情,只是没有标注,RaFD其实也有性别区别,但对于网络来说没标记就是未知的。简单扩充域标记信息位是肯定不行的。我们希望网络只关注它有明确信息的那一部分标注。

因此,作者加了一个mask。在联合多个数据集训练时把mask向量也输入到生成器。

以上的ci代表第i个数据集的标签,已知标签ci如果是二进制属性则可以表示为二进制向量,如果为类别属性表示一个onehot。剩下的n-1个则指定为0。m则是一个长度为n的onehot编码。这样网络就会只关注已给定的标签。

论文部分到此结束,下面来分析一下代码

主要的代码有model.py和solver.py两个。

在model.py中作者创建了生成器G与鉴别器D。

在生成器中先对模型降维缩小为原来4倍,再使用多个残差网络获得等维度输出,接着使用转置卷积放大4倍,最后通过一层尺寸不变的卷积,取tanh作为输出。

另外一个值得注意的是生成器如何将输入图片与目标域c一起结合作为输入的,代码中可以看出就是直接在第四维度上进行拼接(pytorch一般为N*C*H*W,所以看起来是在第二维)。

对于鉴别器,使用conv1的输出代表域的预测概率,conv2的输出代表图片是否为真的判断。这两个的关系是并行的。

Solver.py比较长,挑选重要的部分来解释:

首先是梯度惩罚,这一部分来自WGAN的改善工作,主要是为了满足Lipschitz连续这个WGAN推导中需要的数学约束。

令人疑惑的是分类loss并不都是交叉熵损失,这是因为CelebA的标签是多属性的,不是一个onehot,所以使用了一个多个二分类的形式,而RaFD则是一个onehot。

下面来看看在多个训练集训练时代码上是怎么操作的。

在数据加载上其实还是单个数据集轮流进行操作的,如下:

以上提到在多数据集训练时,我们需要mask向量,mask向量的形成按如下形式进行拼接,前面是celebA的label后面是RaFD的label,最后是onehot,代表了哪个数据集的标签是已知的。

以生成器为例,计算损失时也是只在输出判断向量中提取该数据集已知的部分进行loss计算。

StarGAN论文及代码理解的更多相关文章

  1. HHL论文及代码理解(Generalizing A Person Retrieval Model Hetero- and Homogeneously ECCV 2018)

    行人再识别Re-ID面临两个特殊的问题: 1)源数据集和目标数据集类别完全不同 2)相机造成的图片差异 因为一般来说传统的域适应问题源域和目标域的类别是相同的,相机之间的不匹配也是造成行人再识别数据集 ...

  2. Context Encoder论文及代码解读

    经过秋招和毕业论文的折磨,提交完论文終稿的那一刻总算觉得有多余的时间来搞自己的事情. 研究论文做的是图像修复相关,这里对基于深度学习的图像修复方面的论文和代码进行整理,也算是研究生方向有一个比较好的结 ...

  3. [ZZ]计算机视觉、机器学习相关领域论文和源代码大集合

    原文地址:[ZZ]计算机视觉.机器学习相关领域论文和源代码大集合作者:计算机视觉与模式 注:下面有project网站的大部分都有paper和相应的code.Code一般是C/C++或者Matlab代码 ...

  4. linux io的cfq代码理解

    内核版本: 3.10内核. CFQ,即Completely Fair Queueing绝对公平调度器,原理是基于时间片的角度去保证公平,其实如果一台设备既有单队列,又有多队列,既有快速的NVME,又有 ...

  5. 10K+,深度学习论文、代码最全汇总!

    我们大部分人是如何查询和搜集深度学习相关论文的?绝大多数情况是根据关键字在谷歌.百度搜索.想寻找相关论文的复现代码又会去 GitHub 上搜索关键词.浪费了很多时间不说,论文.代码通常也不够完整.怎么 ...

  6. (转) AI突破性论文及代码实现汇总

    本文转自:https://zhuanlan.zhihu.com/p/25191377 AI突破性论文及代码实现汇总 极视角 · 2 天前 What Can AI Do For You? “The bu ...

  7. 通过汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的

    秦鼎涛  <Linux内核分析>MOOC课程http://mooc.study.163.com/course/USTC-1000029000 实验一 通过汇编一个简单的C程序,分析汇编代码 ...

  8. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  9. 通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的

    实验一:通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的 学号:20135114 姓名:王朝宪 注: 原创作品转载请注明出处   <Linux内核分析>MOOC课程http: ...

随机推荐

  1. Failed to load resource: the server responded with a status of 404 ()

    今天遇到了一个一开始感觉很莫名其妙的报错 在编写页面的时候把原先写在html页面里的js代码单独拿出来做成一个JavaScriptUtil文件,放在了和html页面同一个目录下.运行之后在对应的页面c ...

  2. cnvd进阶学习

    说明 cnvd相对在src漏洞平台中还是比较具备含金量的.今天证书的申请标准就不说了,总归网上都有,主要是想分享下怎么去挖漏洞. 咱们这里只讲通用型漏洞,事件型的暂时我也没挖到.挖通用型漏洞主要方法就 ...

  3. python学习Day21

    目录 今日内容详细 作业讲解 os模块 知识点进修 创建目录(文件夹) 删除目录(文件夹) 查看某个路径下所有的文件名称(文件.文件夹) 删除文件.重命名文件 获取当前路径.切换路径 软件开发目录规范 ...

  4. 攻防世界-MISC:simple_transfer

    这是攻防世界高手进阶区的题目,题目如下: 点击下载附件一,得到一个流量包,用wireshark打开搜索flag无果,无奈跑去查看WP,说是先查看一下协议分级,但是并没有像WP所说的协议的字节百分比占用 ...

  5. 「BUAA OO Unit 2 HW8」第二单元总结

    「BUAA OO Unit 2 HW8」第二单元总结 目录 「BUAA OO Unit 2 HW8」第二单元总结 Part 0 前言 Part 1 第五次作业 1.1 作业要求 1.2 架构设计 1. ...

  6. Oracle19c单实例数据库配置OGG单用户数据同步测试

    目录 19c单实例配置GoldenGate 并进行用户数据同步测试 一.数据库操作 1.开启数据库附加日志 2.开启数据库归档模式 3.开启goldengate同步 4.创建goldengate管理用 ...

  7. 软件项目管理 ——1.2.PMBOK与软件项目管理知识体系

    软件项目管理 --1.2.PMBOK与软件项目管理知识体系 归档于软件项目管理初级学习路线 第一章 软件项目管理基本概念 <初级学习路线合集 > @ 目录 软件项目管理 --1.2.PMB ...

  8. AngularJS搭建环境

    一.搭建环境 1.1 调试工具:batarang Chrome浏览器插件 主要功能:查看作用域.输出高度信息.性能监控 1.2 依赖软件:Node.js 下载:https://nodejs.org/e ...

  9. iOS全埋点解决方案-数据存储

    前言 ​ SDK 需要把事件数据缓冲到本地,待符合一定策略再去同步数据. 一.数据存储策略 ​ 在 iOS 应用程序中,从 "数据缓冲在哪里" 这个纬度看,缓冲一般分两种类型. 内 ...

  10. python之部分内置函数与迭代器与异常处理

    目录 常见内置函数(部分) 可迭代对象 迭代器对象 for循环内部原理 异常处理 异常信息的组成部分 异常的分类 异常处理实操 异常处理的其他操作 for循环本质 迭代取值与索引取值的区别 常见内置函 ...