1、结构图

2、知识点

生成器(G):将噪音数据生成一个想要的数据
判别器(D):将生成器的结果进行判别,

3、代码及案例

# coding: utf-8

# ## 对抗生成网络案例 ##
#
#
# <img src="jpg/3.png" alt="FAO" width="590" > # - 判别器 : 火眼金睛,分辨出生成和真实的 <br />
# <br />
# - 生成器 : 瞒天过海,骗过判别器 <br />
# <br />
# - 损失函数定义 : 一方面要让判别器分辨能力更强,另一方面要让生成器更真 <br />
# <br />
#
# <img src="jpg/1.jpg" alt="FAO" width="590" > # In[1]: import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as plt get_ipython().run_line_magic('matplotlib', 'inline') # # 导入数据 # In[2]: from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/data') # ## 网络架构
#
# ### 输入层 :待生成图像(噪音)和真实数据
#
# ### 生成网络:将噪音图像进行生成
#
# ### 判别网络:
# - (1)判断真实图像输出结果
# - (2)判断生成图像输出结果
#
# ### 目标函数:
# - (1)对于生成网络要使得生成结果通过判别网络为真
# - (2)对于判别网络要使得输入为真实图像时判别为真 输入为生成图像时判别为假
#
# <img src="jpg/2.png" alt="FAO" width="590" > # ## Inputs # In[3]: #真实数据和噪音数据
def get_inputs(real_size, noise_size): real_img = tf.placeholder(tf.float32, [None, real_size])
noise_img = tf.placeholder(tf.float32, [None, noise_size]) return real_img, noise_img # ## 生成器
# * noise_img: 产生的噪音输入
# * n_units: 隐层单元个数
# * out_dim: 输出的大小(28 * 28 * 1) # In[4]: def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01): with tf.variable_scope("generator", reuse=reuse):
# hidden layer
hidden1 = tf.layers.dense(noise_img, n_units)
# leaky ReLU
hidden1 = tf.maximum(alpha * hidden1, hidden1)
# dropout
hidden1 = tf.layers.dropout(hidden1, rate=0.2) # logits & outputs
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.tanh(logits) return logits, outputs # ## 判别器
# * img:输入
# * n_units:隐层单元数量
# * reuse:由于要使用两次 # In[5]: def get_discriminator(img, n_units, reuse=False, alpha=0.01): with tf.variable_scope("discriminator", reuse=reuse):
# hidden layer
hidden1 = tf.layers.dense(img, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1) # logits & outputs
logits = tf.layers.dense(hidden1, 1)
outputs = tf.sigmoid(logits) return logits, outputs # ## 网络参数定义
# * img_size:输入大小
# * noise_size:噪音图像大小
# * g_units:生成器隐层参数
# * d_units:判别器隐层参数
# * learning_rate:学习率 # In[6]: img_size = mnist.train.images[0].shape[0] noise_size = 100 g_units = 128 d_units = 128 learning_rate = 0.001 alpha = 0.01 # ## 构建网络 # In[7]: tf.reset_default_graph() real_img, noise_img = get_inputs(img_size, noise_size) # generator
g_logits, g_outputs = get_generator(noise_img, g_units, img_size) # discriminator
d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True) # ### 目标函数:
# - (1)对于生成网络要使得生成结果通过判别网络为真
# - (2)对于判别网络要使得输入为真实图像时判别为真 输入为生成图像时判别为假
#
# <img src="jpg/2.png" alt="FAO" width="590" > # In[8]: # discriminator的loss
# 识别真实图片
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_logits_real)))
# 识别生成的图片
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_fake)))
# 总体loss
d_loss = tf.add(d_loss_real, d_loss_fake) # generator的loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_logits_fake))) # ## 优化器 # In[9]: train_vars = tf.trainable_variables() # generator
g_vars = [var for var in train_vars if var.name.startswith("generator")]
# discriminator
d_vars = [var for var in train_vars if var.name.startswith("discriminator")] # optimizer
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars) # # 训练 # In[10]: # batch_size
batch_size = 64
# 训练迭代轮数
epochs = 300
# 抽取样本数
n_sample = 25 # 存储测试样例
samples = []
# 存储loss
losses = []
# 保存生成器变量
saver = tf.train.Saver(var_list = g_vars)
# 开始训练
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for batch_i in range(mnist.train.num_examples//batch_size):
batch = mnist.train.next_batch(batch_size) batch_images = batch[0].reshape((batch_size, 784))
# 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数
batch_images = batch_images*2 - 1 # generator的输入噪声
batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size)) # Run optimizers
_ = sess.run(d_train_opt, feed_dict={real_img: batch_images, noise_img: batch_noise})
_ = sess.run(g_train_opt, feed_dict={noise_img: batch_noise}) # 每一轮结束计算loss
train_loss_d = sess.run(d_loss,
feed_dict = {real_img: batch_images,
noise_img: batch_noise})
# real img loss
train_loss_d_real = sess.run(d_loss_real,
feed_dict = {real_img: batch_images,
noise_img: batch_noise})
# fake img loss
train_loss_d_fake = sess.run(d_loss_fake,
feed_dict = {real_img: batch_images,
noise_img: batch_noise})
# generator loss
train_loss_g = sess.run(g_loss,
feed_dict = {noise_img: batch_noise}) print("Epoch {}/{}...".format(e+1, epochs),
"判别器损失: {:.4f}(判别真实的: {:.4f} + 判别生成的: {:.4f})...".format(train_loss_d, train_loss_d_real, train_loss_d_fake),
"生成器损失: {:.4f}".format(train_loss_g)) losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g)) # 保存样本
sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))
gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True),
feed_dict={noise_img: sample_noise})
samples.append(gen_samples) saver.save(sess, './checkpoints/generator.ckpt') # 保存到本地
with open('train_samples.pkl', 'wb') as f:
pickle.dump(samples, f) # # loss迭代曲线 # In[11]: fig, ax = plt.subplots(figsize=(20,7))
losses = np.array(losses)
plt.plot(losses.T[0], label='判别器总损失')
plt.plot(losses.T[1], label='判别真实损失')
plt.plot(losses.T[2], label='判别生成损失')
plt.plot(losses.T[3], label='生成器损失')
plt.title("对抗生成网络")
ax.set_xlabel('epoch')
plt.legend() # # 生成结果 # In[12]: # Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
samples = pickle.load(f) # In[13]: #samples是保存的结果 epoch是第多少次迭代
def view_samples(epoch, samples): fig, axes = plt.subplots(figsize=(7,7), nrows=5, ncols=5, sharey=True, sharex=True)
for ax, img in zip(axes.flatten(), samples[epoch][1]): # 这里samples[epoch][1]代表生成的图像结果,而[0]代表对应的logits
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
im = ax.imshow(img.reshape((28,28)), cmap='Greys_r') return fig, axes # In[14]: _ = view_samples(-1, samples) # 显示最终的生成结果 # # 显示整个生成过程图片 # In[15]: # 指定要查看的轮次
epoch_idx = [10, 30, 60, 90, 120, 150, 180, 210, 240, 290]
show_imgs = []
for i in epoch_idx:
show_imgs.append(samples[i][1]) # In[16]: # 指定图片形状
rows, cols = 10, 25
fig, axes = plt.subplots(figsize=(30,12), nrows=rows, ncols=cols, sharex=True, sharey=True) idx = range(0, epochs, int(epochs/rows)) for sample, ax_row in zip(show_imgs, axes):
for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
ax.imshow(img.reshape((28,28)), cmap='Greys_r')
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False) # # 生成新的图片 # In[17]: # 加载我们的生成器变量
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
sample_noise = np.random.uniform(-1, 1, size=(25, noise_size))
gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True),
feed_dict={noise_img: sample_noise}) # In[18]: _ = view_samples(0, [gen_samples])

