GAN由论文《Ian Goodfellow et al., “Generative Adversarial Networks,” arXiv (2014)》提出。

GAN与VAEs的区别

GANs require differentiation through the visible units, and thus cannot model discrete data,

while VAEs require differentiation through the hidden units, and thus cannot have discrete latent variables.

即GAN不能处理离散数据,VAEs不能处理离散隐空间变量。

训练过程

常见模型是最小化一个loss,GAN里的生成器和鉴别器则是一个minmax操作,即

同时,生成器更新一次后,鉴别器应该更新多次,这样保证鉴别器可以维持在最优解附近。

如果生成器连续多次更新,而鉴别器不更新,则生成器倾向于生成那些“为难”鉴别器的同一批样本,这样生成器就缺乏多样性。

论文中给出的算法流程(简单的一次生成器更新对应多次鉴别器更新):

一些细节:

生成器使用relu和sigmoid激活函数,鉴别器使用maxout激活函数,Dropout只添加于鉴别器。


本文代码使用的一些trick:

  • 生成器最后的激活函数使用tanh代替sigmoid
  • 隐空间中使用正态分布去采样
  • 添加随机性因素。GAN是非常难以训练的,添加一些噪音可以让训练不会轻易卡主。除了Dropout外,此处对鉴别器判断的标签也添加随机噪音。
  • 稀疏梯度(Sparse gradients)在一些网络中通常是渴求的目标。但在GAN中,它会妨碍训练过程。所以将maxpool替换为带stride的卷积层,并使用leakyRELU代替relu激活函数。
  • 为了避免产生的图像如棋盘状(即一个个正方形像素块,而非连续流畅的像素),设定卷积窗口大小为步长的整数倍。
  • 优化器使用的是RMSprop,并使用梯度裁剪和梯度衰减。

训练过程为:

数据集为cifar10

定义生成器网络,输入为隐空间中一个矢量,输出为一个图片。

定义鉴别器网络,输入为生成器网络采样所得的图片和真实图片(以及标签),输出为sigmoid激活函数的标量值,即判断图片为真实还是伪造。

定义生成对抗网络,为D(G(x))即生成网络与鉴别网络的嵌套形式。输入为生成网络的输入,输出为鉴别器网络的输出。

训练时,使用高斯分布从隐空间中采样,经过生成网络得到生成的图片,与真实图片混合后(以及标签)作为鉴别器网络的输入。

先训练鉴别器。然后重新采样生成图片,此时需将这些图片的标签置为真实图片的标签(固定标签后,训练生成器,即让其参数调整到鉴别器都以为确实是真实图片)。再训练GAN(此时冻结鉴别器参数,训练的只是生成器)

可以看到,定义了3个模型,只是因为生成器网络的训练要基于鉴别器网络进行。


代码如下

import numpy as np
from keras.datasets import cifar10
from keras.models import Model
from keras.layers import Input,Dense,LeakyReLU,Reshape,Conv2D,Conv2DTranspose,Flatten,Dropout
from keras.optimizers import RMSprop
from keras.preprocessing import image
import os latent_dim=32
# Cifar10图片尺寸
height,width=(32,32)
channels=3

3个网络定义

# 生成网络:将隐空间中矢量生成图片,使用Conv2DTranspose
generator_input=Input((latent_dim,))
x=Dense(128*16*16)(generator_input)
# 只添加了一个alpha参数,其他地方跟书上一致,alpha默认0.3
x=LeakyReLU(alpha=0.1)(x)
x=Reshape((16,16,128))(x)
x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x)
# 结果为32*32*256,为避免生成图片呈现棋盘的点阵格式,凡是使用strides的地方,窗口大小为strides的整数倍
x=Conv2DTranspose(256,4,strides=2,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x) x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x)
x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x) # 结果为32*32*3,即一个图片正确格式。使用tanh代替sigmoid
x=Conv2D(channels,7,activation='tanh',padding='same')(x)
generator=Model(generator_input,x)#它在包含在GAN里训练的,所以这里不用编译
# generator.summary() # 鉴别网络
discriminator_input=Input((height,width,channels))
x=Conv2D(128,3)(discriminator_input)
x=LeakyReLU(alpha=0.1)(x) x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
# 2*2*128
x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
x=Flatten()(x)
# Dropout和给标签添加噪声,可以避免GAN卡住
x=Dropout(0.4)(x)
x=Dense(1,activation='sigmoid')(x) discriminator=Model(discriminator_input,x)
# discriminator.summary() # clipvalue,梯度超过这个值就截断,decay,衰减,使得训练稳定
discriminator_optimizer=RMSprop(lr=0.0003,clipvalue=1.0,decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer,loss='binary_crossentropy') # 最后的生成对抗网络,由生成网络与对抗网络组合而成,此时冻结鉴别网络,训练的只是生成网络
discriminator.trainable=False
# 组成整个生成对抗网络
gan_input=Input((latent_dim,))
# 最终网络形式为鉴别网络作用于生成网络,故生成器也不用compile
gan_output=discriminator(generator(gan_input))
gan_optimizer=RMSprop(lr=0.0004,clipvalue=1.0,decay=1e-8)
gan=Model(gan_input,gan_output)
gan.compile(optimizer=gan_optimizer,loss='binary_crossentropy')

训练过程,此处并未使用多次鉴别器更新一次生成器更新,你可以自己调整(即循环里面开头添加个循环,训练鉴别器)。

