【生成对抗网络学习 其一】经典GAN与其存在的问题和相关改进
参考资料:
1、https://github.com/dragen1860/TensorFlow-2.x-Tutorials
2、《Generative Adversarial Net》
直接介绍GAN可能不太容易理解,所以本次会顺着几个具体的问题讨论并介绍GAN(个人理解有限,有错误的希望各位大佬指出),本来想做代码介绍的,但是关于eriklindernoren的GAN系列实现已经有很多博主介绍过了,所以就不写了。
如果你对GAN的基本知识不太了解,建议先看看莫烦的介绍:https://mofanpy.com/tutorials/machine-learning/gan/
注:图片刷不出来可能需要fq,最近jsdelivr代理好像挂了。
1、什么是GAN
GAN是一种生成网络
区别于以往使用的RNN、CNN等网络,GAN不是将数据与结果根据某种关系联系起来,而是使用一堆随机数去生成想要的结果
GAN中同时训练着两个模型:一个是生成器(Generator),另一个是判别器(Discriminator)。
生成器通过随机数来生成结果(有可能是图片,也有可能是其他的),这里我们把生成的结果称为"生成数据",我们的目标结果称为"真实数据"。
之后,生成数据和真实数据会被同时送给判别器进行区分。
GAN的训练
整体结构如下
2、GAN有什么问题
研究数据的特点是解决问题的一个前提,数据分布会直接影响到我们算法的结果
对于“通过训练,从噪声数据生成一副图片”这个问题来说,里面涉及到两种数据:噪声和真实图片
这两种数据在分布上来说是没有重合部分的
例如,真实的图片是“手写数字1”,而还没经过训练的生成器用噪声生成的图片是类似老电视上的那种白色雪花噪点(我的理解,不知道恰不恰当)
显然两者不会有数据上的重合(overlapped),在这种情况下无论生成多少张图片,我们人去看的时候总是能够区分假图片
具体来说就对应成下面的两种分布:P、Q
省略KL与JS的推导过程直接看结论:
当θ≠0(数据分布没有重合时),使用KL散度(相对熵)和JS散度不能够很好地量化训练结果
例如上图,在均值达到某一大小时,两者会变成固定常数(也就是不起作用,并且JS比KL出现时间更早),对应到现象就是出现梯度消失,没办法继续更新梯度。
再换个角度看
蓝色的是分布1,红色的是真实数据分布,不管怎么移动,只要两者没有重合,那么JS永远是log2(同一值)
也就是说,如果刚开始训练时,数据分布处于一个不好的状态,那么训练很难再进行下去
再举个例子
蓝色是真实数据,绿色是生成数据的分布,它们中间部分是没有重叠的
对于橙色线也就是判别器而言,他可以简单的区分出是或者不是真实数据(0/1)
但是如果一开始我们的数据分布就处于绿色区域(不利位置),那么是没办法进行判别的(无法更新)
而WGAN的评价标准(EM)可以解决这个问题,即使在不相交的位置,导数也可以起到引导作用
因此,使用EM距离来衡量训练结果便是对GAN的一个重要改进
Wasserstein距离(Wasserstei Distance)(也叫EM距离,Earth-Mover Distance)可以用于衡量两个分布之间的距离
对于没有重叠的分布同样适用,例如下面这种的
从横轴可以看出,这两个数据按照KL散度或者JS散度的标准,是完全没有重叠的
现在将问题转换一下,即使两种分布没有重叠,但是我们还可以让他们形式尽可能保持一致嘛
那么我们就可以通过"交换"上面的柱状体来实现这个目标
可以通过“交换次数”来衡量这个行为的效率
于是问题转换成了:“从P分布转换到Q分布需要几步?”
转换需要的“步数”就是代价,可以用来衡量P分布与Q分布到底有多相似
所以,即便是完全不重叠的分布,无非就是转换步骤很多而已,就说明他们非常不相关。
我们把柱形的移动转换类比成“铲土”,那么同样是铲,笨方法就铲的次数多,好方法步数少,于是我们会针对每个情况去计算最优的“铲土方法”,用来衡量P与Q的相似度。
所以Wasserstein距离也叫“推土机距离”
计算式如下
可以看到,GAN中使用"D"来衡量分布的接近程度
这个"D"其实就是一个使用JS散度构建的神经网络层(判别器层)
而WGAN中则是使用"f"来衡量分布
"f"则是使用WD来构建的判别器层(因为WD原来是在离散情况下的,所以这里在连续情况下使用相当于通过判别器来逼近理想的"f",因此有约束条件)
但是有个约束条件:在f上取任意的两个梯度x1、x2,他们的差必须小于1(图中经过化简了),即1-Lipschitz function
现在可以做个总结:
GAN与WGAN最主要的区别就是将JS散度换成Wasserstein距离(或者说EM距离),由此解决了GAN早期训练时因为数据重合(overlapped)度低而出现的梯度消失问题。
WGAN的训练
整体结构如下
注意:经典GAN的生成器与编码器均使用简单的全连接层构建,而其他衍生种类GAN一般使用卷积层/反卷积层代替全连接层
3、WGAN有什么问题
实现上面提到的设想的关键在于如何满足1-Lipschitz约束
WGAN 为了实现这个约束,使用了 clip 截断了判别器 weights
但这只有在权重恰好合适时能够实现(具体不推导了),并且这变相限制了这个网络的参数,进而约束了网络的表达能力。
在WGAN-gp论文中,它提到了WGAN使用clip方式所引发的问题。
重点看看下面的右边(b)这张图,很多颜色线条那个是随着判别器层数增加, Clip 方案中梯度传导是有问题的,要么爆炸要么消失了,而 Gradient penalty 方案可以让每一层的梯度都比较稳定。
再来看看最右边的图, Clip 方案网络中 weights 参数都跑到的极端的地方,要么最大,要么最小,而 Gradient penalty 方案可以让 weights 比较均匀地分布。
WGAN-gp的训练
【生成对抗网络学习 其一】经典GAN与其存在的问题和相关改进的更多相关文章
- 【生成对抗网络学习 其三】BiGAN论文阅读笔记及其原理理解
参考资料: 1.https://github.com/dragen1860/TensorFlow-2.x-Tutorials 2.<Adversarial Feature Learning> ...
- 生成对抗网络(GAN)
GAN的全称是 Generative Adversarial Networks,中文名称是生成对抗网络.原始的GAN是一种无监督学习方法,巧妙的利用“博弈”的思想来学习生成式模型. 1 GAN的原理 ...
- 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】
本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...
- 生成对抗网络(GAN)相关链接汇总
1.基础知识 创始人的介绍: “GANs之父”Goodfellow 38分钟视频亲授:如何完善生成对抗网络?(上) “GAN之父”Goodfellow与网友互动:关于GAN的11个问题(附视频) 进一 ...
- 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...
- 深度学习-生成对抗网络GAN笔记
生成对抗网络(GAN)由2个重要的部分构成: 生成器G(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器 判别器D(Discriminator):判断这张图像是真实的 ...
- 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...
- 生成对抗网络(Generative Adversarial Networks,GAN)初探
1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...
- AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华
注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...
随机推荐
- 一篇文章带你整明白HTTP缓存知识
最近看了很多关于缓存的文章, 每次看完,看似明白但是实际还是没明白,这次总算搞明白协商缓存是怎么回事了 首先,服务器缓存分强制缓存和协商缓存(也叫对比缓存) 强制缓存一般是服务端在请求头携带字段Exp ...
- python入门-开始
1.为啥要学Python? 各种语言的优劣势对比视频版:https://www.bilibili.com/video/BV1y3411r7pX/?spm_id_from=autoNext 各种语言的优 ...
- 利用js获取不同页面间跳转需要传递的参数
获取参数的js函数如下: function GetQueryValue(queryName) { var query = decodeURI(window.location.search.substr ...
- Java语言学习day35--8月10日
今日内容介绍1.集合2.Iterator迭代器3.增强for循环4.泛型 ###01集合使用的回顾 *A:集合使用的回顾 *a.ArrayList集合存储5个int类型元素 public static ...
- CDN绕过
信息收集_CDN绕过 什么是CDN?为什么要绕过? CDN全称是内容分发网络(content delivery network).其目的是让用户能够更快速的得到请求的数据. 网上找了一张图片, ...
- jmeter并发设置的原理
目录 简介 广义并发 绝对并发 简介 性能测试过程中是否需要进行同步定时器的设置,需要根据实际情况来考虑. 举个栗子来讲是我们的双十一秒杀活动,这时候就必须实现请求数量达到一定数量后同时向服务 ...
- redis:缓存穿透、缓存击穿、缓存雪崩
缓存穿透的解决方案(空标记) 缓存穿透是指,在数据存储系统中不存在的记录,不会被存储到缓存中.这种记录每次的查询流量都会穿透到数据存储层.在高流量的场景下,不断查询空结果会大量消耗数据查询服务的资源, ...
- 攻防世界-MISC:ext3
这是攻防世界新手练习区的第九题,题目如下: 点击下载附件1,通过题目描述可知这是一个Linux系统光盘,用010editor打开,搜索flag,发现存在flag.txt文件 将该文件解压,找到flag ...
- Oracle 数据库表删除重复数据
删除重复数据并保留一条 方法一 1.建立临时表,记录重复的数据 create table 临时表 as select a.字段1,a.字段2,max(a.rowid) as dataid from 原 ...
- JSP标签、JSTL标签、EL表达
JSP页面转发,附带数据 <jsp:forward page="/jsptag2.jsp"> <jsp:param name="name" v ...