从GAN到WGAN的来龙去脉
一、原始GAN的理论分析
1.1 数学描述
其实GAN的原理很好理解,网络结构主要包含生成器 (generator) 和鉴别器 (discriminator) ,数据主要包括目标样本 \(x_r \sim P_{r}\), 随机输入样本 \(z \sim P_{z}\) 。生成器的目的就是根据 \(z\) 生成 \(G(z) \sim P_{r}\) ,而鉴别器则尽量区分出来 \(G(z)\) 与 \(x_{r}\) 的不同。生成器和鉴别器采用生成对抗的方式不断优化,最终能通过生成器得到期望输出(比如风格转换,人脸生成等)。联想到电影《无双》的情节,生成器就是造假币的机器,而鉴别器可以类似为鉴别假币的手段。在初始情况下,假币制造机只能生成不是很逼真的假币,此时鉴别器很轻松就能鉴别出来,于是便优化流程和材料,鉴别器鉴别错误之后再改进判别手段......如此往复,最终我们可以得到足以以假乱真的假币。
鉴别器 $D(input;\theta_{d})$ 的目标是对输入的数据做出准确的判断,因此目标函数为:
$$
\mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式1)
$$
生成器 $G(input;\theta_{g})$ 的目标就是输出的数据尽可能与目标样本接近骗过鉴别器 $D$,因此:
$$
\mathop{min}\limits_{G}E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式2)
$$
因此总的目标函数可以写为:
$$
\mathop{min}\limits_{G}\mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式3)
$$
借用原论文的符号,我们将生成器输出的概率分布记为 $P_{g}$,于是公式 (3) 可以记为:
$$
\mathop{min}\limits_{G}\mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{x\sim P_{g}}[log(1-D(x;\theta_{d}))] (公式4)
$$
1.2 求出全局最优解
当固定 \(G\) 的参数时,优化 \(D\) 的参数:
V_{D} = \int P_{r}logD(x)dx+\int P_{g}log(1-D(x))dx = \int (P_{r}logD(x)+P_{g}log(1-D(x))dx \\ (公式5)
\]
因此,最大值为:
\dfrac{\partial{V_{D}}}{\partial{D}} = \int \dfrac{P_{r}}{D(x)}-\dfrac{P_{g}}{1-D(x)}dx = 0\\ (公式6)
\]
解得:
\]
于是,将 \(D^{*}(x)\) 带入到公式 (4) 中,得到:
\]
即:
\]
由于\(P_{r}+P{g} \in [0,2]\),因此公式 (10) 可以写为:
V_{G} = KL(P_{r}|| \dfrac{P_{r}+P_{g}}{2})+log \dfrac{1}{2}+KL(P_{g}|| \dfrac{P_{r}+P_{g}}{2})+log \dfrac{1}{2} \\(公式10)
\]
最终:
\]
因此,当 \(P_{r} = \dfrac{P_{r}+P_{g}}{2} = P_{g}\) 时,存在唯一极小值 \(P_{r} = P_{g}\),此时 \(D^{*}(x) = \dfrac{1}{2}\)。即公式 (4) 存在全局最优解,在全局最优解的情况下,生成器生成的概率分布与目标样本概率分布一样,此时鉴别器无法准确判断生成样本与目标样本的差异,判断正确和错误的概率各为0.5,类似于瞎猜。
1.3 原始GAN到底出了什么问题?
GAN的训练是依靠生成器和鉴别器的相互对抗来完成的,那么直观地思考一下:如果鉴别器过于差劲,给不到生成器任何有用的信息,那么生成器的更新就会没有方向;如果鉴别器太好,那么类似于造假币的机器极其差,而鉴别器直接就是验钞机,那么直观上也无法给生成器提供足够的信息去更新。因此,原始的GAN理论上可行,而实际上却受到鉴别器和生成器状态的影响,不一定能找到最优解,且训练不稳定。
从数学角度上来描述:我们在 1.2节 求全局最优解的过程中,先求出了鉴别器 \(D\) 的最优解,然后得到了公式 (11) ,在这种情况下相当于我用已经训练好的鉴别器来指导生成器的学习,将概率分布从 \(P_{z}\) 拉向 \(P_{r}\)。乍一看没什么问题,但是如果两个分布 \(P_{r}\),\(P_{z}\) 完全没有重叠的部分,或者它们重叠的部分可忽略,会发生什么情况呢?答案是无论换句话说,无论 \(P_{r}\) 跟 \(P_{g}\)是远在天边,还是近在眼前,只要它们俩没有一点重叠或者重叠部分可忽略,公式 (11) 散度就固定是常数 \(log2\),而这对于梯度下降方法意味着——梯度为0!此时对于最优判别器来说,生成器肯定是得不到一丁点梯度信息的;即使对于接近最优的判别器来说,生成器也有很大机会面临梯度消失的问题。与我们直观上的感觉一致。
那么问题就变成了\(P_{r}\),\(P_{z}\) 没有重叠的部分的概率大吗?答案是非常大。首先,\(P_{r}\) 是一个复杂分布,而 \(P_{z}\) 则是一个简单分布,所以在空间上二者不重叠的概率很大。更重要的一个原因是,输入 \(z \sim P_{r}\) 一般是 100 维,而生成的目标往往是一张图片,比如 \(64 \times 64\) 就是 \(4096\) 维,低维与高维相重合本来就很少,因此更加证明了原始GAN不容易训练。总结下来:
原始GAN存在梯度不稳定的问题,即判别器训练得太好,生成器梯度消失,生成器loss降不下去;判别器训练得不好,生成器梯度不准,四处乱跑。只有判别器训练得不好不坏才行,但是这个火候又很难把握,甚至在同一轮训练的前后不同阶段这个火候都可能不一样,所以GAN才那么难训练。 此外,GAN还存在模式崩塌(collapse mode)的问题,即生成样本多样性不足。
二、WGAN的前世今生
为了解决原始GAN梯度不稳定的问题,一个过渡的解决方案是强行对生成样本和真实样本加噪声,使得原本两个分布弥散到整个高维空间,增加重叠部分。当二者出现重叠部分时,再把噪声拿掉,这样也能够继续收敛。这只是一个折中的方案,并没有从本质上解决问题。
2.1 Wasserstein 距离
Wasserstein 距离又叫 Earth-Mover ( EM ) 距离,定义如下:
\]
其中:\(\prod (P_{r}, P_{g})\) 表示从概率 \(P_{g}\) 到 \(P_{r}\) 的所有可能分布,而 \(W(P_{r},P_{g})\) 代表所有可能的分布中, \(||x-y||\) 的最小期望值距离。举个例子:如下图所示,假如将左侧的方块运送到右侧的位置,那么方案有很多种,其中最小的那一种移动所花的消耗即为Wasserstein距离。
**因此,Wasserstein的好处就是无论两个分布是否有重叠部分,Wasserstein距离都是连续的,能够反映两个分布的远近,而JS散度和KL散度既不能反映远近,也提供不了梯度。**所以,EM距离更适合用作GAN的loss function。
2.2 从EM距离到WGAN
由于在Wasserstein中,\(\mathop{inf}\limits_{\gamma \sim \prod (P_{r}, P_{g})}\) 没办法直接求解,因此WGAN的作者通过已有的定理将其转换成如下形式:
\]
式子的证明过程对我来说确实难以理解,因此这里就不作解释了,有兴趣的可以参考WGAN的原论文。最后,WGAN的loss function变成了下面的形式:
\]
于是,可以把函数 \(f\) 用一个参数为 \(w\) 的神经网络来表示。最后,为了满足 \(||f_{w}||_{L}<K\) 的限制,将神经网络的所有参数 \(w\) 都拉伸到 \([-c,c]\) 中,所以一定满足Lipschitz连续条件。
因此,我们可以构造一个含参数 \(w\)、最后一层不是非线性激活层的判别器网络 \(f_{w}\),在限制! \(w\) 不超过某个范围的条件下,使得:
\]
尽可能取到最大,此时的 \(L\) 就可以近似为真实分布 \(P_{r}\) 与生成分布 \(P_{g}\) 之间的Wasserstein距离。注意:原始GAN的判别器做的时二分类任务,所以最后一层采用 \(sigmoid\) 函数,而WGAN中的判别器做的是拟合 Wasserstein 距离,属于回归任务,因此把最后一层的 \(sigmoid\) 去掉。
因此判别器的loss function为:
\]
生成器的loss function为:
\]
所以,不管理论再复杂, WGAN在原始的GAN上只做了三点改进:
- 判别器最后一层去掉sigmoid
- 生成器和判别器的loss不取log
- 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
最后,作者通过经验发现,不要使用Adam优化算法,推荐RMSProp或者SGD。
2.3 模型崩塌(collapse mode)问题的解决方法
上述解决了GAN在训练过程中梯度不稳定的问题,那么模型崩塌(collapse mode)问题的解决方法如下:
2.3.1 在loss function 层面
通常先更新几轮生成器,之后再更新一轮鉴别器。因为GAN的训练是 \(min max\) 的策略,即先更新鉴别器,然后再更新生成器。往往在迭代的过程中,生成器和鉴别器交替优化,容易将问题变成 \(maxmin\) 的问题,这样一来就变成了:生成器先生成一个输出,然后鉴别器对这个输出进行判断,那么生成器最后学习到的往往是最保险的,导致模型崩塌(collapse mode),生成样本多样性不足。
2.3.2 在网络结构方面
1、采用多个生成器和一个鉴别器,类似于旷视“先发散再收敛”的学习策略,通过正则化约束生成器之间的比重,生成多样性的样本。
2、将真实样本通过一个编码器 (Encoder) 后再使用生成器进行重构,如下图所示:
那么 \(D_{M}\) 和 \(R\) 用来指导生成对应的样本,而 \(D_{D}\) 则对 \(G(z)\) 和 \(G(E(x))\) 进行判别,显然二者都是生成的样本,差别越大那么表明生成样本的多样性越高。
3、Mini-batch discrimination在判别器的中间层建立一个mini-batch layer用于计算基于 \(L_{1}\) 距离的样本统计量,通过建立该统计量去判别一个batch内某个样本与其他样本有多接近。这个信息可以被判别器利用到,从而甄别出哪些缺乏多样性的样本。对生成器而言,则要试图生成具有多样性的样本。
2.4 WGAN 部分代码分析
self.G_sample = self.generator(self.z)
self.D_real, _ = self.discriminator(self.X)
self.D_fake, _ = self.discriminator(self.G_sample, reuse = True)
# loss
self.D_loss = - tf.reduce_mean(self.D_real) + tf.reduce_mean(self.D_fake)
self.G_loss = - tf.reduce_mean(self.D_fake)
self.D_solver = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(self.D_loss, var_list=self.discriminator.vars)
self.G_solver = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(self.G_loss, var_list=self.generator.vars)
# clip
self.clip_D = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in self.discriminator.vars]
然后按照正常的GAN训练即可。
从GAN到WGAN的来龙去脉的更多相关文章
- 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码
先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...
- W-GAN系 (Wasserstein GAN、 Improved WGAN)
学习总结于国立台湾大学 :李宏毅老师 WGAN前作:Towards Principled Methods for Training Generative Adversarial Networks W ...
- GAN的文献综述
1.Conditional Generative Adversarial Netwoks Describe GAN: Generative adversarial nets were recently ...
- (转) Read-through: Wasserstein GAN
Sorta Insightful Reviews Projects Archive Research About In a world where everyone has opinions, on ...
- DCGAN、WGAN、WGAN-GP、LSGAN、BEGAN原理总结及对比
DCGAN.WGAN.WGAN-GP.LSGAN.BEGAN原理总结及对比 from:https://blog.csdn.net/qq_25737169/article/details/7885778 ...
- GAN的调研和学习
近期集中学习了GAN,下面记录一下调研的结果,和学习的心得,疏漏的地方,敬请指正. 本文将分为几个部分进行介绍,首先是GAN的由来,其次是GAN的发展,最后是GAN的应用. 先把最近收集的资料列举一下 ...
- GAN生成图像论文总结
GAN Theory Modifyingthe Optimization of GAN 题目 内容 GAN DCGAN WGAN Least-square GAN Loss Sensi ...
- GAN与VAE
经典算法·GAN与VAE Generative Adversarial Networks 及其变体 生成对抗网络是近几年最为经典的生成模型的代表工作,Goodfellow的经典工作.通过两个神经网络结 ...
- 深度学习----现今主流GAN原理总结及对比
原文地址:https://blog.csdn.net/Sakura55/article/details/81514828 1.GAN 先来看看公式: GAN网络主要由两个网络构 ...
随机推荐
- kotlin中的嵌套类与内部类
Java中的内部类和静态内部类在Java中内部类简言之就是在一个类的内部定义的另一个类.当然在如果这个内部类被static修饰符修饰,那就是一个静态内部类.关于内部类 和静态内部类除了修饰符的区别之外 ...
- jenkins配置基于角色的项目权限管理设置步骤
jenkins配置基于角色的项目权限管理设置步骤 本文链接:https://blog.csdn.net/russ44/article/details/52276222 由于jenkins默认的权限管理 ...
- grasshopper之python电池执行逻辑
在grasshopper中,需要导入的包虽然不多,但是相当绕人,所要实现的操作往往找不到,暂时做个分类. 双击输入 python 电池: # 导入rhino 包 import Rhino #Rhino ...
- ImportError:no mudle named 'cv2'
提供一下下载的网址:OpenCV,速度比较慢. 我的anaconda版本Python是3.6的,直接提供百度云下载: 链接:https://pan.baidu.com/s/1Xz9JrE2m-dwPv ...
- C#解决WebClient不能下载https网页内容
在下载之前,执行以下代码即可: if (stUrl.Substring(0, 5) == "https") { // 解决WebClient不能通过https下载内容问题 Serv ...
- JavaScript 中数组 sort() 方法的基本使用
在日常的代码开发中,关于数组排序的操作可不少,JavaScript 中可以调用 sort 方法对数组进行快速排序. 今天,就数组的 sort 方法来学习一下,避免日后踩坑的悲惨遭遇. 概念 sort ...
- C ++基本输入/输出
C ++基本输入/输出 本文将学习如何使用cin对象从用户那里获取输入,并使用cout对象在示例的帮助下向用户显示输出. C ++输出 在C ++中,cout将格式化的输出发送到标准输出设备,例如屏幕 ...
- 人工智能AI智能加速卡技术
人工智能AI智能加速卡技术 一. 可编程AI加速卡 1. 概述: 这款可编程AI加速器卡具备 FPGA 加速的强大性能和多功能性,可部署AI加速器IP(WNN/GNN,直接加速卷积神经网络,直接运行常 ...
- Java真的是白天鹅
前言 我最近越来越真切的感受到,Java真的是白天鹅. 这真的是一种羡慕嫉妒恨的感受. 今天和一个Java技术Leader聊天,我告诉他敏捷开发是以人为本,他居然跟我说敏捷开发在行业内有规范,规范是死 ...
- 可微渲染 SoftRas 实践
SoftRas 是目前主流三角网格可微渲染器之一. 可微渲染通过计算渲染过程的导数,使得从单张图片学习三维结构逐渐成为现实.可微渲染目前被广泛地应用于三维重建,特别是人体重建.人脸重建和三维属性估计等 ...