WGAN的改进点和实操
包含三部分:1、WGAN改进点 2、代码修改 3、训练心得
一、WGAN的改进部分:
- 判别器最后一层去掉sigmoid (相当于最后一层做了一个y = x的激活)
- 生成器和判别器的loss不取log
- 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
- 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行 (这部分很玄学)
去掉sigmoid会出现什么问题?
优点: 去掉sigmoid 只要二者存在差值就会学习让他们尽量小
缺点:去掉sigmoid 判别器的输出会到无穷大 生成器也会到无穷大(只要二者的差值很小就满足条件)无法优化。
(公式1)
如何解决(上述)无法优化问题(loss可能一直上升)?
这就是WGAN的第三个改进点。(每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c)
(公式2)(作者用这个公式来表达,证明过程再论文附录中)
详细解读(这部分参看:https://blog.csdn.net/omnispace/article/details/54942668)
分析
首先需要介绍一个概念——Lipschitz连续。它其实就是在一个连续函数上面额外施加了一个限制,要求存在一个常数使得定义域内的任意两个元素和都满足
此时称函数的Lipschitz常数为。
简单理解,比如说的定义域是实数集合,那上面的要求就等价于的导函数绝对值不超过(这里是导数概念(f(x1) - f(x2))/(x1-x2) 为导数)。再比如说就不是Lipschitz连续,因为它的导函数没有上界。Lipschitz连续条件限制了一个连续函数的最大局部变动幅度。
公式2的意思就是在要求函数的Lipschitz常数不超过的条件下,对所有可能满足条件的取到的上界,然后再除以。特别地,我们可以用一组参数来定义一系列可能的函数,此时求解公式2可以近似变成求解如下形式
(公式3)
再用上我们搞深度学习的人最熟悉的那一套,不就可以把用一个带参数的神经网络来表示嘛!由于神经网络的拟合能力足够强大,我们有理由相信,这样定义出来的一系列虽然无法囊括所有可能,但是也足以高度近似公式2要求的那个了。
最后,还不能忘了满足公式3中这个限制。我们其实不关心具体的K是多少,只要它不是正无穷就行,因为它只是会使得梯度变大倍,并不会影响梯度的方向。所以作者采取了一个非常简单的做法,就是限制神经网络的所有参数的不超过某个范围,比如,此时关于输入样本的导数也不会超过某个范围,所以一定存在某个不知道的常数使得的局部变动幅度不会超过它,Lipschitz连续条件得以满足。具体在算法实现中,只需要每次更新完后把它clip回这个范围就可以了。
到此为止,我们可以构造一个含参数、最后一层不是非线性激活层的判别器网络,在限制不超过某个范围的条件下,使得
(公式4)
尽可能取到最大,此时就会近似真实分布与生成分布之间的Wasserstein距离(忽略常数倍数)。注意原始GAN的判别器做的是真假二分类任务,所以最后一层是sigmoid,但是现在WGAN中的判别器做的是近似拟合Wasserstein距离,属于回归任务,所以要把最后一层的sigmoid拿掉。
接下来生成器要近似地最小化Wasserstein距离,可以最小化,由于Wasserstein距离的优良性质,我们不需要担心生成器梯度消失的问题。再考虑到的第一项与生成器无关,就得到了WGAN的两个loss。
二、代码修改:
根据改进的四个部分来修改代码(TF下):
加变量:
CLIP = [-0.01, 0.01] #用来截断w(第三个改进点)
CRITIC_NUM = 5 #权衡训练次数 Discrimnator要训练的比Genenrator多(5 次Discrimnator 一次 G)
① 判别器最后一层去掉sigmoid
return tf.nn.sigmoid(h4), h4
替换后:
return h4, h4
② 生成器和判别器的loss不取log
原始的GAN loss为:
min GmaxD Ex∼q(x)[logD(x)]+Ez∼p(z)[log(1−D(G(z)))
去掉log为 min GmaxD D(x) + 1−D(G(z))
由于最大化D 我们在代码中应该加 “-” D loss: minD -(D(x) + 1−D(G(z)))
G loss minG −D(G(z))
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \
4 self.d_loss = self.d_loss_real + self.d_loss_fake
修改D loss为:
self.d_loss_real = tf.reduce_mean(self.D_logits)
self.d_loss_fake = -tf.reduce_mean(self.D_logits_)
self.d_loss = -(self.d_loss_real + self.d_loss_fake)
修改G loss为:
self.g_loss = -tf.reduce_mean(self.D_logits_)
③ ④ 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c(放到参数更新后) 修改优化器
原始:
d_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
.minimize(self.d_loss, var_list=self.d_vars)
g_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
.minimize(self.g_loss, var_list=self.g_vars)
修改为:
d_optim = tf.train.RMSPropOptimizer(args.lr, beta1=args.beta1) \
.minimize(self.d_loss, var_list=self.d_vars)
g_optim = tf.train.RMSPropOptimizer(args.lr, beta1=args.beta1) \
.minimize(self.g_loss, var_list=self.g_vars)
clip_d_op = [var.assign(tf.clip_by_value(var, CILP[0], CILP[1])) for var in self.d_vars] #进行截断
三、训练心得:
一、权重
a. 调节Generator loss中GAN loss的权重
G loss和Gan loss在一个尺度上或者G loss比Gan loss大一个尺度。但是千万不能让Gan loss占主导地位, 这样整个网络权重会被带偏。
二、训练次数
b. 调节Generator和Discrimnator的训练次数比
一般来说,Discrimnator要训练的比Genenrator多。比如训练五次Discrimnator,再训练一次Genenrator(WGAN论文 是这么干的)。
三、学习率
c. 调节learning rate
这个学习速率不能过大。一般要比Genenrator的速率小一点。
四、优化器
d. Optimizer的选择不能用基于动量法的
如Adam和momentum。可使用RMSProp或者SGD。
五、结构
e. Discrimnator的结构可以改变
如果用WGAN,判别器的最后一层需要去掉sigmoid。但是用原始的GAN,需要用sigmoid,因为其loss function里面需要取log,所以值必须在[0,1]。这里用的是邓炜的critic模型当作判别器。之前twitter的论文里面的判别器即使去掉了sigmoid也不好训练。
WGAN的改进点和实操的更多相关文章
- Golang的运算符优先级实操案例
Golang的运算符优先级实操案例 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.运算符优先级案例 运算符是用来在程序运行时执行数学或逻辑运算的,在Go语言中,一个表达式可以包 ...
- SFUD+FAL+EasyFlash典型场景需求分析,并记一次实操记录
SFUD+FAL+EasyFlash典型场景需求分析:用整个flash存储数据,上千条数据,读取得时候用easyflash很慢,估计要检索整个flash太慢了. 改进方法:分区检索. 1存数据时,根据 ...
- ABP入门系列(1)——学习Abp框架之实操演练
作为.Net工地搬砖长工一名,一直致力于挖坑(Bug)填坑(Debug),但技术却不见长进.也曾热情于新技术的学习,憧憬过成为技术大拿.从前端到后端,从bootstrap到javascript,从py ...
- 号外号外:9月13号《Speed-BI云平台案例实操--十分钟做报表》开讲了
引言:如何快速分析纷繁复杂的数据?如何快速做出老板满意的报表?如何快速将Speed-BI云平台运用到实际场景中? 本课程将通过各行各业案例背景,将Speed-BI云平台运用到实际场景中 ...
- Mysql MHA(GTID)配置(实操)
实现环境 centos6.7 MYSQL5.6.36 主:192.168.1.191 从1:192.168.1.145 从2:192.168.1.146 监测:放在从2上 192.168.1.146 ...
- Selenium之unittest测试框架详谈及实操
申明:本文是基于python3.x及selenium3.x. unittest,也可以称为PyUnit,可以用来创建全面的测试套件,可以用于单元自动化测试(模块).功能自动化测试(UI)等等. 官方文 ...
- unittest测试框架详谈及实操(二)
类级别的setUp()方法与tearDown()方法 在实操(一)的例子中,通过setUp()方法为每个测试方法都创建了一个Chrome实例,并且在每个测试方法执行结束后要关闭实例.是不是觉得有个多余 ...
- .net基础学java系列(四)Console实操
上一篇文章 .net基础学java系列(三)徘徊反思 本章节没啥营养,请绕路! 看视频,不实操,对于上了年龄的人来说,是记不住的!我已经看了几遍IDEA的教学视频: https://edu.51cto ...
- RTN 实操
创建房间 test-rtn 10001 e2uii6r7r 8LfwOcreM76OiV1V1y8jXrMG_BNa-cmktpWUznRa:kdYdsEpcYLc5ceWEHPaK0ZDI7Qc=: ...
随机推荐
- 前端开发在手机UC浏览器上遇到的坑
1.user-scalable问题 写手机页面都会加一个meta标签 <meta content="width=device-width, initial-scale=1.0, max ...
- sso和oauth2.0的简单了解学习
sso,单点登录,single sign on 缩写.sso多用于多个应用之间的切换,例如百度论坛.百度知道.百度云.百度文库等,在其中一个系统中登录,(登录有效期内)切换到另一个系统的时候,不必再次 ...
- abp(net core)+easyui+efcore仓储系统——创建应用服务(五)
abp(net core)+easyui+efcore仓储系统目录 abp(net core)+easyui+efcore仓储系统——ABP总体介绍(一) abp(net core)+easyui+e ...
- ZooKeeper学习第五期--ZooKeeper管理分布式环境中的数据(转)
转载来源:https://www.cnblogs.com/sunddenly/p/4092654.html 引言 本节本来是要介绍ZooKeeper的实现原理,但是ZooKeeper的原理比较复杂,它 ...
- mpvue 开发小程序接口数据统一管理
mpvue项目里做API与数据分离统一管理 小程序里请求数据接口使用wx:request,因为考虑项目比较大,最好把wx:request封装起来,统一使用管理 utils.js 配置开发环境和线上环境 ...
- lock和synchronized如何选择?
1.lock是一个接口,而synchronized是java关键字,synchronized是内置的语言实现. 2.synchronized在发生异常时,会自动释放线程占有的锁,因此不会导致死锁,而l ...
- tkinter + cefpython 仿美团桌面程序
使用js开发桌面程序目前是一个趋势,Electron是其中一个佼佼者,网上也不乏很多文章.今天主要是来讲一下cefpython. 用python的朋友,特别使用过tkinter开发过界面的,一定会觉得 ...
- Appium+python自动化(十一)- 元素定位秘籍助你打通任督二脉 - 下卷(超详解)
简介 宏哥看你骨骼惊奇,印堂发亮,必是练武之奇才! 按照上一篇的节目预告,这一篇还是继续由宏哥给小伙伴们分享元素定位,是不是按照上一篇的秘籍修炼,是不是感觉到头顶盖好像被掀开,内气从头上冒出去,顿时觉 ...
- Spark学习之路(八)—— Spark SQL 之 DataFrame和Dataset
一.Spark SQL简介 Spark SQL是Spark中的一个子模块,主要用于操作结构化数据.它具有以下特点: 能够将SQL查询与Spark程序无缝混合,允许您使用SQL或DataFrame AP ...
- raft算法解析
一.raft算法引入 在寻找一种易于理解的一致性算法的研究(In Search of an Understandable Consensus Algorithm-extended version) 论 ...