Wasserstein Generative Adversarial Nets (WGAN ) and CGAN
GAN目前是机器学习中非常受欢迎的研究方向。主要包括有两种类型的研究,一种是将GAN用于有趣的问题,另一种是试图增加GAN的模型稳定性。
事实上,稳定性在GAN训练中是非常重要的。起初的GAN模型在训练中存在一些问题,e.g., 模式塌陷(生成器演化成非常窄的分布,只覆盖数据分布中的单一模式)。模式塌陷的含义是发生器只能产生非常相似的样本(例如MNIST中的单个数字),即所产生的样本不是多样的。这当然违反了GAN的初衷。
GAN中的另一个问题是没有指很好的指标或度量说明模型的收敛性。生成器和鉴别器的损失并没有告诉我们关于这方面的任何信息。当然,我们可以通过查看生成器产生的数据来监控训练过程。但是,这是一个愚蠢的手动过程。所以,我们需要一个可解释的指标告诉我们训练过程的好坏。
Wasserstein GAN
Wasserstein GAN(WGAN)是一种新提出的GAN算法,可以在一定程度解决上述两个问题。对于WGAN背后的直觉和理论背景,可以查看相关资料。
整个算法的伪代码如下:

- 损失函数中没有log。判别器D(X)的输出不再是一个概率(标量),同时也就意味着没有sigmoid激活函数
- 对于判别器D(X)的权重W进行裁剪
- 训练判别器的次数多于生成器
- 采用RMSProp优化器,代替原先的ADAM优化器
- 非常低的learning rate, α=0.00005
WGAN TensorFlow implementation
GAN的基本实现可以在上一篇文章中介绍过。 我们只需要稍微修改下传统的GAN。 首先,让我们更新我们的判别器D(X)
""" Vanilla GAN """
def discriminator(x):
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
out = tf.matmul(D_h1, D_W2) + D_b2
return tf.nn.sigmoid(out) """ WGAN """
def discriminator(x):
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
out = tf.matmul(D_h1, D_W2) + D_b2
return out
接下来,修改loss函数,去掉log:
""" Vanilla GAN """
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake)) """ WGAN """
D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)
G_loss = -tf.reduce_mean(D_fake)
在每次梯度下降更新后,裁剪判别器D(X)的权重:
# theta_D is list of D's params
clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in theta_D]
然后,只需要训练更多次的判别器D(X)就行了
D_solver = (tf.train.RMSPropOptimizer(learning_rate=5e-5)
.minimize(-D_loss, var_list=theta_D))
G_solver = (tf.train.RMSPropOptimizer(learning_rate=5e-5)
.minimize(G_loss, var_list=theta_G)) for it in range(1000000):
for _ in range(5):
X_mb, _ = mnist.train.next_batch(mb_size) _, D_loss_curr, _ = sess.run([D_solver, D_loss, clip_D], feed_dict={X: X_mb, z: sample_z(mb_size, z_dim)}) _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={z: sample_z(mb_size, z_dim)})
Conditional GAN
这里顺便简短的介绍下CGAN。

只需要在判别器D(X)和生成器G(Z)中的输入层额外拼接上向量y就可以了
额外的输入y:
y = tf.placeholder(tf.float32, shape=[None, y_dim])
再将它加入到判别器D(X)和生成器G(Z)中:
def generator(z, y):
# Concatenate z and y
inputs = tf.concat(concat_dim=1, values=[z, y]) G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob) return G_prob def discriminator(x, y):
# Concatenate x and y
inputs = tf.concat(concat_dim=1, values=[x, y]) D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit) return D_prob, D_logit
改变权重的维数:
# Modify input to hidden weights for discriminator
D_W1 = tf.Variable(shape=[X_dim + y_dim, h_dim])) # Modify input to hidden weights for generator
G_W1 = tf.Variable(shape=[Z_dim + y_dim, h_dim]))
构建新的网络:
# Add additional parameter y into all networks
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)
训练时,额外加入y即可:
X_mb, y_mb = mnist.train.next_batch(mb_size) Z_sample = sample_Z(mb_size, Z_dim)
_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})
接下来进行生成器验证的时候,可以固定y的值:
n_sample = 16
Z_sample = sample_Z(n_sample, Z_dim) # Create conditional one-hot vector, with index 5 = 1
y_sample = np.zeros(shape=[n_sample, y_dim])
y_sample[:, 7] = 1 samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})






