一、原始GAN的理论分析

1.1 数学描述

其实GAN的原理很好理解,网络结构主要包含生成器 (generator) 和鉴别器 (discriminator) ,数据主要包括目标样本 \(x_r \sim P_{r}\), 随机输入样本 \(z \sim P_{z}\) 。生成器的目的就是根据 \(z\) 生成 \(G(z) \sim P_{r}\) ,而鉴别器则尽量区分出来 \(G(z)\) 与 \(x_{r}\) 的不同。生成器和鉴别器采用生成对抗的方式不断优化,最终能通过生成器得到期望输出(比如风格转换,人脸生成等)。联想到电影《无双》的情节,生成器就是造假币的机器,而鉴别器可以类似为鉴别假币的手段。在初始情况下,假币制造机只能生成不是很逼真的假币,此时鉴别器很轻松就能鉴别出来,于是便优化流程和材料,鉴别器鉴别错误之后再改进判别手段......如此往复,最终我们可以得到足以以假乱真的假币。

鉴别器 $D(input;\theta_{d})$ 的目标是对输入的数据做出准确的判断,因此目标函数为:
$$
\mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式1)
$$
生成器 $G(input;\theta_{g})$ 的目标就是输出的数据尽可能与目标样本接近骗过鉴别器 $D$,因此:
$$
\mathop{min}\limits_{G}E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式2)
$$
因此总的目标函数可以写为:
$$
\mathop{min}\limits_{G}\mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式3)
$$
借用原论文的符号,我们将生成器输出的概率分布记为 $P_{g}$,于是公式 (3) 可以记为:
$$
\mathop{min}\limits_{G}\mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{x\sim P_{g}}[log(1-D(x;\theta_{d}))] (公式4)
$$

1.2 求出全局最优解

当固定 \(G\) 的参数时,优化 \(D\) 的参数:

