GAN

原始GAN中判别器要最小化如下损失函数,尽可能把真实样本分为正例,生成样本分为负例:

其中是真实样本分布,是由生成器产生的样本分布。

第一个式子我们不看梯度符号的话即为判别器的损失函数,logD(xi)为判别器将真实数据判定为真实数据的概率,log(1-D(G(zi)))为判别器将生成器生成的虚假数据判定为真实数据的对立面即将虚假数据仍判定为虚假数据的概率。判别器就相当于警察,在鉴别真伪时,必须要保证鉴别的结果真的就是真的,假的就是假的,所以判别器的总损失即为二者之和,应当最大化该损失。由于判别器(警察)鉴别真伪的能力随着训练次数的增加越来越高,生成器就要与之“对抗”,生成器就要相应地提高“造假”技术,来迷惑判别器。第二个式子为第一个式子的第二项,含义相同,只不过对于生成器应当最小化该项,生成器当然希望辨别器将虚假数据仍判定为虚假数据的概率越低越好,即将虚假数据误判定为真实数据的概率越大越好,即最大化log(D(G(zi)))损失函数。所以二者相互提高或者减小自身的损失,以不断互相对抗。

从公式1可以得到,在生成器G固定参数时最优的判别器D应该是什么。对于一个具体的样本,它可能来自真实分布也可能来自生成分布,它对公式1损失函数的贡献是

令其关于的导数为0,得

化简得最优判别器为:

(公式4)

看一个样本来自真实分布和生成分布的可能性的相对比例。如果,最优判别器就应该非常自信地给出概率0;如果,说明该样本是真是假的可能性刚好一半一半,此时最优判别器也应该给出概率0.5。

 (公式2)

GAN训练的问题之一,就是别把判别器训练得太好,否则在实验中生成器会完全学不动(loss降不下去),为了探究背后的原因,我们就可以看看在极端情况——判别器最优时,生成器的损失函数变成什么。给公式2加上一个不依赖于生成器的项,使之变成

最小化这个损失函数等价于最小化公式2,而且它刚好是判别器损失函数的相反数。代入最优判别器即公式4,再进行简单的变换可以得到

(公式5)

从而可以得到KL散度和JS散度(衡量量两个分布的差异区别)

于是公式5就可以继续写成

(公式8)

目前得到的结论:根据原始GAN定义的判别器loss,我们可以得到最优判别器的形式;而在最优判别器的下,我们可以把原始GAN定义的生成器loss等价变换为最小化真实分布与生成分布之间的JS散度。我们越训练判别器,它就越接近最优,最小化生成器的loss也就会越近似于最小化之间的JS散度。

问题就出在这个JS散度上。我们会希望如果两个分布之间越接近它们的JS散度越小,我们通过优化JS散度就能将“拉向”,最终以假乱真。这个希望在两个分布有所重叠的时候是成立的,但是如果两个分布完全没有重叠的部分,或者它们重叠的部分可忽略(下面解释什么叫可忽略),它们的JS散度是多少呢?

答案是,因为对于任意一个x只有四种可能:

第一种对计算JS散度无贡献,第二种情况由于重叠部分可忽略所以贡献也为0,第三种情况对公式7右边第一个项的贡献是,第四种情况与之类似,所以最终

换句话说,无论是远在天边,还是近在眼前,只要它们俩没有一点重叠或者重叠部分可忽略,JS散度就固定是常数而这对于梯度下降方法意味着——梯度为0!此时对于最优判别器来说,生成器肯定是得不到一丁点梯度信息的;即使对于接近最优的判别器来说,生成器也有很大机会面临梯度消失的问题。

但是不重叠或重叠部分可忽略的可能性有多大?不严谨的答案是:非常大。比较严谨的答案是:的支撑集(support)是高维空间中的低维流形(manifold)时,重叠部分测度(measure)为0的概率为1。

    • 支撑集(support)其实就是函数的非零部分子集,比如ReLU函数的支撑集就是,一个概率分布的支撑集就是所有概率密度非零部分的集合。
    • 流形(manifold)是高维空间中曲线、曲面概念的拓广,我们可以在低维上直观理解这个概念,比如我们说三维空间中的一个曲面是一个二维流形,因为它的本质维度(intrinsic dimension)只有2,一个点在这个二维流形上移动只有两个方向的自由度。同理,三维空间或者二维空间中的一条曲线都是一个一维流形。
    • 测度(measure)是高维空间中长度、面积、体积概念的拓广,可以理解为“超体积”。