PS:用下面的loss函数,收敛特别快,效果会更加好。
D_loss_real=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real,labels=tf.ones_like(D_real)))
D_loss_fake=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake,labels=tf.zeros_like(D_fake)))
D_loss=D_loss_real+D_loss_fake
G_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake,labels=tf.ones_like(D_fake)))
Wasserstein Generative Adversarial Nets (WGAN ) and CGAN的更多相关文章
- Generative Adversarial Nets[Wasserstein GAN]
本文来自<Wasserstein GAN>,时间线为2017年1月,本文可以算得上是GAN发展的一个里程碑文献了,其解决了以往GAN训练困难,结果不稳定等问题. 1 引言 本文主要思考的是 ...
- Generative Adversarial Nets[content]
0. Introduction 基于纳什平衡,零和游戏,最大最小策略等角度来作为GAN的引言 1. GAN GAN开山之作 图1.1 GAN的判别器和生成器的结构图及loss 2. Condition ...
- Generative Adversarial Nets[BEGAN]
本文来自<BEGAN: Boundary Equilibrium Generative Adversarial Networks>,时间线为2017年3月.是google的工作. 作者提出 ...
- Generative Adversarial Nets[Pre-WGAN]
本文来自<towards principled methods for training generative adversarial networks>,时间线为2017年1月,第一作者 ...
- (转)Deep Learning Research Review Week 1: Generative Adversarial Nets
Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...
- Generative Adversarial Nets[pix2pix]
本文来自<Image-to-Image Translation with Conditional Adversarial Networks>,是Phillip Isola与朱俊彦等人的作品 ...
- Generative Adversarial Nets(原生GAN学习)
学习总结于国立台湾大学 :李宏毅老师 Author: Ian Goodfellow • Paper: https://arxiv.org/abs/1701.00160 • Video: https:/ ...
- GAN(Generative Adversarial Nets)的发展
GAN(Generative Adversarial Nets),产生式对抗网络 存在问题: 1.无法表示数据分布 2.速度慢 3.resolution太小,大了无语义信息 4.无reference ...
- 论文笔记之:Conditional Generative Adversarial Nets
Conditional Generative Adversarial Nets arXiv 2014 本文是 GANs 的拓展,在产生 和 判别时,考虑到额外的条件 y,以进行更加"激烈 ...
随机推荐
- leetcode 140 单词拆分2 word break II
单词拆分2,递归+dp, 需要使用递归,同时使用记忆化搜索保存下来结果,c++代码如下 class Solution { public: //定义一个子串和子串拆分(如果有的话)的映射 unorder ...
- LinkedList简介
原文:https://blog.csdn.net/GongchuangSu/article/details/51527042 LinkedList简介 LinkedList 是一个继承于Abstrac ...
- JavaScript中二进制与10进制互相转换
webpack打包生成的代码中涉及了一些二进制位与的操作, 所以今天来学习一下JavaScript中的二进制与十进制转换操作吧 十进制转二进制: var num = 100 num.toString( ...
- fixture详细介绍-作为参数传入,error和failed区别
前言 fixture是pytest的核心功能,也是亮点功能,熟练掌握fixture的使用方法,pytest用起来才会得心应手! fixture简介 fixture的目的是提供一个固定基线,在该基线上测 ...
- 【HTTP】四、HTTP协议常见问题
HTTP协议是一个非常重要的应用层协议,在面试中有很多关于这方面的问题,这里做一个总结,大部分都在前面的文章中提到了,没提到的这里做一个介绍. 1.HTTP协议的基本原理.工作流程 HTTP协 ...
- POJ3585 Accumulation Degree【换根dp】
题目传送门 题意 给出一棵树,树上的边都有容量,在树上任意选一个点作为根,使得往外流(到叶节点,叶节点可以接受无限多的流量)的流量最大. 分析 首先,还是从1号点工具人开始$dfs$,可以求出$dp[ ...
- 【Python开发】【神经网络与深度学习】如何利用Python写简单网络爬虫
平时没事喜欢看看freebuf的文章,今天在看文章的时候,无线网总是时断时续,于是自己心血来潮就动手写了这个网络爬虫,将页面保存下来方便查看 先分析网站内容,红色部分即是网站文章内容div,可以看 ...
- (转)在Kubernetes集群中使用JMeter对Company示例进行压力测试
背景 压力测试是评估应用性能的一种有效手段.此外,越来越多的应用被拆分为多个微服务而每个微服务的性能不一,有的微服务是计算密集型,有的是IO密集型. 因此,压力测试在基于微服务架构的网络应用中扮演着越 ...
- flask 重定向详解
from flask import Flask,request,redirect,url_for app = Flask(__name__) @app.route('/') def hello_wor ...
- python 并发编程 多线程 GIL与Lock
GIL与Lock Python已经有一个GIL来保证同一时间只能有一个线程来执行了,为什么这里还需要互斥锁lock? 锁的目的是为了保护共享的数据,同一时间只能有一个线程来修改共享的数据 GIT保证了 ...