(x_train,y_train),(x_test,y_test)=cifar10.load_data()
# 选择frog类别,总共10个类
x_train=x_train[y_train.flatten()==6]
# reshape到输入格式 nums*height*width*channels,像素归一化
x_train=x_train.reshape((x_train.shape[0],)+(height,width,channels)).astype('float32')/255.
iters=10000
batch_size=20
save_dir='frog' start=0
for step in range(iters):
# 选取潜空间中随机矢量(正态分布)
random_latent_vec=np.random.normal(size=(batch_size,latent_dim))
# 生成网络产生图片
generated_images=generator.predict(random_latent_vec)
stop=start+batch_size
# 真实原始图片
real_images=x_train[start:stop]
# mix生成和真实图片
combined_images=np.concatenate([generated_images,real_images])
# mix labels
labels=np.concatenate([np.ones((batch_size,1)),np.zeros((batch_size,1))])
# trick:标签添加随机噪声
labels+=0.05*np.random.random(labels.shape)
# 鉴别loss,可能为负,因为使用的是LeakyReLU
d_loss=discriminator.train_on_batch(combined_images,labels)
# 重新生成随机矢量
random_latent_vec=np.random.normal(size=(batch_size,latent_dim))
# 故意设置标签为真实
misleading_targets=np.zeros((batch_size,1))
a_loss=gan.train_on_batch(random_latent_vec,misleading_targets)
start+=batch_size
if start>len(x_train)-batch_size:
start=0
if step%100==0:
# gan.save_weights('gan.h5')
print('discriminator loss:',d_loss)
print('adversarial loss:',a_loss)
# 保存一个batch里的第一个图片,之前像素归一化了,这里乘以255还原
img=image.array_to_img(generated_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'generated_frog'+str(step)+'.png'))
# 保存一个对比图片
img=image.array_to_img(real_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'real_frog'+str(step)+'.png'))

loss变化趋势,可以看到是不稳定的

看真实图和生成图片对比,上下2行图片只是同一批保存的,没有相关性。这是训练4000步,也即80000个训练样本后的结果。看起来比较丑陋吧。

GAN(生成对抗网络)之keras实践的更多相关文章

  1. 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】

    本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...

  2. 生成对抗网络(Generative Adversarial Networks,GAN)初探

    1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...

  3. GAN实战笔记——第四章深度卷积生成对抗网络(DCGAN)

    深度卷积生成对抗网络(DCGAN) 我们在第3章实现了一个GAN,其生成器和判别器是具有单个隐藏层的简单前馈神经网络.尽管很简单,但GAN的生成器充分训练后得到的手写数字图像的真实性有些还是很具说服力 ...

  4. GAN实战笔记——第六章渐进式增长生成对抗网络(PGGAN)

    渐进式增长生成对抗网络(PGGAN) 使用 TensorFlow和 TensorFlow Hub( TFHUB)构建渐进式增长生成对抗网络( Progressive GAN, PGGAN或 PROGA ...

  5. [ZZ] Valse 2017 | 生成对抗网络(GAN)研究年度进展评述

    Valse 2017 | 生成对抗网络(GAN)研究年度进展评述 https://www.leiphone.com/news/201704/fcG0rTSZWqgI31eY.html?viewType ...

  6. 生成对抗网络(GAN)

    基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...

  7. TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成

    生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...

  8. AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华

    注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...

  9. 生成对抗网络GAN介绍

    GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...

随机推荐

  1. Java笔记(基础第四篇)

    Java集合类 集合类概述 Java 语言的java.util包中提供了一些集合类,这些集合类又被称为容器.常用的集合有List集合.Set集合.Map集合,其中List与Set实现了Collecti ...

  2. DP(第一版)

    序 任何一种具有递推或者递归形式的计算过程,都叫做动态规划 如果你一开始学的时候就不会DP,那么你在考试的时候就一定不会想到用动态规划! 需要进行掌握的内容 1)DP中的基本概念 2)状态 3)转移方 ...

  3. jquery.qrcode.js生成二维码(前端生成二维码)

    官网地址:http://jeromeetienne.github.io/jquery-qrcode/ 第一步引入插件: <script type='text/javascript' src='h ...

  4. Python基础之range()

    range:指定范围,生成指定数字. 1. range() for i in range(1, 10): print(i) 执行结果为: 1 2 3 4 5 6 7 8 9 2. range()步长 ...

  5. POJ 3177 (Redundant Paths) —— (有重边,边双联通,无向图缩点)

    做到这里以后,总算是觉得tarjan算法已经有点入门了. 这题的题意是,给出若干个点和若干条边连接他们,在这个无向图中,问至少增加多少条边可以使得这个图变成边双联通图(即任意两点间都有至少两条没有重复 ...

  6. OSX 改变PHP安装路径环境变量

    当使用XAMPP来学习Laravel的时候,用composer安装laravel总是报错,说mcrypt is required ,但是当我在终端里打印 which php 显示的是usr/bin/p ...

  7. Go http包执行流程

    Go 语言实现的 Web 服务工作方式与其他形式下的 Web 工作方式并没有什么不同,具体流程如下: -- http包执行流程 Request:来自用户的请求信息,包括 post.get.Cookie ...

  8. H5-Mui框架——修改mui.confirm样式

    问题简述: 使用mui框架默认提示框时,感觉与整体布局不符,因此想要更改其中的样式. 首先,查了一下资料:mui.toast样式风格及位置修改教程 以下是转载过来的文章内容. ============ ...

  9. vue 无法覆盖vant的UI组件的样式

    vue 无法覆盖vant的UI组件的样式 有时候UI组件提供的默认的样式不能满足项目的需要,就需要我们对它的样式进行修改,但是发现加了scoped后修改的样式不起作用. 解决方法: 使用深度选择器,将 ...

  10. JAVA字符串处理函数列表一览

    JAVA字符串处理函数列表一览   Java中的字符串也是一连串的字符.但是与许多其他的计算机语言将字符串作为字符数组处理不同,Java将字符串作为String类型对象来处理.将字符串作为内置的对象处 ...