\[V_{D} = [E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{x\sim P_{g}}[log(1-D(x;\theta_{d}))] \\
V_{D} = \int P_{r}logD(x)dx+\int P_{g}log(1-D(x))dx = \int (P_{r}logD(x)+P_{g}log(1-D(x))dx \\ (公式5)
\]

因此,最大值为:

\[\dfrac{\partial{V_{D}}}{\partial{D}} = \dfrac{\partial{}}{\partial{D}}\int (P_{r}logD(x)+P_{g}log(1-D(x))dx \\
\dfrac{\partial{V_{D}}}{\partial{D}} = \int \dfrac{P_{r}}{D(x)}-\dfrac{P_{g}}{1-D(x)}dx = 0\\ (公式6)
\]

解得:

\[D^{*}(x) = \dfrac{P_{r}}{P_{r}+P_{g}} (公式7)
\]

于是,将 \(D^{*}(x)\) 带入到公式 (4) 中,得到:

\[V_{G} = [E_{x\sim P_{r}}log\dfrac{P_{r}}{P_{r}+P_{g}}]+E_{x\sim P_{g}}[log(1-\dfrac{P_{r}}{P_{r}+P_{g}})] \\ (公式8)
\]

即:

\[V_{G} = [E_{x\sim P_{r}}log\dfrac{P_{r}}{P_{r}+P_{g}}]+E_{x\sim P_{g}}[log(\dfrac{P_{g}}{P_{r}+P_{g}})] (公式9)
\]

由于\(P_{r}+P{g} \in [0,2]\),因此公式 (10) 可以写为:

\[V_{G} = [E_{x\sim P_{r}}log\dfrac{P_{r}}{(P_{r}+P_{g})/2}\times \dfrac{1}{2}]+E_{x\sim P_{g}}[log(\dfrac{P_{g}}{(P_{r}+P_{g})/2}\times \dfrac{1}{2})] \\
V_{G} = KL(P_{r}|| \dfrac{P_{r}+P_{g}}{2})+log \dfrac{1}{2}+KL(P_{g}|| \dfrac{P_{r}+P_{g}}{2})+log \dfrac{1}{2} \\(公式10)
\]

最终:

\[V_{G} = KL(P_{r}|| \dfrac{P_{r}+P_{g}}{2})+KL(P_{g}|| \dfrac{P_{r}+P_{g}}{2})-2log2 (公式11)
\]

因此,当 \(P_{r} = \dfrac{P_{r}+P_{g}}{2} = P_{g}\) 时,存在唯一极小值 \(P_{r} = P_{g}\),此时 \(D^{*}(x) = \dfrac{1}{2}\)。即公式 (4) 存在全局最优解,在全局最优解的情况下,生成器生成的概率分布与目标样本概率分布一样,此时鉴别器无法准确判断生成样本与目标样本的差异,判断正确和错误的概率各为0.5,类似于瞎猜。

1.3 原始GAN到底出了什么问题?

GAN的训练是依靠生成器和鉴别器的相互对抗来完成的,那么直观地思考一下:如果鉴别器过于差劲,给不到生成器任何有用的信息,那么生成器的更新就会没有方向;如果鉴别器太好,那么类似于造假币的机器极其差,而鉴别器直接就是验钞机,那么直观上也无法给生成器提供足够的信息去更新。因此,原始的GAN理论上可行,而实际上却受到鉴别器和生成器状态的影响,不一定能找到最优解,且训练不稳定。

从数学角度上来描述:我们在 1.2节 求全局最优解的过程中,先求出了鉴别器 \(D\) 的最优解,然后得到了公式 (11) ,在这种情况下相当于我用已经训练好的鉴别器来指导生成器的学习,将概率分布从 \(P_{z}\) 拉向 \(P_{r}\)。乍一看没什么问题,但是如果两个分布 \(P_{r}\),\(P_{z}\) 完全没有重叠的部分,或者它们重叠的部分可忽略,会发生什么情况呢?答案是无论换句话说,无论 \(P_{r}\) 跟 \(P_{g}\)是远在天边,还是近在眼前,只要它们俩没有一点重叠或者重叠部分可忽略,公式 (11) 散度就固定是常数 \(log2\),而这对于梯度下降方法意味着——梯度为0!此时对于最优判别器来说,生成器肯定是得不到一丁点梯度信息的;即使对于接近最优的判别器来说,生成器也有很大机会面临梯度消失的问题。与我们直观上的感觉一致。

那么问题就变成了\(P_{r}\),\(P_{z}\) 没有重叠的部分的概率大吗?答案是非常大。首先,\(P_{r}\) 是一个复杂分布,而 \(P_{z}\) 则是一个简单分布,所以在空间上二者不重叠的概率很大。更重要的一个原因是,输入 \(z \sim P_{r}\) 一般是 100 维,而生成的目标往往是一张图片,比如 \(64 \times 64\) 就是 \(4096\) 维,低维与高维相重合本来就很少,因此更加证明了原始GAN不容易训练。总结下来:

原始GAN存在梯度不稳定的问题,即判别器训练得太好,生成器梯度消失,生成器loss降不下去;判别器训练得不好,生成器梯度不准,四处乱跑。只有判别器训练得不好不坏才行,但是这个火候又很难把握,甚至在同一轮训练的前后不同阶段这个火候都可能不一样,所以GAN才那么难训练。 此外,GAN还存在模式崩塌(collapse mode)的问题,即生成样本多样性不足。

二、WGAN的前世今生

为了解决原始GAN梯度不稳定的问题,一个过渡的解决方案是强行对生成样本和真实样本加噪声,使得原本两个分布弥散到整个高维空间,增加重叠部分。当二者出现重叠部分时,再把噪声拿掉,这样也能够继续收敛。这只是一个折中的方案,并没有从本质上解决问题。

2.1 Wasserstein 距离

Wasserstein 距离又叫 Earth-Mover ( EM ) 距离,定义如下:

\[W(P_{r},P_{g}) = \mathop{inf}\limits_{\gamma \sim \prod (P_{r}, P_{g})}E_{(x,y)\sim \gamma}[||x-y||] (公式12)
\]

其中:\(\prod (P_{r}, P_{g})\) 表示从概率 \(P_{g}\) 到 \(P_{r}\) 的所有可能分布,而 \(W(P_{r},P_{g})\) 代表所有可能的分布中, \(||x-y||\) 的最小期望值距离。举个例子:如下图所示,假如将左侧的方块运送到右侧的位置,那么方案有很多种,其中最小的那一种移动所花的消耗即为Wasserstein距离。

**因此,Wasserstein的好处就是无论两个分布是否有重叠部分,Wasserstein距离都是连续的,能够反映两个分布的远近,而JS散度和KL散度既不能反映远近,也提供不了梯度。**所以,EM距离更适合用作GAN的loss function。

2.2 从EM距离到WGAN

由于在Wasserstein中,\(\mathop{inf}\limits_{\gamma \sim \prod (P_{r}, P_{g})}\) 没办法直接求解,因此WGAN的作者通过已有的定理将其转换成如下形式:

\[W(P_{r},P_{g}) = \dfrac{1}{K} \mathop{sup}\limits_{||f||_{L}<K} E_{x \sim P_{r}}[f(x)]-E_{x \sim P_{g}} [f(x)] (公式13)
\]

式子的证明过程对我来说确实难以理解,因此这里就不作解释了,有兴趣的可以参考WGAN的原论文。最后,WGAN的loss function变成了下面的形式:

\[W(P_{r},P_{g}) = \dfrac{1}{K} \mathop{max}\limits_{||f_{w}||_{L}<K} E_{x \sim P_{r}}[f_{w}(x)]-E_{x \sim P_{g}} [f_{w}(x)] (公式14)
\]

于是,可以把函数 \(f\) 用一个参数为 \(w\) 的神经网络来表示。最后,为了满足 \(||f_{w}||_{L}<K\) 的限制,将神经网络的所有参数 \(w\) 都拉伸到 \([-c,c]\) 中,所以一定满足Lipschitz连续条件。

因此,我们可以构造一个含参数 \(w\)、最后一层不是非线性激活层的判别器网络 \(f_{w}\),在限制! \(w\) 不超过某个范围的条件下,使得:

\[L = E_{x \sim P_{r}}[f_{w}(x)]-E_{x \sim P_{g}} [f_{w}(x)] (公式15)
\]

尽可能取到最大,此时的 \(L\) 就可以近似为真实分布 \(P_{r}\) 与生成分布 \(P_{g}\) 之间的Wasserstein距离。注意:原始GAN的判别器做的时二分类任务,所以最后一层采用 \(sigmoid\) 函数,而WGAN中的判别器做的是拟合 Wasserstein 距离,属于回归任务,因此把最后一层的 \(sigmoid\) 去掉。

因此判别器的loss function为:

\[E_{x \sim P_{g}}[f_{w}(x)]-E_{x \sim P_{r}} [f_{w}(x)] (公式16)
\]

生成器的loss function为:

\[-E_{x \sim P_{g}}[f_{w}(x)] (公式17)
\]

所以,不管理论再复杂, WGAN在原始的GAN上只做了三点改进:

  • 判别器最后一层去掉sigmoid
  • 生成器和判别器的loss不取log
  • 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c

最后,作者通过经验发现,不要使用Adam优化算法,推荐RMSProp或者SGD。

2.3 模型崩塌(collapse mode)问题的解决方法

上述解决了GAN在训练过程中梯度不稳定的问题,那么模型崩塌(collapse mode)问题的解决方法如下:

2.3.1 在loss function 层面

通常先更新几轮生成器,之后再更新一轮鉴别器。因为GAN的训练是 \(min max\) 的策略,即先更新鉴别器,然后再更新生成器。往往在迭代的过程中,生成器和鉴别器交替优化,容易将问题变成 \(maxmin\) 的问题,这样一来就变成了:生成器先生成一个输出,然后鉴别器对这个输出进行判断,那么生成器最后学习到的往往是最保险的,导致模型崩塌(collapse mode),生成样本多样性不足。

2.3.2 在网络结构方面

1、采用多个生成器和一个鉴别器,类似于旷视“先发散再收敛”的学习策略,通过正则化约束生成器之间的比重,生成多样性的样本。

2、将真实样本通过一个编码器 (Encoder) 后再使用生成器进行重构,如下图所示:

那么 \(D_{M}\) 和 \(R\) 用来指导生成对应的样本,而 \(D_{D}\) 则对 \(G(z)\) 和 \(G(E(x))\) 进行判别,显然二者都是生成的样本,差别越大那么表明生成样本的多样性越高。

3、Mini-batch discrimination在判别器的中间层建立一个mini-batch layer用于计算基于 \(L_{1}\) 距离的样本统计量,通过建立该统计量去判别一个batch内某个样本与其他样本有多接近。这个信息可以被判别器利用到,从而甄别出哪些缺乏多样性的样本。对生成器而言,则要试图生成具有多样性的样本。

2.4 WGAN 部分代码分析

self.G_sample = self.generator(self.z)

self.D_real, _ = self.discriminator(self.X)
self.D_fake, _ = self.discriminator(self.G_sample, reuse = True) # loss
self.D_loss = - tf.reduce_mean(self.D_real) + tf.reduce_mean(self.D_fake)
self.G_loss = - tf.reduce_mean(self.D_fake) self.D_solver = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(self.D_loss, var_list=self.discriminator.vars)
self.G_solver = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(self.G_loss, var_list=self.generator.vars) # clip
self.clip_D = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in self.discriminator.vars]

然后按照正常的GAN训练即可。

从GAN到WGAN的来龙去脉的更多相关文章

  1. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  2. W-GAN系 (Wasserstein GAN、 Improved WGAN)

    学习总结于国立台湾大学 :李宏毅老师 WGAN前作:Towards Principled Methods for Training Generative Adversarial Networks  W ...

  3. GAN的文献综述

    1.Conditional Generative Adversarial Netwoks Describe GAN: Generative adversarial nets were recently ...

  4. (转) Read-through: Wasserstein GAN

    Sorta Insightful Reviews Projects Archive Research About  In a world where everyone has opinions, on ...

  5. DCGAN、WGAN、WGAN-GP、LSGAN、BEGAN原理总结及对比

    DCGAN.WGAN.WGAN-GP.LSGAN.BEGAN原理总结及对比 from:https://blog.csdn.net/qq_25737169/article/details/7885778 ...

  6. GAN的调研和学习

    近期集中学习了GAN,下面记录一下调研的结果,和学习的心得,疏漏的地方,敬请指正. 本文将分为几个部分进行介绍,首先是GAN的由来,其次是GAN的发展,最后是GAN的应用. 先把最近收集的资料列举一下 ...

  7. GAN生成图像论文总结

    GAN Theory Modifyingthe Optimization of GAN 题目 内容 GAN   DCGAN   WGAN   Least-square GAN   Loss Sensi ...

  8. GAN与VAE

    经典算法·GAN与VAE Generative Adversarial Networks 及其变体 生成对抗网络是近几年最为经典的生成模型的代表工作,Goodfellow的经典工作.通过两个神经网络结 ...

  9. 深度学习----现今主流GAN原理总结及对比

    原文地址:https://blog.csdn.net/Sakura55/article/details/81514828 1.GAN 先来看看公式:             GAN网络主要由两个网络构 ...

随机推荐

  1. global nonlocal 闭包构建 变量名解析原则

  2. java并发编程工具类JUC第三篇:DelayQueue延时队列

    DelayQueue 是BlockingQueue接口的实现类,它根据"延时时间"来确定队列内的元素的处理优先级(即根据队列元素的"延时时间"进行排序).另一层 ...

  3. MySQL之数据查询语言(DQL)

    数据查询语言DQL基本结构是由SELECT子句,FROM子句,WHERE子句组成的查询块: SELECT <字段> FROM <表名> WHERE <查询条件> - ...

  4. GO学习-(10) Go语言基础之指针

    Go语言基础之指针 区别于C/C++中的指针,Go语言中的指针不能进行偏移和运算,是安全指针. 要搞明白Go语言中的指针需要先知道3个概念:指针地址.指针类型和指针取值. Go语言中的指针 任何程序数 ...

  5. Go语言网络通信---tcp上传大文件(粘包问题还需优雅解决)

    server端: package main import ( "bufio" "encoding/binary" "fmt" "n ...

  6. 【RMAN】使用RMAN备份将数据库不完全恢复到指定时间点

    RMAN作为Oracle强大的备份恢复工具,可以协助我们恢复数据库到指定时间点,这便是Oracle不完全恢复的一种体现,通过这种方法可以找回我们曾经丢失的数据.这里以找回误TRUNCATE表数据为例给 ...

  7. 解决maven依赖冲突问题

    解决maven依赖冲突问题 1.idea 安装maven helper插件 2.打开pom.xml文件 底部多出一个Dependency Analyzer选项 点开这个选项 找到冲突,点击右键,选择E ...

  8. Python分析离散心率信号(中)

    Python分析离散心率信号(中) 一些理论和背景 心率信号不仅包含有关心脏的信息,还包含有关呼吸,短期血压调节,体温调节和荷尔蒙血压调节(长期)的信息.也(尽管不总是始终如一)与精神努力相关联,这并 ...

  9. 实战SpringBoot Admin

    长话短说哦,直接查看下文吧 目录 声明 先锋 前提 SpringBoot Admin 介绍 服务端的搭建 客户端的搭建 参数的指南 尾声 声明 见名知意,实战SpringBoot Admin,实战!实 ...

  10. linux用户组添加和权限的设置

    1.useradd  添加用户 useradd [选项]... 用户名 -u 用户id.-d 家目录路径.-s 登录Shell(解释器).-G 附加组 /sbin/nologin  :禁止用户登陆系统 ...