在(近似)最优判别器下,最小化生成器的loss等价于最小化之间的JS散度,而由于几乎不可能有不可忽略的重叠,所以无论它们相距多远JS散度都是常数,最终导致生成器的梯度(近似)为0,梯度消失。

    • 首先,之间几乎不可能有不可忽略的重叠,所以无论它们之间的“缝隙”多狭小,都肯定存在一个最优分割曲面把它们隔开,最多就是在那些可忽略的重叠处隔不开而已。
    • 由于判别器作为一个神经网络可以无限拟合这个分隔曲面,所以存在一个最优判别器,对几乎所有真实样本给出概率1,对几乎所有生成样本给出概率0,而那些隔不开的部分就是难以被最优判别器分类的样本,但是它们的测度为0,可忽略。
    • 最优判别器在真实分布和生成分布的支撑集上给出的概率都是常数(1和0),导致生成器的loss梯度为0,梯度消失。

有了这些理论分析,原始GAN不稳定的原因就彻底清楚了:判别器训练得太好,生成器梯度消失,生成器loss降不下去;判别器训练得不好,生成器梯度不准,四处乱跑。只有判别器训练得不好不坏才行,但是这个火候又很难把握,甚至在同一轮训练的前后不同阶段这个火候都可能不一样,所以GAN才那么难训练。

WGAN

引入Wasserstein距离

希望建立一个平滑的,处处可导的cost function。在图中,蓝色为真实分布,绿色为生成数据的分布。红色为discriminator的cost function,我们发现虽然discriminator有效的区分了两个分布,但是当蓝绿两个分布没有交集时,在大量的点上的cost function为常数值,梯度为0,generator 不能更新了。这时看一下wasserstein 距离,它体现为那个草绿色的线,它平滑,可导这就是我们要寻找的cost function。

数学定义如下:

(公式12)

解释如下:组合起来的所有可能的联合分布的集合,反过来说,中每一个分布的边缘分布都是。对于每一个可能的联合分布而言,可以从中采样得到一个真实样本和一个生成样本,并算出这对样本的距离,所以可以计算该联合分布下样本对距离的期望值。在所有可能的联合分布中能够对这个期望值取到的下界,就定义为Wasserstein距离。

Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。

但是

因为Wasserstein距离定义(公式12)中的没法直接求解,不过没关系,作者用了一个已有的定理把它变换为如下形式

(公式13)

Lipschitz连续。它其实就是在一个连续函数上面额外施加了一个限制,的导函数绝对值不超过。限制了一个连续函数的最大局部变动幅度。

公式13的意思就是在要求函数的Lipschitz常数不超过的条件下,对所有可能满足条件的取到的上界,然后再除以。特别地,我们可以用一组参数来定义一系列可能的函数,此时求解公式13可以近似变成求解如下形式

(公式14)

再用上我们搞深度学习的人最熟悉的那一套,不就可以把用一个带参数的神经网络来表示嘛!由于神经网络的拟合能力足够强大,我们有理由相信,这样定义出来的一系列虽然无法囊括所有可能,但是也足以高度近似公式13要求的那个了。

最后,还不能忘了满足公式14中这个限制。我们其实不关心具体的K是多少,只要它不是正无穷就行,因为它只是会使得梯度变大倍,并不会影响梯度的方向。所以作者采取了一个非常简单的做法,就是限制神经网络的所有参数的不超过某个范围,比如,此时关于输入样本的导数也不会超过某个范围,所以一定存在某个不知道的常数使得的局部变动幅度不会超过它,Lipschitz连续条件得以满足。具体在算法实现中,只需要每次更新完后把它clip回这个范围就可以了。

到此为止,我们可以构造一个含参数、最后一层不是非线性激活层的判别器网络,在限制不超过某个范围的条件下,使得

(公式15)

尽可能取到最大,此时就会近似真实分布与生成分布之间的Wasserstein距离(忽略常数倍数)。注意原始GAN的判别器做的是真假二分类任务,所以最后一层是sigmoid,但是现在WGAN中的判别器做的是近似拟合Wasserstein距离,属于回归任务,所以要把最后一层的sigmoid拿掉。

接下来生成器要近似地最小化Wasserstein距离,可以最小化,由于Wasserstein距离的优良性质,我们不需要担心生成器梯度消失的问题。再考虑到的第一项与生成器无关,就得到了WGAN的两个loss。

(公式16,WGAN生成器loss函数)

(公式17,WGAN判别器loss函数)

公式15是公式17的反,可以指示训练进程,其数值越小,表示真实分布与生成分布的Wasserstein距离越小,GAN训练得越好。

WGAN与原始GAN第一种形式相比,只改了四点:

  • 判别器最后一层去掉sigmoid
  • 生成器和判别器的loss不取log
  • 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
  • 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

