生成对抗网络GAN详解与代码
1.GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:
G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。
D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”
最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。
这样我们的目的就达成了:我们得到了一个生成式的模型G,它可以用来生成图片。
以上只是大致说了一下GAN的核心原理,如何用数学语言描述呢?这里直接摘录论文里的公式:
(1)优化D:
优化第一项是真是样本x输入的时候,结果越大越好;对于噪声等的输入z,生成的假样本G(z)要越小越好
(2)优化G:
优化生成器时和真是样本没关系,故不需要考虑;这时候只有假样本,但生成器希望假样本越逼真越好(接近1),故D(G(z)越大越好,则最小化1-D(G(z))
2.GAN的特点:
(1)相比较传统的模型,他存在两个不同的网络,而不是单一的网络,并且训练方式采用的是对抗训练方式
(2)GAN中G的梯度更新信息来自判别器D,而不是来自数据样本
3. GAN 的优点:
(1) GAN是一种生成式模型,相比较其他生成模型(玻尔兹曼机和GSNs)只用到了反向传播,而不需要复杂的马尔科夫链
(2)相比其他所有模型, GAN可以产生更加清晰,真实的样本
(3)GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域
(4)相比于变分自编码器, GANs没有引入任何决定性偏置( deterministic bias),变分方法引入决定性偏置,因为他们优化对数似然的下界,而不是似然度本身,这看起来导致了VAEs生成的实例比GANs更模糊
(5)相比VAE, GANs没有变分下界,如果鉴别器训练良好,那么生成器可以完美的学习到训练样本的分布.换句话说,GANs是渐进一致的,但是VAE是有偏差的
(6)GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,不管三七二十一,只要有一个的基准,直接上判别器,剩下的就交给对抗训练了。
4. GAN的缺点:
(1)训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.我们还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但我认为在实践中它还是比训练玻尔兹曼机稳定的多
(2)GAN不适合处理离散形式的数据,比如文本
(3)GAN存在训练不稳定、梯度消失、模式崩溃的问题(目前已解决)
5.为什么GAN中的优化器不常用SGD
(1)SGD容易震荡,容易使GAN训练不稳定,
(2)GAN的目的是在高维非凸的参数空间中找到纳什均衡点,GAN的纳什均衡点是一个鞍点,但是SGD只会找到局部极小值,因为SGD解决的是一个寻找最小值的问题,GAN是一个博弈问题。
6.训练GAN的一些技巧
(1). 输入规范化到(-1,1)之间,最后一层的激活函数使用tanh(BEGAN除外)
(2). 使用wassertein GAN的损失函数,
(3). 如果有标签数据的话,尽量使用标签,也有人提出使用反转标签效果很好,另外使用标签平滑,单边标签平滑或者双边标签平滑
(4). 使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm
(5). 避免使用RELU和pooling层,减少稀疏梯度的可能性,可以使用leakrelu激活函数
(6). 优化器尽量选择ADAM,学习率不要设置太大,初始1e-4可以参考,另外可以随着训练进行不断缩小学习率,
(7). 给D的网络层增加高斯噪声,相当于是一种正则
7.GAN实战
import tensorflow as tf #导入tensorflow
from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
import numpy as np #导入numpy
import matplotlib.pyplot as plt #plt是绘图工具,在训练过程中用于输出可视化结果
import matplotlib.gridspec as gridspec #gridspec是图片排列工具,在训练过程中用于输出可视化结果
import os #导入os def xavier_init(size): #初始化参数时使用的xavier_init函数
in_dim = size[]
xavier_stddev = . / tf.sqrt(in_dim / .) #初始化标准差
return tf.random_normal(shape=size, stddev=xavier_stddev) #返回初始化的结果 X = tf.placeholder(tf.float32, shape=[None, ]) #X表示真的样本(即真实的手写数字) D_W1 = tf.Variable(xavier_init([, ])) #表示使用xavier方式初始化的判别器的D_W1参数,是一个784行128列的矩阵
D_b1 = tf.Variable(tf.zeros(shape=[])) #表示全零方式初始化的判别器的D_1参数,是一个长度为128的向量
D_W2 = tf.Variable(xavier_init([, ])) #表示使用xavier方式初始化的判别器的D_W2参数,是一个128行1列的矩阵
D_b2 = tf.Variable(tf.zeros(shape=[])) ##表示全零方式初始化的判别器的D_1参数,是一个长度为1的向量
theta_D = [D_W1, D_W2, D_b1, D_b2] #theta_D表示判别器的可训练参数集合 Z = tf.placeholder(tf.float32, shape=[None, ]) #Z表示生成器的输入(在这里是噪声),是一个N列100行的矩阵 G_W1 = tf.Variable(xavier_init([, ])) #表示使用xavier方式初始化的生成器的G_W1参数,是一个100行128列的矩阵
G_b1 = tf.Variable(tf.zeros(shape=[])) #表示全零方式初始化的生成器的G_b1参数,是一个长度为128的向量
G_W2 = tf.Variable(xavier_init([, ])) #表示使用xavier方式初始化的生成器的G_W2参数,是一个128行784列的矩阵
G_b2 = tf.Variable(tf.zeros(shape=[])) #表示全零方式初始化的生成器的G_b2参数,是一个长度为784的向量
theta_G = [G_W1, G_W2, G_b1, G_b2] #theta_G表示生成器的可训练参数集合 def sample_Z(m, n): #生成维度为[m, n]的随机噪声作为生成器G的输入
return np.random.uniform(-., ., size=[m, n]) def generator(z): #生成器,z的维度为[N, ]
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #输入的随机噪声乘以G_W1矩阵加上偏置G_b1,G_h1维度为[N, ]
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G_h1乘以G_W2矩阵加上偏置G_b2,G_log_prob维度为[N, ]
G_prob = tf.nn.sigmoid(G_log_prob) #G_log_prob经过一个sigmoid函数,G_prob维度为[N, ]
return G_prob #返回G_prob def discriminator(x): #判别器,x的维度为[N, ]
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) #输入乘以D_W1矩阵加上偏置D_b1,D_h1维度为[N, ]
D_logit = tf.matmul(D_h1, D_W2) + D_b2 #D_h1乘以D_W2矩阵加上偏置D_b2,D_logit维度为[N, ]
D_prob = tf.nn.sigmoid(D_logit) #D_logit经过一个sigmoid函数,D_prob维度为[N, ]
return D_prob, D_logit #返回D_prob, D_logit G_sample = generator(Z) #取得生成器的生成结果
D_real, D_logit_real = discriminator(X) #取得判别器判别的真实手写数字的结果
D_fake, D_logit_fake = discriminator(G_sample) #取得判别器判别的生成的手写数字的结果 #对判别器对真实样本的判别结果计算误差(将结果与1比较)
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, targets=tf.ones_like(D_logit_real)))
#对判别器对虚假样本(即生成器生成的手写数字)的判别结果计算误差(将结果与0比较)
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.zeros_like(D_logit_fake)))
#判别器的误差
D_loss = D_loss_real + D_loss_fake
#生成器的误差(将判别器返回的对虚假样本的判别结果与1比较)
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.ones_like(D_logit_fake))) mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集 D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) #判别器的训练器
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #生成器的训练器 mb_size = #训练的batch_size
Z_dim = #生成器输入的随机噪声的列的维度 sess = tf.Session() #会话层
sess.run(tf.initialize_all_variables()) #初始化所有可训练参数 def plot(samples): #保存图片时使用的plot函数
fig = plt.figure(figsize=(, )) #初始化一个4行4列包含16张子图像的图片
gs = gridspec.GridSpec(, ) #调整子图的位置
gs.update(wspace=0.05, hspace=0.05) #置子图间的间距
for i, sample in enumerate(samples): #依次将16张子图填充进需要保存的图像
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(, ), cmap='Greys_r')
return fig path = '/data/User/zcc/' #保存可视化结果的路径
i = #训练过程中保存的可视化结果的索引
for it in range(): #训练100万次
if it % == : #每训练1000次就保存一下结果
samples = sess.run(G_sample, feed_dict={Z: sample_Z(, Z_dim)})
fig = plot(samples) #通过plot函数生成可视化结果
plt.savefig(path+'out/{}.png'.format(str(i).zfill()), bbox_inches='tight') #保存可视化结果
i +=
plt.close(fig) X_mb, _ = mnist.train.next_batch(mb_size) #得到训练一个batch所需的真实手写数字(作为判别器的输入) #下面是得到训练一次的结果,通过sess来run出来
_, D_loss_curr, D_loss_real, D_loss_fake, D_loss = sess.run([D_solver, D_loss, D_loss_real, D_loss_fake, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)}) if it % == : #每训练1000次输出一下结果
print('Iter: {}'.format(it))
print('D loss: {:.4}'. format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
print()
参考博客:
https://blog.csdn.net/m0_37407756/article/details/75309670
https://blog.csdn.net/jiongnima/article/details/80033169
生成对抗网络GAN详解与代码的更多相关文章
- 用MXNet实现mnist的生成对抗网络(GAN)
用MXNet实现mnist的生成对抗网络(GAN) 生成式对抗网络(Generative Adversarial Network,简称GAN)由一个生成网络与一个判别网络组成.生成网络从潜在空间(la ...
- TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成
生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...
- 生成对抗网络GAN介绍
GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...
- 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...
- 深度学习-生成对抗网络GAN笔记
生成对抗网络(GAN)由2个重要的部分构成: 生成器G(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器 判别器D(Discriminator):判断这张图像是真实的 ...
- 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...
- 科普 | 生成对抗网络(GAN)的发展史
来源:https://en.wikipedia.org/wiki/Edmond_de_Belamy 五年前,Generative Adversarial Networks(GANs)在深度学习领域掀起 ...
- 生成对抗网络(GAN)
基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...
- 利用tensorflow训练简单的生成对抗网络GAN
对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(di ...
随机推荐
- 02_View
1.View 1.基于类的视图 Class-based Views REST framework提供APIView是Django的View的子类 发送到View的Request请求:是REST fra ...
- react的登录逻辑
https://blog.csdn.net/qq_36822018/article/details/83028661(先看看这个 https://blog.csdn.net/weixin_342681 ...
- K-D Tree学习笔记
用途 做各种二维三维四维偏序等等. 代替空间巨大的树套树. 数据较弱的时候水分. 思想 我们发现平衡树这种东西功能强大,然而只能做一维上的询问修改,显得美中不足. 于是我们尝试用平衡树的这种二叉树结构 ...
- Codeforces 1182D Complete Mirror [树哈希]
Codeforces 中考考完之后第一个AC,纪念一下qwq 思路 简单理解一下题之后就可以发现其实就是要求一个点,使得把它提为根之后整棵树显得非常对称. 很容易想到树哈希来判结构是否相同,而且由于只 ...
- ros平台下python脚本控制机械臂运动
在使用moveit_setup_assistant生成机械臂的配置文件后可以使用roslaunch demo.launch启动demo,在rviz中可以通过拖动机械臂进行运动学正逆解/轨迹规划等仿真运 ...
- epoll事件模型
事件模型 EPOLL事件有两种模型: Edge Triggered (ET) 边缘触发只有数据到来才触发,不管缓存区中是否还有数据. Level Triggered (LT) 水平触发只要有数据都会触 ...
- 数据结构实验之图论六:村村通公路【Prim算法】(SDUT 3362)
题解:选点,选最小权的边,更新点权.可以手动自行找一遍怎么找到这个最小的生成树,随便选一个点放入我们选的集合中,然后看和这个点相连的点中,与那个点相连的那条边权值是最小的,选择之后,把相连的这个点一起 ...
- 一个Maven项目在eclipse中正常,但在IDEA中启动时报错
这个项目十有八九最初是在ecplise创建的,框架上十有八九整合了Mybatis,报的错误十有八九是 org.apache.ibatis.binding.BindingException: Inval ...
- PHP 之CI框架+GatewayWorker+AmazeUI低仿微信聊天网页版
html5开发的仿微信网页版聊天,采用html5+css3+jquery+websocket+amazeui等技术混合架构开发,实现了微信网页版的主要功能. 一.效果图 二.前端参考代码 <!D ...
- NPM私有包部署到私有仓库
NPM私有包部署到私有仓库1.项目部署到NPM2.私有仓库的搭建1,项目部署到NPM注册NPM账号注册地址:https://www.npmjs.com/ 注册完成后进入邮箱验证 账号登录 npm lo ...