GAN(生成对抗网络)之keras实践
GAN由论文《Ian Goodfellow et al., “Generative Adversarial Networks,” arXiv (2014)》提出。
GAN与VAEs的区别
GANs require differentiation through the visible units, and thus cannot model discrete data,
while VAEs require differentiation through the hidden units, and thus cannot have discrete latent variables.
即GAN不能处理离散数据,VAEs不能处理离散隐空间变量。
训练过程
常见模型是最小化一个loss,GAN里的生成器和鉴别器则是一个minmax操作,即
同时,生成器更新一次后,鉴别器应该更新多次,这样保证鉴别器可以维持在最优解附近。
如果生成器连续多次更新,而鉴别器不更新,则生成器倾向于生成那些“为难”鉴别器的同一批样本,这样生成器就缺乏多样性。
论文中给出的算法流程(简单的一次生成器更新对应多次鉴别器更新):
一些细节:
生成器使用relu和sigmoid激活函数,鉴别器使用maxout激活函数,Dropout只添加于鉴别器。
本文代码使用的一些trick:
- 生成器最后的激活函数使用tanh代替sigmoid
- 隐空间中使用正态分布去采样
- 添加随机性因素。GAN是非常难以训练的,添加一些噪音可以让训练不会轻易卡主。除了Dropout外,此处对鉴别器判断的标签也添加随机噪音。
- 稀疏梯度(Sparse gradients)在一些网络中通常是渴求的目标。但在GAN中,它会妨碍训练过程。所以将maxpool替换为带stride的卷积层,并使用leakyRELU代替relu激活函数。
- 为了避免产生的图像如棋盘状(即一个个正方形像素块,而非连续流畅的像素),设定卷积窗口大小为步长的整数倍。
- 优化器使用的是RMSprop,并使用梯度裁剪和梯度衰减。
训练过程为:
数据集为cifar10
定义生成器网络,输入为隐空间中一个矢量,输出为一个图片。
定义鉴别器网络,输入为生成器网络采样所得的图片和真实图片(以及标签),输出为sigmoid激活函数的标量值,即判断图片为真实还是伪造。
定义生成对抗网络,为D(G(x))即生成网络与鉴别网络的嵌套形式。输入为生成网络的输入,输出为鉴别器网络的输出。
训练时,使用高斯分布从隐空间中采样,经过生成网络得到生成的图片,与真实图片混合后(以及标签)作为鉴别器网络的输入。
先训练鉴别器。然后重新采样生成图片,此时需将这些图片的标签置为真实图片的标签(固定标签后,训练生成器,即让其参数调整到鉴别器都以为确实是真实图片)。再训练GAN(此时冻结鉴别器参数,训练的只是生成器)。
可以看到,定义了3个模型,只是因为生成器网络的训练要基于鉴别器网络进行。
代码如下
import numpy as np
from keras.datasets import cifar10
from keras.models import Model
from keras.layers import Input,Dense,LeakyReLU,Reshape,Conv2D,Conv2DTranspose,Flatten,Dropout
from keras.optimizers import RMSprop
from keras.preprocessing import image
import os latent_dim=32
# Cifar10图片尺寸
height,width=(32,32)
channels=3
3个网络定义
# 生成网络:将隐空间中矢量生成图片,使用Conv2DTranspose
generator_input=Input((latent_dim,))
x=Dense(128*16*16)(generator_input)
# 只添加了一个alpha参数,其他地方跟书上一致,alpha默认0.3
x=LeakyReLU(alpha=0.1)(x)
x=Reshape((16,16,128))(x)
x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x)
# 结果为32*32*256,为避免生成图片呈现棋盘的点阵格式,凡是使用strides的地方,窗口大小为strides的整数倍
x=Conv2DTranspose(256,4,strides=2,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x) x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x)
x=Conv2D(256,5,padding='same')(x)
x=LeakyReLU(alpha=0.1)(x) # 结果为32*32*3,即一个图片正确格式。使用tanh代替sigmoid
x=Conv2D(channels,7,activation='tanh',padding='same')(x)
generator=Model(generator_input,x)#它在包含在GAN里训练的,所以这里不用编译
# generator.summary() # 鉴别网络
discriminator_input=Input((height,width,channels))
x=Conv2D(128,3)(discriminator_input)
x=LeakyReLU(alpha=0.1)(x) x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
# 2*2*128
x=Conv2D(128,4,strides=2)(x)
x=LeakyReLU(alpha=0.1)(x)
x=Flatten()(x)
# Dropout和给标签添加噪声,可以避免GAN卡住
x=Dropout(0.4)(x)
x=Dense(1,activation='sigmoid')(x) discriminator=Model(discriminator_input,x)
# discriminator.summary() # clipvalue,梯度超过这个值就截断,decay,衰减,使得训练稳定
discriminator_optimizer=RMSprop(lr=0.0003,clipvalue=1.0,decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer,loss='binary_crossentropy') # 最后的生成对抗网络,由生成网络与对抗网络组合而成,此时冻结鉴别网络,训练的只是生成网络
discriminator.trainable=False
# 组成整个生成对抗网络
gan_input=Input((latent_dim,))
# 最终网络形式为鉴别网络作用于生成网络,故生成器也不用compile
gan_output=discriminator(generator(gan_input))
gan_optimizer=RMSprop(lr=0.0004,clipvalue=1.0,decay=1e-8)
gan=Model(gan_input,gan_output)
gan.compile(optimizer=gan_optimizer,loss='binary_crossentropy')
训练过程,此处并未使用多次鉴别器更新一次生成器更新,你可以自己调整(即循环里面开头添加个循环,训练鉴别器)。
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
# 选择frog类别,总共10个类
x_train=x_train[y_train.flatten()==6]
# reshape到输入格式 nums*height*width*channels,像素归一化
x_train=x_train.reshape((x_train.shape[0],)+(height,width,channels)).astype('float32')/255.
iters=10000
batch_size=20
save_dir='frog' start=0
for step in range(iters):
# 选取潜空间中随机矢量(正态分布)
random_latent_vec=np.random.normal(size=(batch_size,latent_dim))
# 生成网络产生图片
generated_images=generator.predict(random_latent_vec)
stop=start+batch_size
# 真实原始图片
real_images=x_train[start:stop]
# mix生成和真实图片
combined_images=np.concatenate([generated_images,real_images])
# mix labels
labels=np.concatenate([np.ones((batch_size,1)),np.zeros((batch_size,1))])
# trick:标签添加随机噪声
labels+=0.05*np.random.random(labels.shape)
# 鉴别loss,可能为负,因为使用的是LeakyReLU
d_loss=discriminator.train_on_batch(combined_images,labels)
# 重新生成随机矢量
random_latent_vec=np.random.normal(size=(batch_size,latent_dim))
# 故意设置标签为真实
misleading_targets=np.zeros((batch_size,1))
a_loss=gan.train_on_batch(random_latent_vec,misleading_targets)
start+=batch_size
if start>len(x_train)-batch_size:
start=0
if step%100==0:
# gan.save_weights('gan.h5')
print('discriminator loss:',d_loss)
print('adversarial loss:',a_loss)
# 保存一个batch里的第一个图片,之前像素归一化了,这里乘以255还原
img=image.array_to_img(generated_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'generated_frog'+str(step)+'.png'))
# 保存一个对比图片
img=image.array_to_img(real_images[0]*255.,scale=False)
img.save(os.path.join(save_dir,'real_frog'+str(step)+'.png'))
loss变化趋势,可以看到是不稳定的
看真实图和生成图片对比,上下2行图片只是同一批保存的,没有相关性。这是训练4000步,也即80000个训练样本后的结果。看起来比较丑陋吧。
GAN(生成对抗网络)之keras实践的更多相关文章
- 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】
本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...
- 生成对抗网络(Generative Adversarial Networks,GAN)初探
1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...
- GAN实战笔记——第四章深度卷积生成对抗网络(DCGAN)
深度卷积生成对抗网络(DCGAN) 我们在第3章实现了一个GAN,其生成器和判别器是具有单个隐藏层的简单前馈神经网络.尽管很简单,但GAN的生成器充分训练后得到的手写数字图像的真实性有些还是很具说服力 ...
- GAN实战笔记——第六章渐进式增长生成对抗网络(PGGAN)
渐进式增长生成对抗网络(PGGAN) 使用 TensorFlow和 TensorFlow Hub( TFHUB)构建渐进式增长生成对抗网络( Progressive GAN, PGGAN或 PROGA ...
- [ZZ] Valse 2017 | 生成对抗网络(GAN)研究年度进展评述
Valse 2017 | 生成对抗网络(GAN)研究年度进展评述 https://www.leiphone.com/news/201704/fcG0rTSZWqgI31eY.html?viewType ...
- 生成对抗网络(GAN)
基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...
- TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成
生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...
- AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华
注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...
- 生成对抗网络GAN介绍
GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...
随机推荐
- 前端知识体系:JavaScript基础-原型和原型链-实现继承的几种方式以及他们的优缺点
实现继承的几种方式以及他们的优缺点(参考文档1.参考文档2.参考文档3) 要搞懂JS继承,我们首先要理解原型链:每一个实例对象都有一个__proto__属性(隐式原型),在js内部用来查找原型链:每一 ...
- vue04 总结
""" 1.环境 node:官网下载安装包,傻瓜式安装 - https://nodejs.org/zh-cn/ => 附带按照了npm cnpm:npm insta ...
- 11 git第二部分(未完成)
https://www.cnblogs.com/shangchunhong/p/9444335.html
- $ python manage.py makemigrations You are trying to add a non-nullable field 'name' to course without a default; we can't do that (the database needs something to populate existing rows). Please selec
问题: $ python manage.py makemigrationsYou are trying to add a non-nullable field 'name' to course wit ...
- The Reset Method of Te Philips VTR 5210
Pull down and hold the ON/OFF buttun, Then press the play button
- python音频处理
第一步:先下载ffmpeg-->下载链接 下载好解压到某个文件夹,并将该文件夹中的bin目录添加到系统path. 第二步:安装pydub pip3 install pydub # -*- cod ...
- Python数据挖掘-文本挖掘
文本挖掘概要 搞什么的? 从大量文本数据中,抽取出有价值的知识,并且利用这些知识更好的组织信息的过程. 目的是什么? 把文本信息转化为人们可利用的知识. 举例来说,下面的图表利用文本挖掘技术对库克ip ...
- php原生导出简单word表格(TP为例) (原)
后台: # 菲律宾名单word导出 public function export_word(){ $tids = $_GET['tids']; $userinfo=M("philippi ...
- mybatis之<trim
1.<trim prefix="" suffix="" suffixOverrides="" prefixOverrides=&quo ...
- Qt 字符串QString arg()用法总结
1.QString::arg()//用字符串变量参数依次替代字符串中最小数值 QString i = "iTest"; // current file's nu ...