1.GAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:

  • G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。

  • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”

最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。

这样我们的目的就达成了:我们得到了一个生成式的模型G,它可以用来生成图片。

以上只是大致说了一下GAN的核心原理,如何用数学语言描述呢?这里直接摘录论文里的公式:

(1)优化D:

优化第一项是真是样本x输入的时候,结果越大越好;对于噪声等的输入z,生成的假样本G(z)要越小越好

(2)优化G:

优化生成器时和真是样本没关系,故不需要考虑;这时候只有假样本,但生成器希望假样本越逼真越好(接近1),故D(G(z)越大越好,则最小化1-D(G(z))

2.GAN的特点:

(1)相比较传统的模型,他存在两个不同的网络,而不是单一的网络,并且训练方式采用的是对抗训练方式

(2)GAN中G的梯度更新信息来自判别器D,而不是来自数据样本

3. GAN 的优点:

(1) GAN是一种生成式模型,相比较其他生成模型(玻尔兹曼机和GSNs)只用到了反向传播,而不需要复杂的马尔科夫链

(2)相比其他所有模型, GAN可以产生更加清晰,真实的样本

(3)GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域

(4)相比于变分自编码器, GANs没有引入任何决定性偏置( deterministic bias),变分方法引入决定性偏置,因为他们优化对数似然的下界,而不是似然度本身,这看起来导致了VAEs生成的实例比GANs更模糊

(5)相比VAE, GANs没有变分下界,如果鉴别器训练良好,那么生成器可以完美的学习到训练样本的分布.换句话说,GANs是渐进一致的,但是VAE是有偏差的

(6)GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,不管三七二十一,只要有一个的基准,直接上判别器,剩下的就交给对抗训练了。

4. GAN的缺点:

(1)训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.我们还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但我认为在实践中它还是比训练玻尔兹曼机稳定的多

(2)GAN不适合处理离散形式的数据,比如文本

(3)GAN存在训练不稳定、梯度消失、模式崩溃的问题(目前已解决)

5.为什么GAN中的优化器不常用SGD

(1)SGD容易震荡,容易使GAN训练不稳定,

(2)GAN的目的是在高维非凸的参数空间中找到纳什均衡点,GAN的纳什均衡点是一个鞍点,但是SGD只会找到局部极小值,因为SGD解决的是一个寻找最小值的问题,GAN是一个博弈问题。

6.训练GAN的一些技巧

(1). 输入规范化到(-1,1)之间,最后一层的激活函数使用tanh(BEGAN除外)

(2). 使用wassertein GAN的损失函数,

(3). 如果有标签数据的话,尽量使用标签,也有人提出使用反转标签效果很好,另外使用标签平滑,单边标签平滑或者双边标签平滑

(4). 使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm

(5). 避免使用RELU和pooling层,减少稀疏梯度的可能性,可以使用leakrelu激活函数

(6). 优化器尽量选择ADAM,学习率不要设置太大,初始1e-4可以参考,另外可以随着训练进行不断缩小学习率,

(7). 给D的网络层增加高斯噪声,相当于是一种正则

7.GAN实战

import tensorflow as tf #导入tensorflow
from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
import numpy as np #导入numpy
import matplotlib.pyplot as plt #plt是绘图工具,在训练过程中用于输出可视化结果
import matplotlib.gridspec as gridspec #gridspec是图片排列工具,在训练过程中用于输出可视化结果
import os #导入os def xavier_init(size): #初始化参数时使用的xavier_init函数
in_dim = size[]
xavier_stddev = . / tf.sqrt(in_dim / .) #初始化标准差
return tf.random_normal(shape=size, stddev=xavier_stddev) #返回初始化的结果 X = tf.placeholder(tf.float32, shape=[None, ]) #X表示真的样本(即真实的手写数字) D_W1 = tf.Variable(xavier_init([, ])) #表示使用xavier方式初始化的判别器的D_W1参数,是一个784行128列的矩阵
D_b1 = tf.Variable(tf.zeros(shape=[])) #表示全零方式初始化的判别器的D_1参数,是一个长度为128的向量
D_W2 = tf.Variable(xavier_init([, ])) #表示使用xavier方式初始化的判别器的D_W2参数,是一个128行1列的矩阵
D_b2 = tf.Variable(tf.zeros(shape=[])) ##表示全零方式初始化的判别器的D_1参数,是一个长度为1的向量
theta_D = [D_W1, D_W2, D_b1, D_b2] #theta_D表示判别器的可训练参数集合 Z = tf.placeholder(tf.float32, shape=[None, ]) #Z表示生成器的输入(在这里是噪声),是一个N列100行的矩阵 G_W1 = tf.Variable(xavier_init([, ])) #表示使用xavier方式初始化的生成器的G_W1参数,是一个100行128列的矩阵
G_b1 = tf.Variable(tf.zeros(shape=[])) #表示全零方式初始化的生成器的G_b1参数,是一个长度为128的向量
G_W2 = tf.Variable(xavier_init([, ])) #表示使用xavier方式初始化的生成器的G_W2参数,是一个128行784列的矩阵
G_b2 = tf.Variable(tf.zeros(shape=[])) #表示全零方式初始化的生成器的G_b2参数,是一个长度为784的向量
theta_G = [G_W1, G_W2, G_b1, G_b2] #theta_G表示生成器的可训练参数集合 def sample_Z(m, n): #生成维度为[m, n]的随机噪声作为生成器G的输入
return np.random.uniform(-., ., size=[m, n]) def generator(z): #生成器,z的维度为[N, ]
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #输入的随机噪声乘以G_W1矩阵加上偏置G_b1,G_h1维度为[N, ]
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G_h1乘以G_W2矩阵加上偏置G_b2,G_log_prob维度为[N, ]
G_prob = tf.nn.sigmoid(G_log_prob) #G_log_prob经过一个sigmoid函数,G_prob维度为[N, ]
return G_prob #返回G_prob def discriminator(x): #判别器,x的维度为[N, ]
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) #输入乘以D_W1矩阵加上偏置D_b1,D_h1维度为[N, ]
D_logit = tf.matmul(D_h1, D_W2) + D_b2 #D_h1乘以D_W2矩阵加上偏置D_b2,D_logit维度为[N, ]
D_prob = tf.nn.sigmoid(D_logit) #D_logit经过一个sigmoid函数,D_prob维度为[N, ]
return D_prob, D_logit #返回D_prob, D_logit G_sample = generator(Z) #取得生成器的生成结果
D_real, D_logit_real = discriminator(X) #取得判别器判别的真实手写数字的结果
D_fake, D_logit_fake = discriminator(G_sample) #取得判别器判别的生成的手写数字的结果 #对判别器对真实样本的判别结果计算误差(将结果与1比较)
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, targets=tf.ones_like(D_logit_real)))
#对判别器对虚假样本(即生成器生成的手写数字)的判别结果计算误差(将结果与0比较)
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.zeros_like(D_logit_fake)))
#判别器的误差
D_loss = D_loss_real + D_loss_fake
#生成器的误差(将判别器返回的对虚假样本的判别结果与1比较)
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, targets=tf.ones_like(D_logit_fake))) mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集 D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) #判别器的训练器
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #生成器的训练器 mb_size = #训练的batch_size
Z_dim = #生成器输入的随机噪声的列的维度 sess = tf.Session() #会话层
sess.run(tf.initialize_all_variables()) #初始化所有可训练参数 def plot(samples): #保存图片时使用的plot函数
fig = plt.figure(figsize=(, )) #初始化一个4行4列包含16张子图像的图片
gs = gridspec.GridSpec(, ) #调整子图的位置
gs.update(wspace=0.05, hspace=0.05) #置子图间的间距
for i, sample in enumerate(samples): #依次将16张子图填充进需要保存的图像
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(, ), cmap='Greys_r')
return fig path = '/data/User/zcc/' #保存可视化结果的路径
i = #训练过程中保存的可视化结果的索引
for it in range(): #训练100万次
if it % == : #每训练1000次就保存一下结果
samples = sess.run(G_sample, feed_dict={Z: sample_Z(, Z_dim)})
fig = plot(samples) #通过plot函数生成可视化结果
plt.savefig(path+'out/{}.png'.format(str(i).zfill()), bbox_inches='tight') #保存可视化结果
i +=
plt.close(fig) X_mb, _ = mnist.train.next_batch(mb_size) #得到训练一个batch所需的真实手写数字(作为判别器的输入) #下面是得到训练一次的结果,通过sess来run出来
_, D_loss_curr, D_loss_real, D_loss_fake, D_loss = sess.run([D_solver, D_loss, D_loss_real, D_loss_fake, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)}) if it % == : #每训练1000次输出一下结果
print('Iter: {}'.format(it))
print('D loss: {:.4}'. format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
print()

参考博客:

https://blog.csdn.net/m0_37407756/article/details/75309670

https://blog.csdn.net/jiongnima/article/details/80033169

生成对抗网络GAN详解与代码的更多相关文章

  1. 用MXNet实现mnist的生成对抗网络(GAN)

    用MXNet实现mnist的生成对抗网络(GAN) 生成式对抗网络(Generative Adversarial Network,简称GAN)由一个生成网络与一个判别网络组成.生成网络从潜在空间(la ...

  2. TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成

    生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...

  3. 生成对抗网络GAN介绍

    GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...

  4. 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...

  5. 深度学习-生成对抗网络GAN笔记

    生成对抗网络(GAN)由2个重要的部分构成: 生成器G(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器 判别器D(Discriminator):判断这张图像是真实的 ...

  6. 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)

    参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...

  7. 科普 | ​生成对抗网络(GAN)的发展史

    来源:https://en.wikipedia.org/wiki/Edmond_de_Belamy 五年前,Generative Adversarial Networks(GANs)在深度学习领域掀起 ...

  8. 生成对抗网络(GAN)

    基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...

  9. 利用tensorflow训练简单的生成对抗网络GAN

    对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(di ...

随机推荐

  1. linux文件共享服务

    linux文件共享配置 Windows访问linux 以下操作都在关闭防火墙和关闭selinux的环境下. 关闭防火墙的命令:service iptables stop关闭SELINUX命令:sete ...

  2. 009——Matlab调用cftool工具的函数方法

    (一)参考文献:http://blog.sina.com.cn/s/blog_4a0a8b5d0100uare.html (二)视频教程:https://v.qq.com/x/page/x30392r ...

  3. 洛谷 P2615 神奇的幻方 题解

    每日一题系列day1 打卡 Analysis 水货模拟,不多说了 #include<iostream> #include<cstdio> #include<cstring ...

  4. MongoDB CRUD 操作

    crud是指在做计算处理时的增加(Create).读取查询(Retrieve).更新(Update)和删除(Delete)几个单词的首字母简写.crud主要被用在描述软件系统中数据库或者持久层的基本操 ...

  5. 2G以上的大文件如何上传

    无法上传大文件是因为php.ini配置有限制了,这样限制了用户默认最大为2MB了,超过了就不能上传了,如果你确实要上传我们可以按下面方法来处理一下. 打开php.ini, 参数  设置  说明 fil ...

  6. Centos7.x 安装 Supervisord

    [环境] 系统:Centos 7.3 软件:supervisord [安装Supervisord] yum install epel-release yum install -y supervisor ...

  7. ckeditor自定义工具栏

    /** * 获取编辑器工具栏自定义参数 * @param type 类型 simple=极简版 basic=基本版 full=完整版 */ function get_ckeditor_toolbar( ...

  8. Linux网络编程四、UDP,广播和组播

    一.UDP UDP:是一个支持无连接的传输协议,全称是用户数据包协议(User Datagram Protocol).UDP协议无需像TCP一样要建立连接后才能发送封装的IP数据报,也是因此UDP相较 ...

  9. (转)kvm初识

    一 虚拟化介绍 1 常见虚拟化软件VMware系列VMware workstation.VMware vsphere(VMware esxi).VMware Fusion(Mac) Xen 开源 半虚 ...

  10. win7下如何根据端口号杀掉进程

    点击windows左下窗口图标按钮.   输入cmd   输入netstat -ano后回车.   左边箭头指向端口号,右边箭头指向为这个端口号对应的进程号pid,我们记下pid号   我们以2001 ...