GAN和GAN的改进的更多相关文章

  1. 【GAN】GAN的原理及推导

    把GAN的论文看完了, 也确实蛮厉害的懒得写笔记了,转一些较好的笔记,前面先贴一些 原论文里推理部分,进行备忘. GAN的解释 算法流程 GAN的理论推理 转自:https://zhuanlan.zh ...

  2. 【GAN】GAN设计与训练集锦

    以下内容纯属经验之谈,无公式推断!部分内容源自其他博客或课程,并已标注来源. 问题篇[1] 1.模式崩溃 在某个模式(mode)下出现大量重复样本,如左图中,生成的样本分布靠得很近,较聚集,可视化如右 ...

  3. 【生成对抗网络学习 其一】经典GAN与其存在的问题和相关改进

    参考资料: 1.https://github.com/dragen1860/TensorFlow-2.x-Tutorials 2.<Generative Adversarial Net> ...

  4. GAN︱生成模型学习笔记(运行机制、NLP结合难点、应用案例、相关Paper)

    我对GAN"生成对抗网络"(Generative Adversarial Networks)的看法: 前几天在公开课听了新加坡国立大学[机器学习与视觉实验室]负责人冯佳时博士在[硬 ...

  5. 深度学习新星:GAN的基本原理、应用和走向

    深度学习新星:GAN的基本原理.应用和走向 (本文转自雷锋网,转载已获取授权,未经允许禁止转载)原文链接:http://www.leiphone.com/news/201701/Kq6FvnjgbKK ...

  6. Improved GAN

    https://www.bilibili.com/video/av9770302/?p=16 从之前讲的basic gan延伸到unified framework,到WGAN 再到通过WGAN进行Ge ...

  7. [Deep-Learning-with-Python]GAN图片生成

    GAN 由Goodfellow等人于2014年引入的生成对抗网络(GAN)是用于学习图像潜在空间的VAE的替代方案.它们通过强制生成的图像在统计上几乎与真实图像几乎无法区分,从而能够生成相当逼真的合成 ...

  8. Generative Adversarial Nets(原生GAN学习)

    学习总结于国立台湾大学 :李宏毅老师 Author: Ian Goodfellow • Paper: https://arxiv.org/abs/1701.00160 • Video: https:/ ...

  9. GAN综述

    生成式对抗模型GAN (Generativeadversarial networks) 是Goodfellow等[1]在 2014年提出的一种生成式模型,目前已经成为人工智能学界一个热门的研究方向,著 ...

随机推荐

  1. while语句的一个使用技巧

    作用类似于可以输入一个不定量长度的数组,但是严格来说 不可能出现任意长度. 内存是有限的 超出一定长度后,不可能存的下,虽然这个值可能是超级大的,但总有限度. 这里我们利用while(cin>& ...

  2. flask-profiler的使用

    使用 profiler测量在你的Flask 应用程序中定义的端点:并通过web界面提供细粒度的报告. 它给出了这些问题的答案: 应用程序中的瓶颈在哪里? 应用程序中最慢的终结点? 哪些是最常被调用的终 ...

  3. 【Kata Daily 190929】Password Hashes(密码哈希)

    题目: When you sign up for an account somewhere, some websites do not actually store your password in ...

  4. 推动中国制造升级,汽车装配车间生产流水线3D可视化

    前言 随着<中国制造2025>的提出,制造业迎来了全新的发展机遇.更多的企业将制造业信息化技术进行广泛的应用,如 MES 系统.数字孪生以及生产管理可视化等技术的研究应用,已经成为社会各界 ...

  5. EFCore自己用的点东西

    测试一下EF core 的操作 蛀牙总是用导航属性有的时候有点不理解 写个例子给自己用 class Program { static void Main(string[] args) { MyDbCo ...

  6. 线程安全之ConcurrentQueue<T>队列

    最近在弄一个小项目,大概600w行的数据,要进行数据清洗,因数据量偏大,如果单线程去执行,会造成效率偏低,只能用多线程了,但采用多线程存在线程安全问题,于是查了下资料,发现有ConcurrentQue ...

  7. c++的基本点

    C++的要点:    一个思想:抽象(和分类)的思想:    三个概念:类.对象.消息:    三个特征:封装.继承.多态:    一个观念:函数服务于数据.      程序 = 对象 + 消息   ...

  8. arm-linux openssl移植

    从openssl官网下载openssl https://www.openssl.org/source/ 1.解压到linux中 2. ./config no-asm shared --prefix=$ ...

  9. IDEA常用插件汇总

    actiBPM idea的Activiti插件,不知道Activiti的可以百度下,用到了再来安装这个插件. Alibaba Java Coding Guidelines Alibaba开发的Java ...

  10. 【网鼎杯2020朱雀组】Web WriteUp

    nmap nmap语法,很简单. 127.0.0.1' -iL /flag -oN vege.txt ' phpweb 打开,抓包,发现可以传递函数和其参数 试了一下很多函数都被过滤了,不能执行系统命 ...