生成对抗网络(Generative Adversarial Networks,GAN)初探
1. 从纳什均衡(Nash equilibrium)说起
我们先来看看纳什均衡的经济学定义:
所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处。换句话说,如果在一个策略组合上,当所有其他人都不改变策略时,没有人会改变自己的策略,则该策略组合就是一个纳什均衡。
B站上有一个关于”海滩2个兄弟卖雪糕“形成纳什均衡的视频,讲的很生动。

0x1:价格战中的纳什均衡
市场上有2家企业A和B,都是卖纸的,纸的成本都是2元钱,A和B都卖5块钱。在最开始,A、B企业都是盈利3块,这种状态叫”社会最优解(Social optimal solution)“。但问题是,社会最优解是一个不稳定的状态,就如同下图中这个优化曲面上那个红球点一样,虽然该小球目前处于曲面最高点,但是只要施加一些轻微的扰动,小球就会立刻向山下滑落:
现在企业A和B准备开展商业竞争:
- 有一天,A企业率先降价到4块钱,于是A销量大增,B销量大减。
- B看到了后,降价到3块钱,于是B销量大增,A销量大减。
- ......
但如果价格战一直这样打下去,这个过程显然不可能无限迭代下去。当A和B都降价到了3块时,双方都达到了成本的临界点,既不敢涨价,也不敢降价。涨价了市场就丢了,降价了,就赚不到钱甚至赔钱。所以A和B都不会再去做改变,这就是纳什均衡。
A和B怎样能够获得最大利润呢,就是A和B坐到一起商量,同时把价格提高,这就叫共谋,但法律为了保障消费者利益,禁止共谋。补充一句,共谋在机器学习中被称作”模型坍塌“,指的对对抗的模型双方都进入了一个互相认可的局部最优区而不再变化,具体的技术细节我们后面会讨论。
0x2:囚徒困境中的纳什均衡
囚徒困境是说:有两个小偷集体作案,然后被警察捉住。警察对两个人分别审讯,并且告诉他们政策:
- 如果两个人都交代坦白,就可以定罪,两个人各判八年。
- 如果一个人交代另一个不交代,那么一样可以定罪。但是交代的人从宽处罚,批评教育就释放。不交代的人从严处罚,判十年。
- 如果两个人都不交代,没法定罪,每个人判一年意思一下。
两个人的收益情况如下所示:
因为A和B是不能互相通信的,因此这是一个静态不完全信息博弈,我们分别考虑双方的决策面:
- A的决策。A会想,我如何才能获得更大收益呢?
- 先考虑最坏的情况:如果B坦白了,那么我坦白就会判8年,我抗拒就会判十年,我应该坦白;
- 再考虑最好的情况:如果B抗拒了,我坦白会判0年,我抗拒会判1年,我还是应该坦白;
- 所以最终A会选择坦白。
- 同样,B也会这样想。
因此最终纳什均衡点在两个人都坦白,各判八年这里。
显然,集体最优解在两个人都抗拒,这样一来每个人都判一年就出来了。但是,纳什均衡点却不在这里。而且,在纳什均衡点上,任何一个人都没有改变自己决策的动力。因为一旦单方面改变决策,那个人的收益就会下降。
0x3:开车加塞现象的纳什均衡
我们知道,在国内开车夹塞很常见。如果大家都不夹塞,是整体的最优解,但是按照纳什均衡理论,任何一个司机都会考虑,无论别人是否夹塞,我夹塞都可以使自己的收益变大。于是最终大家都会夹塞,加剧拥堵,反而不如大家都不加塞走的快。
那么,有没有办法使个人最优变成集体最优呢?方法就是共谋。两个小偷在作案之前可以说好,咱们如果进去了,一定都抗拒。如果你这一次敢反悔,那么以后道上的人再也不会有人跟你一起了。也就是说,在多次博弈过程中,共谋是可能的。但是如果这个小偷想干完这一票就走,共谋就是不牢靠的。
在社会领域,共谋是靠法律完成的。大家约定的共谋结论就是法律,如果有人不按照约定做,就会受到法律的惩罚。通过这种方式保证最终决策从个人最优的纳什均衡点变为集体最优点。
另外一方面,现在很多汽车厂商提出了车联网的概念,在路上的每一辆车都通过物联网连成一个临时网络,所有车按照一个最优的协同算法共同协定最优的行车路线、行车速度、路口等待等行为,这样整体交通可以达到一个整体最优,所有人都节省了时间。
0x3:枪手博弈
彼此痛恨的甲、乙、丙三个枪手准备决斗,他们各自的水平如下:
- 甲枪法最好,十发八中;
- 乙枪法次之,十发六中;
- 丙枪法最差,十发四中;
1. 场景一:三人同时开枪,并且每人只发一枪。每一轮枪战后,谁活下来的机会大一些?
首先明确一点,这是一个静态不完全信息博弈,每个抢手在开枪前都不知道其他对手的策略,只能在猜测其他对手策略的基础上,选择对自己最优的策略。
我们来分析一下第一轮枪战各个枪手的策略。
- 枪手甲一定要对枪手乙先开枪。因为乙对甲的威胁要比丙对甲的威胁更大,甲应该首先干掉乙,这是甲的最佳策略。
- 同样的道理,枪手乙的最佳策略是第一枪瞄准甲。乙一旦将甲干掉,乙和丙进行对决,乙胜算的概率自然大很多。
- 枪手丙的最佳策略也是先对甲开枪。乙的枪法毕竟比甲差一些,丙先把甲干掉再与乙进行对决,丙的存活概率还是要高一些。
第一轮枪战过后,有几种可能的结果:
- 甲乙双亡,丙获胜
- 甲亡,乙丙存活
- 乙亡,甲丙存活
现在进入第二轮枪战:
除非第一轮甲乙双亡,否则丙就一定处于劣势,因为不论甲或乙,他们的命中率都比丙的命中率为高。
这就是枪手丙的悲哀。能力不行的丙玩些花样虽然能在第一轮枪战中暂时获胜。但是,如果甲乙在第一轮枪战中没有双亡的话,在第二轮枪战结束后,丙的存活的几率就一定比甲或乙为低。
这似乎说明,能力差的人在竞争中耍弄手腕能赢一时,但最终往往不能成事。
2. 场景二:三人轮流开枪,没人只发一枪。丙最后发枪。
我们现在改变游戏规则,假定甲乙丙不是同时开枪,而是他们轮流开一枪。先假定开枪的顺序是甲、乙、丙,我们来分析一下枪战过程:
- 甲一枪将乙干掉后(80%的几率),就轮到丙开枪,丙有40%的几率一枪将甲干掉。
- 乙躲过甲的第一枪(20%几率),轮到乙开枪,乙还是会瞄准枪法最好的甲开枪,即使乙这一枪干掉了甲(60%几率),下一轮仍然是轮到丙开枪(40%几率)。无论是甲或者乙先开枪,乙都有在下一轮先开枪的优势。
如果是丙先开枪,情况又如何呢?
3. 场景三:三人轮流开枪,没人只发一枪。丙第一个发枪。
- 丙可以向甲先开枪(40%几率),
- 即使丙打不中甲,甲的最佳策略仍然是向乙开枪。
- 但是,如果丙打中了甲,下一轮可就是乙开枪打丙了。
- 因此,丙的最佳策略是胡乱开一枪,只要丙不打中甲或者乙,在下一轮射击中他就处于有利的形势(先发优势)。
我们通过这个例子,可以理解人们在博弈中能否获胜,不单纯取决于他们的实力,更重要的是取决于博弈方实力对比所形成的关系。
在上面的例子中,乙和丙实际上是一种联盟关系,先把甲干掉,他们的生存几率都上升了。我们现在来判断一下,乙和丙之中,谁更有可能背叛,谁更可能忠诚?
任何一个联盟的成员都会时刻权衡利弊,一旦背叛的好处大于忠诚的好处,联盟就会破裂。在乙和丙的联盟中,乙是最忠诚的。这不是因为乙本身具有更加忠诚的品质,而是利益关系使然。只要甲不死,乙的枪口就一定会瞄准甲。但丙就不是这样了,丙不瞄准甲而胡乱开一枪显然违背了联盟关系,丙这样做的结果,将使乙处于更危险的境地。
合作才能对抗强敌。只有乙丙合作,才能把甲先干掉。如果,乙丙不和,乙或丙单独对甲都不占优,必然被甲先后解决。、
1966年经典电影《黄金三镖客》中的最后一幕,三个主人公手持枪杆站在墓地中,为了宝藏随时准备决一死战。为了活着拿到宝藏,幸存下来的最优策略是什么呢?
0x4:蒙古联合南宋灭金
当时,蒙古军事实力最强,金国次之,南宋武力最弱。本来南宋应该和金国结盟,帮助金国抵御蒙古的入侵才是上策,或者至少保持中立。但是,当时的南宋采取了和蒙古结盟的政策。南宋当局先是糊涂地同意了拖雷借道宋地伐金。1231年,蒙古军队在宋朝的先遣队伍引导下,借道四川等地,北度汉水歼灭了金军有生力量。
1233年,南宋军队与蒙古军队合围蔡州,金朝最后一个皇帝在城破后死于乱兵,金至此灭亡。1279年,南宋正式亡于蒙古。
如果南宋当政者有战略眼光,捐弃前嫌,与世仇金结盟对抗最强大的敌人蒙古,宋和金都不至于那么快就先后灭亡了。
0x5:智猪博弈
猪圈里面有两只猪, 一只大,一只小。猪圈很长,一头有一个踏板,另一头是饲料的出口和食槽。每踩一下踏板,在远离踏板的猪圈的另一边的投食口就会落下少量的食物。如果有一只猪去踩踏板,另一只猪就有机会抢先吃到另一边落下的食物。
- 当小猪踩动踏板时,大猪会在小猪跑到食槽之前刚好吃光所有的食物;
- 若是大猪踩动了踏板,则还有机会在小猪吃完落下的食物之前跑到食槽,争吃到另一半残羹。
那么,两只猪各会采取什么策略?令人出乎意料的是,答案居然是:小猪将选择“搭便车”策略,也就是舒舒服服地等在食槽边;而大猪则为一点残羹不知疲倦地奔忙于踏板和食槽之间。
原因何在呢?我们来分析一下,首先这是一个静态不完全信息博弈:
- 小猪踩踏板:小猪将一无所获,不踩踏板反而能吃上食物。对小猪而言,无论大猪是否踩动踏板,不踩踏板总是好的选择。
- 反观大猪,已明知小猪是不会去踩动踏板的,自己亲自去踩踏板总比不踩强吧,所以只好亲力亲为了。
“智猪博弈”的结论似乎是,在一个双方公平、公正、合理和共享竞争环境中,有时占优势的一方最终得到的结果却有悖于他的初始理性。这种情况在现实中比比皆是。
比如,在某种新产品刚上市,其性能和功用还不为人所熟识的情况下,如果进行新产品生产的不仅是一家小企业,还有其他生产能力和销售能力更强的企业。那么,小企业完全没有必要作出头鸟,自己去投入大量广告做产品宣传,只要采用跟随战略即可。
“智猪博弈”告诉我们,谁先去踩这个踏板,就会造福全体,但多劳却并不一定多得。
在现实生活中,很多人都只想付出最小的代价,得到最大的回报,争着做那只坐享其成的小猪。“一个和尚挑水喝,两个和尚抬水喝,三个和尚没水喝”说的正是这样一个道理。这三个和尚都想做“小猪”,却不想付出劳动,不愿承担起“大猪”的义务,最后导致每个人都无法获得利益。
0x6:证券市场中的“智猪博弈”
金融证券市场是一个群体博弈的场所,其真实情况非常复杂。在证券交易中,其结果不仅依赖于单个参与者自身的策略和市场条件,也依赖其他人的选择及策略。
在“智猪博弈”的情景中,大猪是占据比较优势的,但是,由于小猪别无选择,使得大猪为了自己能吃到食物,不得不辛勤忙碌,反而让小猪搭了便车,而且比大猪还得意。这个博弈中的关键要素是猪圈的设计, 即踩踏板的成本。
证券投资中也是有这种情形的。例如,当庄家在底位买入大量股票后,已经付出了相当多的资金和时间成本,如果不等价格上升就撤退,就只有接受亏损。
所以,基于和大猪一样的贪吃本能,只要大势不是太糟糕,庄家一般都会抬高股价,以求实现手中股票的增值。这时的中小散户,就可以对该股追加资金,当一只聪明的“小猪”,而让 “大猪”庄家力抬股价。当然,这种股票的发觉并不容易,所以当“小猪”所需要的条件,就是发现有这种情况存在的猪圈,并冲进去。这样,你就成为一只聪明的“小猪”。
股市中,散户投资者与小猪的命运有相似之处,没有能力承担炒作成本,所以就应该充分利用资金灵活、成本低和不怕被套的优势,发现并选择那些机构投资者已经或可能坐庄的股票,等着大猪们为自己服务。
由此看到,散户和机构的博弈中,散户并不是总没有优势的,关键是找到有大猪的那个食槽,并等到对自己有利的游戏规则形成时再进入。
0x7:纳什均衡博弈与GAN网络的关系
GAN的主要灵感来源于博弈论中零和博弈的思想。
应用到深度学习神经网络上来说,就是通过生成网络G(Generator)和判别网络D(Discriminator)不断博弈,进而使 G 学习到数据的分布,同时时 D 获得更好的鲁棒性和泛化能力。
举个例子:用在图片生成上,我们想让最后的 G 可以从一段随机数中生成逼真的图像:
上图中:
G是一个生成式的网络,它接收一个随机的噪声 z(随机数),然后通过这个噪声生成图像。
D是一个判别网络,判别一张图片是不是 “真实的”。它的输入是一张图片,输出的 D(x) 代表 x 为真实图片的概率,如果为 1,就代表 100% 是真实的图片,而输出为 0,就代表不可能是真实的图片。
那么这个训练的过程是什么样子的呢?在训练中:
G 的目标就是尽量生成真实的图片去欺骗判别网络 D。
D的目标就是尽量辨别出G生成的假图像和真实的图像。
这样,G 和 D 就构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点。
Relevant Link:
- https://baijiahao.baidu.com/s?id=1611846467821315306&wfr=spider&for=pc
- https://www.jianshu.com/p/fadba906f5d3
2. GAN网络的思想起源
GAN的起源之作鼻祖是 Ian Goodfellow 在 2014 年发表在 ICLR 的论文:Generative Adversarial Networks”。
按照笔者的理解,提出GAN网络的出发点有如下几个:
- 最核心的作用是提高分类器的鲁棒能力,因为生成器不断生成”尽量逼近真实样本“的伪造图像,而分类器为了能正确区分出伪造和真实的样本,就需要不断地挖掘样本中真正蕴含的潜在概率信息,而抛弃无用的多余特征,这就起到了提高鲁棒和泛化能力的作用。从某种程度上来说,GAN起到了和正则化约束的效果。
- 基于随机扰动,有针对性地生成新样本。但是要注意的一点是,GAN生成的样本并不是完全的未知新样本,GAN的generator生成的新样本更多的侧重点是通过增加可控的扰动来尝试躲避discriminator的检测。实际上,GAN对生成0day样本的能力很有限。
为了清楚地阐述这个概念,笔者先从对抗样本这个话题开始说起。
0x1:对抗样本(adversarial example)
对抗样本(adversarial example)是指经过精心计算得到的用于误导分类器的样本。例如下图就是一个例子,左边是一个熊猫,但是添加了少量随机噪声变成右图后,分类器给出的预测类别却是长臂猿,但视觉上左右两幅图片并没有太大改变。
出现这种情况的原因是什么呢?
简单来说,就是预测器发生了过拟合。图像分类器本质上是高维空间的一个复杂的决策函数,在高维空间上,图像分类器过分考虑了全像素区间内的细节信息,导致预测器对图像的细节信息太敏感,微小的扰动就可能导致预测器的预测行为产生很大的变化。
关于这个话题,笔者在另一篇文章中对过拟合现象以及规避方法进行了详细讨论。
除了添加”随机噪声驱动的像素扰动”这种方法之外,还可以通过图像变形的方式,使得新图像和原始图像视觉上一样的情况下,让分类器得到有很高置信度的错误分类结果。这种过程也被称为对抗攻击(adversarial attack)。
0x2:有监督驱动的无监督学习
人类通过观察和体验物理世界来学习,我们的大脑十分擅长预测,不需要显式地经过复杂计算就可以得到正确的答案。监督学习的过程就是学习数据和标签之间的相关关系。
但是在非监督学习中,数据并没有被标记,而且目标通常也不是对新数据进行预测。
在现实世界中,标记数据是十分稀有和昂贵的。生成对抗网络通过生成伪造的/合成的数据并尝试判断生成样本真伪的方法学习,这本质上相当于采用了监督学习的方法来做无监督学习。做分类任务的判别器在这里是一个监督学习的组件,生成器的目标是了解真实数据的模样(概率分布),并根据学到的知识生成新的数据。
Relevant Link:
- https://www.jiqizhixin.com/articles/2018-03-05-4
3. GAN网络基本原理
GAN网络发展到如今已经有很多的变种,在arxiv上每天都会有大量的新的研究论文被提出。但是笔者这里不准备枚举所有的网络结构,而是仅仅讨论GAN中最核心的思想,通过笔者自己的论文阅读,将我认为最精彩的思想和学术创新提炼出来给大家,今后我们也可以根据自己的理解,将其他领域的思想交叉引入进来,继续不断创新发展。
0x1:GAN的组成
经典的GAN网络由两部分组成,分别称之为判别器D和生成器G,两个网络的工作原理可以如下图所示,
D 的目标就是判别真实图片和 G 生成的图片的真假,而 G 是输入一个随机噪声来生成图片,并努力欺骗 D。
简单来说,GAN 的基本思想就是一个最小最大定理,当两个玩家(D 和 G)彼此竞争时(零和博弈),双方都假设对方采取最优的步骤而自己也以最优的策略应对(最小最大策略),那么结果就会进入一个确定的均衡状态(纳什均衡)。
0x2:损失函数分析
1. 生成器(generator)损失函数
生成器网络以随机的噪声z作为输入并试图生成样本数据,并将生成的伪造样本数据提供给判别器网络D,
可以看到,G 网络的训练目标就是让 D(G(z)) 趋近于 1,即完全骗过判别器(判别器将生成器生成的伪造样本全部误判为真)。G 网络通过接受 D 网络的反馈作为梯度改进方向,通过BP过程反向调整自己的网络结构参数。
2. 判别器(discriminator)
判别器网络以真实数据x或者伪造数据G(z)作为输入,并试图预测当前输入是真实数据还是生成的伪造数据,并产生一个【0,1】范围内的预测标量值。
D 网络的训练目标是区分真假数据,D 网络的训练目标是让 D(x) 趋近于 1(真实的样本判真),而 D(G(z)) 趋近于0(伪造的样本判黑)。D 网络同时接受真实样本和 G 网络传入的伪造样本作为梯度改进方向,,通过BP过程反向调整自己的网络结构参数。
3. 综合损失函数
生成器和判别器网络的损失函数结合起来就是生成对抗网络(GAN)的综合损失函数:
两个网络相互对抗,彼此博弈,如上所示,综合损失函数是一个极大极小函数;
- 损失函数第一项:会驱使判别器尽量将真实样本都判真
- 损失函数第二项:会驱使判别器尽量将伪造样本都判黑。但同时,生成器G会对抗这个过程
整个相互对抗的过程,Ian Goodfellow 在论文中用下图来描述:
黑色曲线表示输入数据 x 的实际分布,绿色曲线表示的是 G 网络生成数据的分布,紫色的曲线表示的是生成数据对应于 D 的分布的差异距离(KL散度)
GAN网络训练的目标是希望着实际分布曲线x,和G网络生成的数据,两条曲线可以相互重合,也就是两个数据分布一致(达到纳什均衡)。
- a图:网络刚开始训练,D 的分类能力还不是最好,因此有所波动,而生成数据的分布也自然和真实数据分布不同,毕竟 G 网络输入是随机生成的噪声;
- b图:随着训练的进行,D 网络的分类能力就比较好了,可以看到对于真实数据和生成数据,它是明显可以区分出来,也就是给出的概率是不同的;
- c图:由于 D 网络先行提高的性能,随后 G 网络开始追赶,G 网络的目标是学习真实数据的分布,即绿色的曲线,所以它会往蓝色曲线方向移动。因为 G 和 D 是相互对抗的,当 G 网络提升,也会影响 D 网络的分辨能力;
- d图:当假设 G 网络不变(G已经优化到收敛状态),继续训练 D 网络,最优的情况会是
,也就是当生成数据的分布
趋近于真实数据分布
的时候,D 网络输出的概率
会趋近于 0.5(真实样本和伪造样本各占一半,生成器无法再伪造了,判别器也无法再优化了,也可以说对于判别器来说其无法从样本中区分中真实样本和伪造样本),这也是最终希望达到的训练结果,这时候 G 和 D 网络也就达到一个平衡状态。
0x3:算法伪码流程
论文给出的算法实现过程如下所示:
一些细节需要注意:
- 首先 G 和 D 是同步训练,但两者训练次数不一样,通常是 D 网络训练 k 次后,G 训练一次。主要原因是 GAN 刚开始训练时候会很不稳定,需要让判别器D尽快先进入收敛区间;
- D 的训练是同时输入真实数据和生成数据来计算 loss,而不是采用交叉熵(cross entropy)分开计算。不采用 cross entropy 的原因是这会让 D(G(z)) 变为 0,导致没有梯度提供给 G 更新,而现在 GAN 的做法是会收敛到 0.5;
- 实际训练的时候,作者是采用
来代替
,这是希望在训练初始就可以加大梯度信息,这是因为初始阶段 D 的分类能力会远大于 G 生成足够真实数据的能力,但这种修改也将让整个 GAN 不再是一个完美的零和博弈。
0x4:算法的优点
GAN的巧妙之处在于其目标函数的设定,因为此,GAN有如下几个优点:
- GAN 中的 G 作为生成模型,不需要像传统图模型一样,需要一个严格的生成数据的概率表达式。这就避免了当数据非常复杂的时候,复杂度过度增长导致的不可计算。
- GAN 不需要 inference 模型中的一些庞大计算量的求和计算。它唯一的需要的就是,一个噪音输入,一堆无标准的真实数据,两个可以逼近函数的网络。
0x5:算法的挑战与缺陷
- 启动及初始化的问题:GAN的训练目标是让生成器和判别器最终达到一个纳什均衡状态,此时两个网络都无法继续再往前做任何优化,优化结束。梯度下降的启动会选择一个减小所定义问题损失的方法,但是并没有理论保证GAN一定可以100%进入纳什均衡状态,这是一个高维度的非凸优化目标。网络试图在接下来的步骤中最小化非凸优化目标,但是最终可能导致进入震荡而不是收敛到底层真实目标。
- GAN 过于自由导致训练难以收敛以及不稳定。
- 梯度消失问题:原始 G 的损失函数
没有意义,它是让 G 最小化 D 识别出自己生成的假样本的概率,但实际上它会导致梯度消失问题,这是由于开始训练的时候,G 生成的图片非常糟糕,D 可以轻而易举的识别出来,这样 D 的训练没有任何损失,也就没有有效的梯度信息回传给 G 去优化它自己,这就是梯度消失了。最后,虽然作者意识到这个问题,在实际应用中改用
来代替,这相当于从最小化 D 揪出自己的概率,变成了最大化 D 抓不到自己的概率。虽然直观上感觉是一致的,但其实并不在理论上等价,也更没有了理论保证在这样的替代目标函数训练下,GAN 还会达到平衡。这个结果会进一步导致模式奔溃问题。
- 模型坍塌:基本原理是生成器可能会在某种情况下重复生成完全一致的图像(也可以理解为梯度消失),这其中的原因和博弈论中的启动问题相关。我们可以这样来想象GAN的训练过程,
- 先从判别器的角度试图最大化,再从生成器的角度试图最小化。如果生成器最小化开始之前,判别器已经完全最大化,所有工作还可以正常运行;
- 如果首先最小化生成器,再从判别器的角度试图最大化。如果判别器最大化开始之前,生成器已经完全最小化,那么工作就无法运行。原因在于如果我们保持判别器不变,它会将空间中的某些点标记为最有可能是真的而不是假的(因为生成器已经最小化了),这样生成器就会选择将所有的噪声输入映射到那些最可能为真的点上,这就陷入了局部最优的陷阱中了,优化过程就提前停止了。
0x6:提升GAN训练效果的一些方法
1. 中间层特征驱动损失函数
2. 小批量度量输入样本相似度
3. 引入历史平均
4. 单侧标签平滑
5. 输入规范化
6. 批规范化
7. 利用ReLU和MaxPool避免梯度稀疏
Relevant Link:
- https://arxiv.org/pdf/1406.2661.pdf
- https://juejin.im/post/5bdd70886fb9a049f912028d
- http://www.iterate.site/2018/07/27/gan-%E7%94%9F%E6%88%90%E5%AF%B9%E6%8A%97%E7%BD%91%E7%BB%9C%E4%BB%8B%E7%BB%8D/
4. 从生成模型和判别模型的概率视角看GAN
在阅读了很多GAN衍生论文以及GAN原始论文之后,笔者一直在思考的一个问题是:GAN背后的底层思想是什么?GAN衍生和改进算法的灵感和思路又是从哪里来的?
经过一段时间思考以及和同行同学讨论后,我得出了一些思考,这里分享如下,希望对读者朋友有帮助。
我们先来看什么是判别模型和生成模型:
- 判别式模型学习某种分布下的条件概率p(y|x),即在特定x条件下y发生的概率。判别器模型十分依赖数据的质量,概率分布p(y|x)可以直接将一个特定的x分类到某个标签y上。以逻辑回归为例,我们所需要做的是最小化损失函数。
- 生成式模型学习的是联合分布概率p(x,y),x是输入数据,y是所期望的分类。一个生成模型可以根据当前数据的假设生成更多新样本。
从概率论的视角来看,我们来看一下原始GAN网络的架构:
- 生成器本质上是一个由输入向量和生成器结构所代表的向量组成的联合概率分布P(v_input, v_G_structure)
- v_input:代表一种输入向量,可以是随机噪声向量z
- v_G_structure:网络本质上是对输入向量进行线性和非线性变化,因为可以将其抽象为一个动态变化的向量函数
- 判别器本质上是一个由(真实样本,伪造样本)作为输入x,进行后验预测p(y|x)的概率模型
遵循这种框架进行思考,CGAN只是将v_input中的随机噪声z替换成了另一种向量(文本或者标签向量),而Pix2pixGAN是将一个图像向量作为v_input输入GAN网络。
5. 从原始GAN网络中衍生出的流行GAN架构
GAN的发展离不开goodfellow后来的学者们不断的研究与发展,目前已经提出了很多优秀的新GAN架构,并且这个发展还在继续。为了让本博文能保持一定的环境独立性,笔者这里不做完整的罗列与枚举,相反,笔者希望从两条脉络来展开讨论:
- 解决问题导向:为了解决原始GAN或者当前学术研究中发现的关于GAN网络的性能和架构问题而提出的新理论与新框架
- 新场景应用导向:为了将GAN应用在新的领域中而提出的新的GAN架构
0x1:DCGAN(Deep Convolutional Generative Adversarial Networks)
Alec Radford,Luke Metz,Soumith Chintala等人在“Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks”提出了DCGAN。这是GAN研究的一个重要里程碑,因为它提出了一个重要的架构变化来解决训练不稳定,模式崩溃和内部协变量转换等问题。从那时起,基于DCGAN的架构就被应用到了许多GAN架构。
DCGAN的提出主要是为了解决原始GAN架构的原生架构问题,我们接下来来讨论下。
1. 生成器的架构优化
生成器从潜在空间中得到100维噪声向量z,通过一系列卷积和上采样操作,将z映射到一个像素矩阵对应的空间中,如下图:
DCGAN通过下面的一些架构性约束来固化网络:
- 在判别器中使用步数卷积来取代池化层,在生成器中使用小步数卷积来取代池化层;
- 在生成器和判别器中均使用批规范化,批规范化是一种通过零均值和单位方差的方法进行输入规范化使得学习过程固话的技术。这项技术在实践中被证实可以在许多场合提升训练速度,减少初始化不佳带来的启动问题,并且通常能产生更准确的结果;
- 消除原架构中较深的全连接隐藏层,并且在最后只使用简单的平均值池化;
- 在生成器输出层使用tanh,在其它层均使用ReLU激发;
- 在判别器的所有层中都使用Leaky ReLU激发;
2. 模型训练
生成器和判别器都是通过binary_crossentropy作为损失函数来进行训练的。之后的每个阶段,生成器产生一个MNIST图像,判别器尝试在真实MNIST图像和生成图像的数据集中进行学习。
经过一段时间后,生成器就可以自动学会如何制作伪造的数字。
- from __future__ import print_function, division
- from keras.datasets import mnist
- from keras.layers import Input, Dense, Reshape, Flatten, Dropout
- from keras.layers import BatchNormalization, Activation, ZeroPadding2D
- from keras.layers.advanced_activations import LeakyReLU
- from keras.layers.convolutional import UpSampling2D, Conv2D
- from keras.models import Sequential, Model
- from keras.optimizers import Adam
- import matplotlib.pyplot as plt
- import sys
- import numpy as np
- class DCGAN():
- def __init__(self):
- # Input shape
- self.img_rows =
- self.img_cols =
- self.channels =
- self.img_shape = (self.img_rows, self.img_cols, self.channels)
- self.latent_dim =
- optimizer = Adam(0.0002, 0.5)
- # Build and compile the discriminator
- self.discriminator = self.build_discriminator()
- self.discriminator.compile(
- loss='binary_crossentropy',
- optimizer=optimizer,
- metrics=['accuracy']
- )
- # Build the generator
- self.generator = self.build_generator()
- # The generator takes noise as input and generates imgs
- z = Input(shape=(self.latent_dim,))
- img = self.generator(z)
- # For the combined model we will only train the generator
- self.discriminator.trainable = False
- # The discriminator takes generated images as input and determines validity
- valid = self.discriminator(img)
- # The combined model (stacked generator and discriminator)
- # Trains the generator to fool the discriminator
- self.combined = Model(z, valid)
- self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
- def build_generator(self):
- model = Sequential()
- model.add(Dense( * * , activation="relu", input_dim=self.latent_dim))
- model.add(Reshape((, , )))
- model.add(UpSampling2D())
- model.add(Conv2D(, kernel_size=, padding="same"))
- model.add(BatchNormalization(momentum=0.8))
- model.add(Activation("relu"))
- model.add(UpSampling2D())
- model.add(Conv2D(, kernel_size=, padding="same"))
- model.add(BatchNormalization(momentum=0.8))
- model.add(Activation("relu"))
- model.add(Conv2D(self.channels, kernel_size=, padding="same"))
- model.add(Activation("tanh"))
- model.summary()
- noise = Input(shape=(self.latent_dim,))
- img = model(noise)
- return Model(noise, img)
- def build_discriminator(self):
- model = Sequential()
- model.add(Conv2D(, kernel_size=, strides=, input_shape=self.img_shape, padding="same"))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.25))
- model.add(Conv2D(, kernel_size=, strides=, padding="same"))
- model.add(ZeroPadding2D(padding=((,),(,))))
- model.add(BatchNormalization(momentum=0.8))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.25))
- model.add(Conv2D(, kernel_size=, strides=, padding="same"))
- model.add(BatchNormalization(momentum=0.8))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.25))
- model.add(Conv2D(, kernel_size=, strides=, padding="same"))
- model.add(BatchNormalization(momentum=0.8))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.25))
- model.add(Flatten())
- model.add(Dense(, activation='sigmoid'))
- model.summary()
- img = Input(shape=self.img_shape)
- validity = model(img)
- return Model(img, validity)
- def train(self, epochs, batch_size=, save_interval=):
- # Load the dataset
- (X_train, _), (_, _) = mnist.load_data()
- # Rescale - to
- X_train = X_train / 127.5 - .
- X_train = np.expand_dims(X_train, axis=)
- # Adversarial ground truths
- valid = np.ones((batch_size, ))
- fake = np.zeros((batch_size, ))
- for epoch in range(epochs):
- # ---------------------
- # Train Discriminator
- # ---------------------
- # Select a random half of images
- idx = np.random.randint(, X_train.shape[], batch_size)
- imgs = X_train[idx]
- # Sample noise and generate a batch of new images
- noise = np.random.normal(, , (batch_size, self.latent_dim))
- gen_imgs = self.generator.predict(noise)
- # Train the discriminator (real classified as ones and generated as zeros)
- d_loss_real = self.discriminator.train_on_batch(imgs, valid)
- d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
- d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
- # ---------------------
- # Train Generator
- # ---------------------
- # Train the generator (wants discriminator to mistake images as real)
- g_loss = self.combined.train_on_batch(noise, valid)
- # Plot the progress
- print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[], *d_loss[], g_loss))
- # If at save interval => save generated image samples
- if epoch % save_interval == :
- self.save_imgs(epoch)
- def save_imgs(self, epoch):
- r, c = ,
- noise = np.random.normal(, , (r * c, self.latent_dim))
- gen_imgs = self.generator.predict(noise)
- # Rescale images -
- gen_imgs = 0.5 * gen_imgs + 0.5
- fig, axs = plt.subplots(r, c)
- cnt =
- for i in range(r):
- for j in range(c):
- axs[i,j].imshow(gen_imgs[cnt, :,:,], cmap='gray')
- axs[i,j].axis('off')
- cnt +=
- fig.savefig("images/mnist_%d.png" % epoch)
- plt.close()
- if __name__ == '__main__':
- dcgan = DCGAN()
- dcgan.train(epochs=, batch_size=, save_interval=)
DCGAN产生的手写数字输出
0x2:CGAN(Conditional GAN,CGAN)
1. 有输入条件约束的生成器网络架构
CGAN由Mehdi Mirza,Simon Osindero在论文“Conditional Generative Adversarial Nets”中首次提出。
在条件GAN中,生成器并不是从一个随机的噪声分布中开始学习,而是通过一个特定的条件或某些特征(例如一个图像标签或者一些文本信息)开始学习如何生成伪造样本。
在CGAN中,生成器和判别器的输入都会增加一些条件变量y,这样判别器D(x,y)和生成器G(z,y)都有了一组联合条件变量。
我们将CGAN的目标函数和GAN进行对比会发现:
GAN目标函数
CGAN目标函数
GAN和CGAN的损失函数区别在于判别器和生成器多出来一个参数y,架构上,CGAN相比于GAN增加了一个输入层条件向量C,同时连接了判别器和生成器网络。
2. 训练过程
在训练过程,我们将y输入给生成器和判别器网络。
- from __future__ import print_function, division
- from keras.datasets import mnist
- from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
- from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
- from keras.layers.advanced_activations import LeakyReLU
- from keras.layers.convolutional import UpSampling2D, Conv2D
- from keras.models import Sequential, Model
- from keras.optimizers import Adam
- import matplotlib.pyplot as plt
- import numpy as np
- class CGAN():
- def __init__(self):
- # Input shape
- self.img_rows =
- self.img_cols =
- self.channels =
- self.img_shape = (self.img_rows, self.img_cols, self.channels)
- self.num_classes =
- self.latent_dim =
- optimizer = Adam(0.0002, 0.5)
- # Build and compile the discriminator
- self.discriminator = self.build_discriminator()
- self.discriminator.compile(
- loss=['binary_crossentropy'],
- optimizer=optimizer,
- metrics=['accuracy']
- )
- # Build the generator
- self.generator = self.build_generator()
- # The generator takes noise and the target label as input
- # and generates the corresponding digit of that label
- noise = Input(shape=(self.latent_dim,))
- label = Input(shape=(,))
- img = self.generator([noise, label])
- # For the combined model we will only train the generator
- self.discriminator.trainable = False
- # The discriminator takes generated image as input and determines validity
- # and the label of that image
- valid = self.discriminator([img, label])
- # The combined model (stacked generator and discriminator)
- # Trains generator to fool discriminator
- self.combined = Model([noise, label], valid)
- self.combined.compile(loss=['binary_crossentropy'],
- optimizer=optimizer)
- def build_generator(self):
- model = Sequential()
- model.add(Dense(, input_dim=self.latent_dim))
- model.add(LeakyReLU(alpha=0.2))
- model.add(BatchNormalization(momentum=0.8))
- model.add(Dense())
- model.add(LeakyReLU(alpha=0.2))
- model.add(BatchNormalization(momentum=0.8))
- model.add(Dense())
- model.add(LeakyReLU(alpha=0.2))
- model.add(BatchNormalization(momentum=0.8))
- model.add(Dense(np.prod(self.img_shape), activation='tanh'))
- model.add(Reshape(self.img_shape))
- model.summary()
- noise = Input(shape=(self.latent_dim,))
- label = Input(shape=(,), dtype='int32')
- label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
- model_input = multiply([noise, label_embedding])
- img = model(model_input)
- return Model([noise, label], img)
- def build_discriminator(self):
- model = Sequential()
- model.add(Dense(, input_dim=np.prod(self.img_shape)))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dense())
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.4))
- model.add(Dense())
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.4))
- model.add(Dense(, activation='sigmoid'))
- model.summary()
- img = Input(shape=self.img_shape)
- label = Input(shape=(,), dtype='int32')
- label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
- flat_img = Flatten()(img)
- model_input = multiply([flat_img, label_embedding])
- validity = model(model_input)
- return Model([img, label], validity)
- def train(self, epochs, batch_size=, sample_interval=):
- # Load the dataset
- (X_train, y_train), (_, _) = mnist.load_data()
- # Configure input
- X_train = (X_train.astype(np.float32) - 127.5) / 127.5
- X_train = np.expand_dims(X_train, axis=)
- y_train = y_train.reshape(-, )
- # Adversarial ground truths
- valid = np.ones((batch_size, ))
- fake = np.zeros((batch_size, ))
- for epoch in range(epochs):
- # ---------------------
- # Train Discriminator
- # ---------------------
- # Select a random half batch of images
- idx = np.random.randint(, X_train.shape[], batch_size)
- imgs, labels = X_train[idx], y_train[idx]
- # Sample noise as generator input
- noise = np.random.normal(, , (batch_size, ))
- # Generate a half batch of new images
- gen_imgs = self.generator.predict([noise, labels])
- # Train the discriminator
- d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
- d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
- d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
- # ---------------------
- # Train Generator
- # ---------------------
- # Condition on labels
- sampled_labels = np.random.randint(, , batch_size).reshape(-, )
- # Train the generator
- g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)
- # Plot the progress
- print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[], *d_loss[], g_loss))
- # If at save interval => save generated image samples
- if epoch % sample_interval == :
- self.sample_images(epoch)
- def sample_images(self, epoch):
- r, c = ,
- noise = np.random.normal(, , (r * c, ))
- sampled_labels = np.arange(, ).reshape(-, )
- gen_imgs = self.generator.predict([noise, sampled_labels])
- # Rescale images -
- gen_imgs = 0.5 * gen_imgs + 0.5
- fig, axs = plt.subplots(r, c)
- cnt =
- for i in range(r):
- for j in range(c):
- axs[i,j].imshow(gen_imgs[cnt,:,:,], cmap='gray')
- axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
- axs[i,j].axis('off')
- cnt +=
- fig.savefig("images/%d.png" % epoch)
- plt.close()
- if __name__ == '__main__':
- cgan = CGAN()
- cgan.train(epochs=, batch_size=, sample_interval=)
根据输入数字生成对应的MNIST手写数字图像
0x3:CycleGAN(Cycle Consistent GAN,循环一致生成网络)
CycleGANs 由Jun-Yan Zhu,Taesung Park,Phillip Isola和Alexei A. Efros在题为“Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks”的论文中提出
CycleGAN用来实现不需要其他额外信息,就能将一张图像从源领域映射到目标领域的方法,例如将照片转换为绘画,将夏季拍摄的照片转换为冬季拍摄的照片,或将马的照片转换为斑马照片,或者相反。总结来说,CycleGAN常备用于不同的图像到图像翻译。
1. 循环网络架构
CycleGAN背后的核心思想是两个转换器F和G,其中:
- F会将图像从域A转换到域B;
- G会将图像从域B转换到域A;
因此,
- 对于一个在域A的图像x,我们期望函数G(F(x))的结果与x相同,即 x == G(F(x));
- 对于一个在域B的图像y,我们期望函数F(G(y))的结果与y相同,即 y == F(G(y));
和原始的GAN结构相比,由单个G->D的单向开放结构,变成了由两对G<->D组成的双向循环的封闭结构,但形式上依然是G给D输入伪造样本。但区别在于梯度的反馈是双向循环的。
2. 损失函数
CycleGAN模型有以下两个损失函数:
- 对抗损失(Adversarial Loss):判别器和生成器之间互相对抗的损失,这就是原始GAN网络的损失函数公式:
- 循环一致损失(Cycle Consistency Loss):综合权衡转换器F和G的损失,F和G之间是编码与解码的对抗关系,不可能同时取到最小值,只能得到整体的平衡最优值:
完整的CycleGAN目标函数如下:
- from __future__ import print_function, division
- import scipy
- from keras.datasets import mnist
- from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
- from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
- from keras.layers import BatchNormalization, Activation, ZeroPadding2D
- from keras.layers.advanced_activations import LeakyReLU
- from keras.layers.convolutional import UpSampling2D, Conv2D
- from keras.models import Sequential, Model
- from keras.optimizers import Adam
- import datetime
- import matplotlib.pyplot as plt
- import sys
- from data_loader import DataLoader
- import numpy as np
- import os
- class CycleGAN():
- def __init__(self):
- # Input shape
- self.img_rows =
- self.img_cols =
- self.channels =
- self.img_shape = (self.img_rows, self.img_cols, self.channels)
- # Configure data loader
- self.dataset_name = 'horse2zebra'
- self.data_loader = DataLoader(
- dataset_name=self.dataset_name,
- img_res=(self.img_rows, self.img_cols)
- )
- # Calculate output shape of D (PatchGAN)
- patch = int(self.img_rows / **)
- self.disc_patch = (patch, patch, )
- # Number of filters in the first layer of G and D
- self.gf =
- self.df =
- # Loss weights
- self.lambda_cycle = 10.0 # Cycle-consistency loss
- self.lambda_id = 0.1 * self.lambda_cycle # Identity loss
- optimizer = Adam(0.0002, 0.5)
- # Build and compile the discriminators
- self.d_A = self.build_discriminator()
- self.d_B = self.build_discriminator()
- self.d_A.compile(
- loss='mse',
- optimizer=optimizer,
- metrics=['accuracy']
- )
- self.d_B.compile(
- loss='mse',
- optimizer=optimizer,
- metrics=['accuracy']
- )
- # -------------------------
- # Construct Computational
- # Graph of Generators
- # -------------------------
- # Build the generators
- self.g_AB = self.build_generator()
- self.g_BA = self.build_generator()
- # Input images from both domains
- img_A = Input(shape=self.img_shape)
- img_B = Input(shape=self.img_shape)
- # Translate images to the other domain
- fake_B = self.g_AB(img_A)
- fake_A = self.g_BA(img_B)
- # Translate images back to original domain
- reconstr_A = self.g_BA(fake_B)
- reconstr_B = self.g_AB(fake_A)
- # Identity mapping of images
- img_A_id = self.g_BA(img_A)
- img_B_id = self.g_AB(img_B)
- # For the combined model we will only train the generators
- self.d_A.trainable = False
- self.d_B.trainable = False
- # Discriminators determines validity of translated images
- valid_A = self.d_A(fake_A)
- valid_B = self.d_B(fake_B)
- # Combined model trains generators to fool discriminators
- self.combined = Model(
- inputs=[img_A, img_B],
- outputs=[valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id ]
- )
- self.combined.compile(
- loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],
- loss_weights=[, , self.lambda_cycle, self.lambda_cycle, self.lambda_id, self.lambda_id],
- optimizer=optimizer
- )
- def build_generator(self):
- """U-Net Generator"""
- def conv2d(layer_input, filters, f_size=):
- """Layers used during downsampling"""
- d = Conv2D(filters, kernel_size=f_size, strides=, padding='same')(layer_input)
- d = LeakyReLU(alpha=0.2)(d)
- d = InstanceNormalization()(d)
- return d
- def deconv2d(layer_input, skip_input, filters, f_size=, dropout_rate=):
- """Layers used during upsampling"""
- u = UpSampling2D(size=)(layer_input)
- u = Conv2D(filters, kernel_size=f_size, strides=, padding='same', activation='relu')(u)
- if dropout_rate:
- u = Dropout(dropout_rate)(u)
- u = InstanceNormalization()(u)
- u = Concatenate()([u, skip_input])
- return u
- # Image input
- d0 = Input(shape=self.img_shape)
- # Downsampling
- d1 = conv2d(d0, self.gf)
- d2 = conv2d(d1, self.gf*)
- d3 = conv2d(d2, self.gf*)
- d4 = conv2d(d3, self.gf*)
- # Upsampling
- u1 = deconv2d(d4, d3, self.gf*)
- u2 = deconv2d(u1, d2, self.gf*)
- u3 = deconv2d(u2, d1, self.gf)
- u4 = UpSampling2D(size=)(u3)
- output_img = Conv2D(self.channels, kernel_size=, strides=, padding='same', activation='tanh')(u4)
- return Model(d0, output_img)
- def build_discriminator(self):
- def d_layer(layer_input, filters, f_size=, normalization=True):
- """Discriminator layer"""
- d = Conv2D(filters, kernel_size=f_size, strides=, padding='same')(layer_input)
- d = LeakyReLU(alpha=0.2)(d)
- if normalization:
- d = InstanceNormalization()(d)
- return d
- img = Input(shape=self.img_shape)
- d1 = d_layer(img, self.df, normalization=False)
- d2 = d_layer(d1, self.df*)
- d3 = d_layer(d2, self.df*)
- d4 = d_layer(d3, self.df*)
- validity = Conv2D(, kernel_size=, strides=, padding='same')(d4)
- return Model(img, validity)
- def train(self, epochs, batch_size=, sample_interval=):
- start_time = datetime.datetime.now()
- # Adversarial loss ground truths
- valid = np.ones((batch_size,) + self.disc_patch)
- fake = np.zeros((batch_size,) + self.disc_patch)
- for epoch in range(epochs):
- for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
- # ----------------------
- # Train Discriminators
- # ----------------------
- # Translate images to opposite domain
- fake_B = self.g_AB.predict(imgs_A)
- fake_A = self.g_BA.predict(imgs_B)
- # Train the discriminators (original images = real / translated = Fake)
- dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
- dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
- dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
- dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
- dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
- dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
- # Total disciminator loss
- d_loss = 0.5 * np.add(dA_loss, dB_loss)
- # ------------------
- # Train Generators
- # ------------------
- # Train the generators
- g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
- [valid, valid,
- imgs_A, imgs_B,
- imgs_A, imgs_B])
- elapsed_time = datetime.datetime.now() - start_time
- # Plot the progress
- print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
- % ( epoch, epochs,
- batch_i, self.data_loader.n_batches,
- d_loss[], *d_loss[],
- g_loss[],
- np.mean(g_loss[:]),
- np.mean(g_loss[:]),
- np.mean(g_loss[:]),
- elapsed_time))
- # If at save interval => save generated image samples
- if batch_i % sample_interval == :
- self.sample_images(epoch, batch_i)
- def sample_images(self, epoch, batch_i):
- if not os.path.exists('images/%s' % self.dataset_name):
- os.makedirs('images/%s' % self.dataset_name)
- r, c = ,
- imgs_A = self.data_loader.load_data(domain="A", batch_size=, is_testing=True)
- imgs_B = self.data_loader.load_data(domain="B", batch_size=, is_testing=True)
- # Demo (for GIF)
- #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
- #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')
- # Translate images to the other domain
- fake_B = self.g_AB.predict(imgs_A)
- fake_A = self.g_BA.predict(imgs_B)
- # Translate back to original domain
- reconstr_A = self.g_BA.predict(fake_B)
- reconstr_B = self.g_AB.predict(fake_A)
- gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])
- # Rescale images -
- gen_imgs = 0.5 * gen_imgs + 0.5
- titles = ['Original', 'Translated', 'Reconstructed']
- fig, axs = plt.subplots(r, c)
- cnt =
- for i in range(r):
- for j in range(c):
- axs[i,j].imshow(gen_imgs[cnt])
- axs[i, j].set_title(titles[j])
- axs[i,j].axis('off')
- cnt +=
- fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
- plt.close()
- if __name__ == '__main__':
- gan = CycleGAN()
- gan.train(epochs=, batch_size=, sample_interval=)
苹果->橙子->苹果
有类似架构思想的还有DiscoGAN,相关论文可以在axiv上找到。
0x4:StackGAN
StackJANs由Han Zhang,Tao Xu,Hongsheng Li还有其他人在题为“StackGAN: Text to Photo-Realistic Image Synthesis with Stacked Generative Adversarial Networks”的论文中提出。他们使用StackGAN来探索文本到图像的合成,得到了非常好的结果。
一个StackGAN由一对网络组成,当提供文本描述时,可以生成逼真的图像。
0x5:Pix2pix
pix2pix网络由Phillip Isola,Jun-Yan Zhu,Tinghui Zhou和Alexei A. Efros在他们的题为“Image-to-Image Translation with Conditional Adversarial Networks”的论文中提出。
对于图像到图像的翻译任务,pix2pix也显示出了令人印象深刻的结果。无论是将夜间图像转换为白天的图像还是给黑白图像着色,或者将草图转换为逼真的照片等等,Pix2pix在这些例子中都表现非常出色。
0x6:Age-cGAN(Age Conditional Generative Adversarial Networks)
Grigory Antipov,Moez Baccouche和Jean-Luc Dugelay在他们的题为“Face Aging with Conditional Generative Adversarial Networks”的论文中提出了用条件GAN进行面部老化。
面部老化有许多行业用例,包括跨年龄人脸识别,寻找失踪儿童,或者用于娱乐,本质上它属于cGAN的一种场景应用。
Relevant Link:
- https://arxiv.org/pdf/1511.06434.pdf
- https://github.com/hindupuravinash/the-gan-zoo
- https://github.com/eriklindernoren/Keras-GAN
- https://zhuanlan.zhihu.com/p/63428113
6. 基于GAN自动生成Webshell样本
0x1:原始GAN结构在NLP领域应用的挑战
我们用DNN架构重写原始GAN代码,并使用一批php webshell作为真实样本,尝试用GAN进行伪造样本生成。
- from keras.layers import Input, Dense, Reshape, Flatten, Dropout
- from keras.layers import BatchNormalization, Activation, ZeroPadding2D
- from keras.layers.advanced_activations import LeakyReLU
- from keras.layers.convolutional import UpSampling2D, Conv2D
- from keras.models import Sequential, Model
- from keras.optimizers import Adam
- from keras.preprocessing import sequence
- from sklearn.externals import joblib
- import re
- import os
- import numpy as np
- # np.set_printoptions(threshold=np.nan)
- class DCGAN():
- def __init__(self):
- # Input shape
- self.charlen =
- self.fileshape = (self.charlen, )
- self.latent_dim =
- self.ENCODER = joblib.load("./CHAR_SEQUENCE_TOKENIZER_INDEX_TABLE_PICKLE.encoder")
- self.rerange_dim = (len(self.ENCODER.word_index) + ) / . - 0.5
- optimizer = Adam(0.0002, 0.5)
- # Build and compile the discriminator
- self.discriminator = self.build_discriminator()
- self.discriminator.compile(
- loss='binary_crossentropy',
- optimizer=optimizer,
- metrics=['accuracy']
- )
- # Build the generator
- self.generator = self.build_generator()
- # The generator takes noise as input and generates imgs
- z = Input(shape=(self.latent_dim,))
- img = self.generator(z)
- # For the combined model we will only train the generator
- self.discriminator.trainable = False
- # The discriminator takes generated images as input and determines validity
- valid = self.discriminator(img)
- # The combined model (stacked generator and discriminator)
- # Trains the generator to fool the discriminator
- self.combined = Model(z, valid)
- self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
- def build_generator(self):
- model = Sequential()
- model.add(Dense(, activation="relu"))
- model.add(Dense(, activation="relu"))
- model.add(Dense(, activation="relu"))
- model.add(Dense(, activation="relu"))
- model.add(Dense(self.charlen, activation="relu"))
- # model.summary()
- noise = Input(shape=(self.latent_dim,))
- img = model(noise)
- return Model(noise, img)
- def build_discriminator(self):
- model = Sequential()
- model.add(Dense(, activation="relu"))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.5))
- model.add(Dense(, activation="relu"))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.5))
- model.add(Dense(, activation="relu"))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.5))
- model.add(Dense(, activation="relu"))
- model.add(LeakyReLU(alpha=0.2))
- model.add(Dropout(0.5))
- model.add(Dense(, activation='sigmoid'))
- # model.summary()
- img = Input(shape=self.fileshape)
- validity = model(img)
- return Model(img, validity)
- def train(self, epochs, batch_size=, save_interval=):
- # Load the dataset
- X_train = self.load_webfile_data()
- # Adversarial ground truths
- valid = np.ones((batch_size, ))
- fake = np.zeros((batch_size, ))
- for epoch in range(epochs):
- # ---------------------
- # Train Discriminator
- # ---------------------
- # Select a random half of images
- idx = np.random.randint(, X_train.shape[], batch_size)
- imgs = X_train[idx]
- # Sample noise and generate a batch of new images
- noise = np.random.normal(, , (batch_size, self.latent_dim))
- gen_imgs = self.generator.predict(noise)
- # print gen_imgs
- # print np.shape(gen_imgs)
- # Train the discriminator (real classified as ones and generated as zeros)
- d_loss_real = self.discriminator.train_on_batch(imgs, valid)
- d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
- d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
- # ---------------------
- # Train Generator
- # ---------------------
- # Train the generator (wants discriminator to mistake images as real)
- g_loss = self.combined.train_on_batch(noise, valid)
- # Plot the progress
- print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[], *d_loss[], g_loss))
- # If at save interval => save generated image samples
- if epoch % save_interval == :
- self.save_imgs(epoch)
- def save_imgs(self, epoch):
- r, c = ,
- noise = np.random.normal(, , (r * c, self.latent_dim))
- gen_imgs = self.generator.predict(noise)
- # Rescale [-,] back to [, ascii_char] range
- gen_imgs = (gen_imgs + .) * self.rerange_dim
- gen_text_vec = gen_imgs.reshape((np.shape(gen_imgs)[], self.charlen))
- gen_text_vec = gen_text_vec.astype(int)
- # reconver back to ascii
- #print "gen_text_vec: ", gen_text_vec
- gen_text = self.ENCODER.sequences_to_texts(gen_text_vec)
- #print "gen_text:", gen_text
- with open('./gen_webfile/{0}.txt'.format(epoch), 'wb') as f:
- for file_vec in gen_text:
- fcontent = ""
- for c in file_vec:
- fcontent += c
- fcontent = re.sub(r"\s+", "", fcontent)
- f.write(fcontent)
- def load_webfile_data(self):
- vec_dict = {
- 'raw_ascii': []
- }
- rootDir = "./webdata"
- for lists in os.listdir(rootDir):
- if lists == '.DS_Store':
- continue
- webpath = os.path.join(rootDir, lists)
- with open(webpath, 'r') as fp:
- fcontent = fp.read()
- # remove space
- fcontent = re.sub(r"\s+", " ", fcontent)
- fcontent_ = ""
- for c in fcontent:
- fcontent_ += c + " "
- vec_dict['raw_ascii'].append(fcontent_)
- # convert to ascii sequence vec
- raw_ascii_sequence_vec = self.ENCODER.texts_to_sequences(vec_dict['raw_ascii'])
- raw_ascii_sequence_vec = sequence.pad_sequences(
- raw_ascii_sequence_vec,
- maxlen=self.charlen, padding='post',
- truncating='post',
- dtype='float32'
- )
- # reshape to 2d array
- raw_ascii_sequence_vec = raw_ascii_sequence_vec.reshape((np.shape(raw_ascii_sequence_vec)[], self.charlen))
- # ascii is range in [, ], we need Rescale - to
- print "rerange_dim: ", self.rerange_dim
- raw_ascii_sequence_vec = raw_ascii_sequence_vec / self.rerange_dim - .
- # raw_ascii_sequence_vec = np.expand_dims(raw_ascii_sequence_vec, axis=)
- print "np.shape(raw_ascii_sequence_vec): ", np.shape(raw_ascii_sequence_vec)
- return raw_ascii_sequence_vec
- if __name__ == '__main__':
- dcgan = DCGAN()
- dcgan.train(epochs=, batch_size=, save_interval=)
- #print dcgan.load_webfile_data()
实验的结果并不理想,GAN很快遇到了模型坍塌问题,从G生成的样本来看,网络很快陷入了一个局部最优区间中。
关于这个问题,学术界已经有比较多的讨论和分析,笔者这里列举如下:
- 原始GAN主要应用实数空间(连续型数据)上,在生成离散数据(texts)这个问题上并不work。最初的 GANs 仅仅定义在实数领域,GANs 通过训练出的生成器来产生合成数据,然后在合成数据上运行判别器,判别器的输出梯度将会产生梯度反馈,告诉生成器如何通过略微改变合成数据而使其更加现实。一般来说只有在数据连续的情况下,生成器才可以略微改变合成的数据,而如果数据是离散的,则不能简单的通过改变合成数据。例如,如果你输出了一张图片,其像素值是1.0,那么接下来你可以将这个值改为1.0001。如果输出了一个单词“penguin”,那么接下来就不能将其改变为“penguin + .001”,因为没有“penguin +.001”这个单词。 因为所有的自然语言处理(NLP)的基础都是离散值,如“单词”、“字母”或者“音节”。
Sparse reward:adversarial training 没起作用很大的一个原因就在于,discriminator 提供的 reward 具备的 guide signal 太少,Classifier-based Discriminator 提供的只是一个为真或者假的概率作为 reward,而这个 reward 在大部分情况下,是 0。这是因为对于 CNN 来说,分出 fake text 和 real text 是非常容易的,CNN 能在 Classification 任务上做到 99% 的 accuracy,而建模 Language Model 来进行生成,是非常困难的。除此以外,即使 generator 在这样的 reward 指导下有一些提升,此后的 reward 依旧很小。
- Search complexity:在以SeqGAN为代表的用RNN作为生成器G的一类的工作中,对于 Reward 的评估都是基于句级别的,也就是会先使用 Monte Carlo Search 的方法将句子进行补全再交给 Discriminator,但是这个采样方法的时间复杂度是 O(nmL2),其中 n 是 batch size,m 是采样的次数,L 是句子的 max len。就 SeqGAN 的实验来说,每次计算 reward 就会来带很大的开销。
0x2:GAN In NLP的主要发展思路和方向
基本上说,学术界对文本的看法是将其是做一个时序依赖的序列,所以主流方向是使用RNN/LSTM这类模型作为生成器来生成伪造文本序列。而接下要要解决的重点问题是,如何有效地将判别器的反馈有效地传递给生成器。
增加reward signal强度和平滑度:从这一点出发,现有不少工作一方法不再使用简单的 fake/true probability 作为 reward。
LeakyGAN(把 CNN 的 feature 泄露给 generator),RankGAN (用 IR 中的排序作为 reward)等工作来提供更加丰富的 reward;
另一个解决的思路是使用 language model-based discriminator,以提供更多的区分度,北大孙栩老师组的 DP-GAN 在使用了 Languag model discrminator 之后,在 true data 和 fake data 中间架起了一座桥梁:
- 使用
离散数据的可导的损失函数:通过改造原始softmax函数,使用新的gumble softmax,它可以代替policy gradient,直接可导了。
- 使用RL提供梯度反馈:使用RL的
policy gradient代替原始gradient,将reward传导回去,这是现在比较主流的做法
Relevant Link:
- https://github.com/LantaoYu/SeqGAN
- https://zhuanlan.zhihu.com/p/25168509
- https://tobiaslee.top/2018/09/30/Text-Generation-with-GAN/
- https://zhuanlan.zhihu.com/p/36880287
- https://www.jianshu.com/p/32e164883eab
生成对抗网络(Generative Adversarial Networks,GAN)初探的更多相关文章
- 生成对抗网络 Generative Adversarial Networks
转自:https://zhuanlan.zhihu.com/p/26499443 生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow教授和他的学生在2014年提出的机器学习架构. 要全面理 ...
- 生成对抗网络(Generative Adversarial Networks, GAN)
生成对抗网络(Generative Adversarial Networks, GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的学习方法之一. GAN 主要包括了两个部分,即 ...
- Generative Adversarial Networks,gan论文的畅想
前天看完Generative Adversarial Networks的论文,不知道有什么用处,总想着机器生成的数据会有机器的局限性,所以百度看了一些别人 的看法和观点,可能我是机器学习小白吧,看完之 ...
- 对抗生成网络 Generative Adversarial Networks
1. Basic idea 基本任务:要得到一个generator,能够模拟想要的数据分布.(一个低维向量到一个高维向量的映射) discriminator就像是一个score function. 如 ...
- 生成对抗网络(GAN)
GAN的全称是 Generative Adversarial Networks,中文名称是生成对抗网络.原始的GAN是一种无监督学习方法,巧妙的利用“博弈”的思想来学习生成式模型. 1 GAN的原理 ...
- 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】
本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...
- 生成对抗网络(Generative Adversarial Network)阅读笔记
笔记持续更新中,请大家耐心等待 首先需要大概了解什么是生成对抗网络,参考维基百科给出的定义(https://zh.wikipedia.org/wiki/生成对抗网络): 生成对抗网络(英语:Gener ...
- 一文读懂对抗生成学习(Generative Adversarial Nets)[GAN]
一文读懂对抗生成学习(Generative Adversarial Nets)[GAN] 0x00 推荐论文 https://arxiv.org/pdf/1406.2661.pdf 0x01什么是ga ...
- 生成对抗网络(GAN)相关链接汇总
1.基础知识 创始人的介绍: “GANs之父”Goodfellow 38分钟视频亲授:如何完善生成对抗网络?(上) “GAN之父”Goodfellow与网友互动:关于GAN的11个问题(附视频) 进一 ...
随机推荐
- 【POJ - 3723 】Conscription(最小生成树)
Conscription Descriptions 需要征募女兵N人,男兵M人. 每招募一个人需要花费10000美元. 如果已经招募的人中有一些关系亲密的人,那么可以少花一些钱. 给出若干男女之前的1 ...
- Sentinel Cluster流程分析
前面介绍了sentinel-core的流程,提到在进行流控判断时,会判断当前是本地限流,还是集群限流,若是集群模式,则会走另一个分支,这节便对集群模式做分析. 一.基本概念 namespace:限 ...
- jQuery常用方法(二)-事件
ready(fn); $(document).ready()注意在body中没有onload事件,否则该函数不能执行.在每个页面中可以 有很多个函数被加载执行,按照fn的顺序来执行. bind( ty ...
- Flask框架踩坑之ajax跨域请求
业务场景: 前后端分离需要对接数据接口. 接口测试是在postman做的,今天才开始和前端对接,由于这是我第一次做后端接口开发(第一次嘛,问题比较多)所以在此记录分享我的踩坑之旅,以便能更好的理解,应 ...
- Linux之修改系统密码
目录 Linux之修改系统密码 参考 RHEL6修改系统密码 RHEL7修改系统密码 Linux之修改系统密码
- linux分析工具之vmstat详解
一.概述 vmstat命令是最常见的Linux/Unix监控工具,可以展现给定时间间隔的服务器的状态值,包括服务器的CPU使用率,内存使用,虚拟内存交换情况,IO读写情况.首先我们查看下帮助.如下图所 ...
- 由std::once_call 引发的单例模式的再次总结,基于C++11
一个偶然的机会,知道了std::once_call这个东西. 了解了下,std::once_call支持多线程情况下的某函数只执行一次.咦,这个不是恰好符合单例模式的多线程安全的困境吗? 单例模式,经 ...
- mysql库复制
一.使用navicate复制mysql库 二.使用命令 通过命令:1.创建新数据库CREATE DATABASE `newdb` DEFAULT CHARACTER SET UTF8 COLLATE ...
- ELK 学习笔记之 Logstash之filter配置
Logstash之filter: json filter: input{ stdin{ } } filter{ json{ source => "message" } } o ...
- css3:bacground-size
个人博客: https://chenjiahao.xyz CSS3之背景尺寸Background-size是CSS3中新加的一个有关背景的属性,这个属性是改变背景尺寸的通过各种不同是属性值改变背景尺寸 ...