4、优化目标

深度学习之GAN对抗神经网络的更多相关文章

  1. Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时

    Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时 在论文中,Capsule被Hinton大神定义为这样一组神经元:其活动向量所表示的是特定实体类型 ...

  2. 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...

  3. Pytorch_第六篇_深度学习 (DeepLearning) 基础 [2]---神经网络常用的损失函数

    深度学习 (DeepLearning) 基础 [2]---神经网络常用的损失函数 Introduce 在上一篇"深度学习 (DeepLearning) 基础 [1]---监督学习和无监督学习 ...

  4. 【深度学习】--GAN从入门到初始

    一.前述 GAN,生成对抗网络,在2016年基本火爆深度学习,所有有必要学习一下.生成对抗网络直观的应用可以帮我们生成数据,图片. 二.具体 1.生活案例 比如假设真钱 r 坏人定义为G  我们通过 ...

  5. 深度学习-Wasserstein GAN论文理解笔记

    GAN存在问题 训练困难,G和D多次尝试没有稳定性,Loss无法知道能否优化,生成样本单一,改进方案靠暴力尝试 WGAN GAN的Loss函数选择不合适,使模型容易面临梯度消失,梯度不稳定,优化目标不 ...

  6. [DeeplearningAI笔记]神经网络与深度学习2.11_2.16神经网络基础(向量化)

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.11向量化 向量化是消除代码中显示for循环语句的艺术,在训练大数据集时,深度学习算法才变得高效,所以代码运行的非常快十分重要.所以在深度学 ...

  7. 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  8. TensorFlow(实战深度学习框架)----深层神经网络(第四章)

    深层神经网络可以解决部分浅层神经网络解决不了的问题. 神经网络的优化目标-----损失函数 深度学习:一类通过多层非线性变化对高复杂性数据建模算法的合集.(两个重要的特性:多层和非线性) 线性模型的最 ...

  9. SIGAI深度学习第二集 人工神经网络1

    讲授神经网络的思想起源.神经元原理.神经网络的结构和本质.正向传播算法.链式求导及反向传播算法.神经网络怎么用于实际问题等 课程大纲: 神经网络的思想起源 神经元的原理 神经网络结构 正向传播算法 怎 ...

