GAN由论文《Ian Goodfellow et al., “Generative Adversarial Networks,” arXiv (2014)》提出。

GAN与VAEs的区别

GANs require differentiation through the visible units, and thus cannot model discrete data,

while VAEs require differentiation through the hidden units, and thus cannot have discrete latent variables.

即GAN不能处理离散数据,VAEs不能处理离散隐空间变量。

训练过程

常见模型是最小化一个loss,GAN里的生成器和鉴别器则是一个minmax操作,即

同时,生成器更新一次后,鉴别器应该更新多次,这样保证鉴别器可以维持在最优解附近。

如果生成器连续多次更新,而鉴别器不更新,则生成器倾向于生成那些“为难”鉴别器的同一批样本,这样生成器就缺乏多样性。

论文中给出的算法流程(简单的一次生成器更新对应多次鉴别器更新):

一些细节:

生成器使用relu和sigmoid激活函数,鉴别器使用maxout激活函数,Dropout只添加于鉴别器。


本文代码使用的一些trick:

  • 生成器最后的激活函数使用tanh代替sigmoid
  • 隐空间中使用正态分布去采样
  • 添加随机性因素。GAN是非常难以训练的,添加一些噪音可以让训练不会轻易卡主。除了Dropout外,此处对鉴别器判断的标签也添加随机噪音。
  • 稀疏梯度(Sparse gradients)在一些网络中通常是渴求的目标。但在GAN中,它会妨碍训练过程。所以将maxpool替换为带stride的卷积层,并使用leakyRELU代替relu激活函数。
  • 为了避免产生的图像如棋盘状(即一个个正方形像素块,而非连续流畅的像素),设定卷积窗口大小为步长的整数倍。
  • 优化器使用的是RMSprop,并使用梯度裁剪和梯度衰减。

训练过程为:

数据集为cifar10

定义生成器网络,输入为隐空间中一个矢量,输出为一个图片。

定义鉴别器网络,输入为生成器网络采样所得的图片和真实图片(以及标签),输出为sigmoid激活函数的标量值,即判断图片为真实还是伪造。

定义生成对抗网络,为D(G(x))即生成网络与鉴别网络的嵌套形式。输入为生成网络的输入,输出为鉴别器网络的输出。

训练时,使用高斯分布从隐空间中采样,经过生成网络得到生成的图片,与真实图片混合后(以及标签)作为鉴别器网络的输入。

先训练鉴别器。然后重新采样生成图片,此时需将这些图片的标签置为真实图片的标签(固定标签后,训练生成器,即让其参数调整到鉴别器都以为确实是真实图片)。再训练GAN(此时冻结鉴别器参数,训练的只是生成器)

可以看到,定义了3个模型,只是因为生成器网络的训练要基于鉴别器网络进行。


代码如下

import numpy as np
from keras.datasets import cifar10
from keras.models import Model
from keras.layers import Input,Dense,LeakyReLU,Reshape,Conv2D,Conv2DTranspose,Flatten,Dropout
from keras.optimizers import RMSprop
from keras.preprocessing import image
import os latent_dim=32
# Cifar10图片尺寸
height,width=(32,32)
channels=3

3个网络定义

# 生成网络:将隐空间中矢量生成图片,使用Conv2DTranspose
generator_input=Input((latent_dim,))
x=Dense(128*16*16)(generator_input)
# 只添加了一个alpha参数,其他地方跟书上一致,alpha默认0.3
x=LeakyReLU(alpha=0.1)(x)
x=Reshape((16,16,128))(x)
x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x)
# 结果为32*32*256,为避免生成图片呈现棋盘的点阵格式,凡是使用strides的地方,窗口大小为strides的整数倍
x=Conv2DTranspose(256,4,strides=2,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x) x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x)
x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x) # 结果为32*32*3,即一个图片正确格式。使用tanh代替sigmoid
x=Conv2D(channels,7,activation='tanh',padding='same')(x)
generator=Model(generator_input,x)#它在包含在GAN里训练的,所以这里不用编译
# generator.summary() # 鉴别网络
discriminator_input=Input((height,width,channels))
x=Conv2D(128,3)(discriminator_input)
x=LeakyReLU(alpha=0.1)(x) x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
# 2*2*128
x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
x=Flatten()(x)
# Dropout和给标签添加噪声,可以避免GAN卡住
x=Dropout(0.4)(x)
x=Dense(1,activation='sigmoid')(x) discriminator=Model(discriminator_input,x)
# discriminator.summary() # clipvalue,梯度超过这个值就截断,decay,衰减,使得训练稳定
discriminator_optimizer=RMSprop(lr=0.0003,clipvalue=1.0,decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer,loss='binary_crossentropy') # 最后的生成对抗网络,由生成网络与对抗网络组合而成,此时冻结鉴别网络,训练的只是生成网络
discriminator.trainable=False
# 组成整个生成对抗网络
gan_input=Input((latent_dim,))
# 最终网络形式为鉴别网络作用于生成网络,故生成器也不用compile
gan_output=discriminator(generator(gan_input))
gan_optimizer=RMSprop(lr=0.0004,clipvalue=1.0,decay=1e-8)
gan=Model(gan_input,gan_output)
gan.compile(optimizer=gan_optimizer,loss='binary_crossentropy')

训练过程,此处并未使用多次鉴别器更新一次生成器更新,你可以自己调整(即循环里面开头添加个循环,训练鉴别器)。