随机推荐

  1. shell脚本基础编写

    shell脚本的格式 名称:Shell 脚本文件的名称可以任意,但为了避免被误以为是普通文件,建议将 .sh 后缀加上,以表示是一个脚本文件. shell 脚本中一般会出现三种不同的元素: 第一行的脚 ...

  2. Appium Desired Capabilities-General Capabilities

    Desired Capabilities are keys and values encoded in a JSON object, sent by Appium clients to the ser ...

  3. JVM命令jinfo

          jinfo也是jvm中参与的一个命令,可以查看运行中jvm的全部参数,还可以设置部分参数.   格式      jinfo [ option ] pid      jinfo [ opti ...

  4. python_tkinter基本属性

    1.外形尺寸 尺寸单位:只用默认的像素或者其他字符类的值!,不要用英寸毫米之类的内容. btn = tkinter.Button(root,text = '按钮') # 设置按钮尺寸,绝大多数默认单位 ...

  5. IDEA中配置Jetty Server

    首先去 Eclipse官网下载Jetty jar包 鼠标移到Jetty上时 点击 Git it (得到它) 点击 .zip等待下载完成 然后 解压出来 接下就让我们 开始 使用IDEA了(创建一个We ...

  6. mysql锁定单个表的方法

    mysql锁定单个表的方法mysql>lock table userstat read;mysql>unlock tables; 本文来自ChinaUnix博客,如果查看原文请点:http ...

  7. hivesql中的concat函数,concat_ws函数,concat_group函数之间的区别

    一.CONCAT()函数CONCAT()函数用于将多个字符串连接成一个字符串.使用数据表Info作为示例,其中SELECT id,name FROM info LIMIT 1;的返回结果为 +---- ...

  8. qt 启动参数 -qws

    运行嵌入式程序 在嵌入式QT版本中,程序需要服务器或自己作为服务器程序. 服务器程序构造的方法是构造一个QApplication::GuiServe类型的QApplication对象.或者使用-qws ...

  9. vue cli 3.x 配置使用 sourceMap

    项目使用vue cli 3.x搭建,没有了配置文件,如何更方便的查找到对应的scss文件,配置项目支持sourceMap方式? 分二步走: 1.项目根目录(不是src目录,不要搞错了)添加vue.co ...

  10. 1 Mybatis

    1 使用Maven导入mybatis依赖 在pom.xml中写上一下代码:这些代码的查找可在https://mvnrepository.com/open-source网站上寻找,导入mybatis时要 ...