(x_train,y_train),(x_test,y_test)=cifar10.load_data()
# 选择frog类别,总共10个类
x_train=x_train[y_train.flatten()==6]
# reshape到输入格式 nums*height*width*channels,像素归一化
x_train=x_train.reshape((x_train.shape[0],)+(height,width,channels)).astype('float32')/255.
iters=10000
batch_size=20
save_dir='frog' start=0
for step in range(iters):
# 选取潜空间中随机矢量(正态分布)
random_latent_vec=np.random.normal(size=(batch_size,latent_dim))
# 生成网络产生图片
generated_images=generator.predict(random_latent_vec)
stop=start+batch_size
# 真实原始图片
real_images=x_train[start:stop]
# mix生成和真实图片
combined_images=np.concatenate([generated_images,real_images])
# mix labels
labels=np.concatenate([np.ones((batch_size,1)),np.zeros((batch_size,1))])
# trick:标签添加随机噪声
labels+=0.05*np.random.random(labels.shape)
# 鉴别loss,可能为负,因为使用的是LeakyReLU
d_loss=discriminator.train_on_batch(combined_images,labels)
# 重新生成随机矢量
random_latent_vec=np.random.normal(size=(batch_size,latent_dim))
# 故意设置标签为真实
misleading_targets=np.zeros((batch_size,1))
a_loss=gan.train_on_batch(random_latent_vec,misleading_targets)
start+=batch_size
if start>len(x_train)-batch_size:
start=0
if step%100==0:
# gan.save_weights('gan.h5')
print('discriminator loss:',d_loss)
print('adversarial loss:',a_loss)
# 保存一个batch里的第一个图片,之前像素归一化了,这里乘以255还原
img=image.array_to_img(generated_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'generated_frog'+str(step)+'.png'))
# 保存一个对比图片
img=image.array_to_img(real_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'real_frog'+str(step)+'.png'))

loss变化趋势,可以看到是不稳定的

看真实图和生成图片对比,上下2行图片只是同一批保存的,没有相关性。这是训练4000步,也即80000个训练样本后的结果。看起来比较丑陋吧。

GAN(生成对抗网络)之keras实践的更多相关文章

  1. 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】

    本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...

  2. 生成对抗网络(Generative Adversarial Networks,GAN)初探

    1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...

  3. GAN实战笔记——第四章深度卷积生成对抗网络(DCGAN)

    深度卷积生成对抗网络(DCGAN) 我们在第3章实现了一个GAN,其生成器和判别器是具有单个隐藏层的简单前馈神经网络.尽管很简单,但GAN的生成器充分训练后得到的手写数字图像的真实性有些还是很具说服力 ...

  4. GAN实战笔记——第六章渐进式增长生成对抗网络(PGGAN)

    渐进式增长生成对抗网络(PGGAN) 使用 TensorFlow和 TensorFlow Hub( TFHUB)构建渐进式增长生成对抗网络( Progressive GAN, PGGAN或 PROGA ...

  5. [ZZ] Valse 2017 | 生成对抗网络(GAN)研究年度进展评述

    Valse 2017 | 生成对抗网络(GAN)研究年度进展评述 https://www.leiphone.com/news/201704/fcG0rTSZWqgI31eY.html?viewType ...

  6. 生成对抗网络(GAN)

    基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...

  7. TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成

    生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...

  8. AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华

    注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...

  9. 生成对抗网络GAN介绍

    GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...

随机推荐

  1. 第40题:组合总和II

    一.问题描述: 给定一个数组 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合. candidates 中的每个数字在每个组合 ...

  2. Beep调用系统声音

    using System.Runtime.InteropServices; 引用命名空间 [DllImport("kernel32.dll")]public   static   ...

  3. CodeForces 839D - Winter is here | Codeforces Round #428 (Div. 2)

    赛后听 Forever97 讲的思路,强的一匹- - /* CodeForces 839D - Winter is here [ 数论,容斥 ] | Codeforces Round #428 (Di ...

  4. docker管理

    查看容器名 [root@docker ~]# docker inspect -f "{{.Name}}" a2f /u1 停止/启动终止状态的容器 [root@docker ~]# ...

  5. 编码问题2 utf-8和Unicode的区别

    utf-8和Unicode到底有什么区别?是存储方式不同?编码方式不同?它们看起来似乎很相似,但是实际上他们并不是同一个层次的概念 要想先讲清楚他们的区别,首先应该讲讲Unicode的来由. 众所周知 ...

  6. NVMe Windows 支持情况

    From NVMe 官网: Windows Driver – Microsoft Inbox • Closed source driver (Microsoft)• Inbox driver to W ...

  7. 2019.6.20 校内测试 NOIP模拟 Day 1 分析+题解

    这次是zay神仙给我们出的NOIP模拟题,不得不说好难啊QwQ,又倒数了~ T1 大美江湖 这个题是一个简单的模拟题.   ----zay 唯一的坑点就是打怪的时候计算向上取整时,如果用ceil函数一 ...

  8. spoj5973

    SP5973 SELTEAM - Selecting Teams #include <bits/stdc++.h> using namespace std; typedef long lo ...

  9. vuecli3.0 webpack4 配置vuex

    废话不说,直接写步骤 1. npm install vux --save 2. npm install less less-loader --save-dev 3. npm install @vux/ ...

  10. mac使用frida

    mac使用frida 安装 https://github.com/frida/frida/releases 根据手机的cpu的版本,选择相应的文件,一般通过手机信息可以看到 我这里是frida-ser ...