GAN网络架构分析

上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake的就是tf.zeros。

网络具体形状大体如上,具体数值有所调整,生成器过程为:噪声向量-全连接-卷积-卷积-卷积,辨别器过程:图片-卷积-卷积-全连接-全连接。

和预想的不同,实际上数据在生成器中并不是从无到有由小变大的过程,而是由3136(56*56)经过正常卷积步骤下降为28*28的过程。

实现如下:

import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../../Mnist_data') """测试数据""" # sample_image = mnist.train.next_batch(1)[0]
# print(sample_image.shape)
# sample_image = sample_image.reshape([28, 28])
# plt.imshow(sample_image, cmap='Greys') """分辨器""" def discriminator(images, reuse=None):
with tf.variable_scope(tf.get_variable_scope(), reuse=reuse) as scope:
# 卷积 + 激活 + 池化
d_w1 = tf.get_variable('d_w1',[5,5,1,32],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b1 = tf.get_variable('d_b1',[32],initializer=tf.constant_initializer(0))
d1 = tf.nn.conv2d(input=images,filter=d_w1,strides=[1,1,1,1],padding='SAME')
d1 = d1 + d_b1
d1 = tf.nn.relu(d1)
d1 = tf.nn.avg_pool(d1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') # 卷积 + 激活 + 池化
d_w2 = tf.get_variable('d_w2',[5,5,32,64],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b2 = tf.get_variable('d_b2',[64],initializer=tf.constant_initializer(0))
d2 = tf.nn.conv2d(input=d1,filter=d_w2,strides=[1,1,1,1],padding='SAME')
d2 = d2 + d_b2
d2 = tf.nn.relu(d2)
d2 = tf.nn.avg_pool(d2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') # 全连接 + 激活
d_w3 = tf.get_variable('d_w3',[7 * 7 * 64,1024],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b3 = tf.get_variable('d_b3',[1024],initializer=tf.constant_initializer(0))
d3 = tf.reshape(d2,[-1,7 * 7 * 64])
d3 = tf.matmul(d3,d_w3)
d3 = d3 + d_b3
d3 = tf.nn.relu(d3) # 全连接
d_w4 = tf.get_variable('d_w4',[1024,1],initializer=tf.truncated_normal_initializer(stddev=0.02))
d_b4 = tf.get_variable('d_b4',[1],initializer=tf.constant_initializer(0))
d4 = tf.matmul(d3,d_w4) + d_b4 # 最后输出一个非尺度化的值
return d4 """生成器""" def generator(z, batch_size, z_dim, reuse=False):
'''接收特征向量z,由z生成图片''' with tf.variable_scope(tf.get_variable_scope(),reuse=reuse):
# 全连接 + 批正则化 + 激活
# z_dim -> 3136 -> 56*56*1
g_w1 = tf.get_variable('g_w1', [z_dim, 3136], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b1 = tf.get_variable('g_b1', [3136], initializer=tf.truncated_normal_initializer(stddev=0.02))
g1 = tf.matmul(z, g_w1) + g_b1
g1 = tf.reshape(g1, [-1, 56, 56, 1])
g1 = tf.contrib.layers.batch_norm(g1, epsilon=1e-5, scope='bn1')
g1 = tf.nn.relu(g1) # 卷积 + 批正则化 + 激活
g_w2 = tf.get_variable('g_w2',[3,3,1,z_dim / 2],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b2 = tf.get_variable('g_b2',[z_dim / 2],initializer=tf.truncated_normal_initializer(stddev=0.02))
g2 = tf.nn.conv2d(g1,g_w2,strides=[1,2,2,1],padding='SAME')
g2 = g2 + g_b2
g2 = tf.contrib.layers.batch_norm(g2,epsilon=1e-5,scope='bn2')
g2 = tf.nn.relu(g2)
g2 = tf.image.resize_images(g2,[56,56]) # 卷积 + 批正则化 + 激活
g_w3 = tf.get_variable('g_w3',[3,3,z_dim / 2,z_dim / 4],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b3 = tf.get_variable('g_b3',[z_dim / 4],initializer=tf.truncated_normal_initializer(stddev=0.02))
g3 = tf.nn.conv2d(g2,g_w3,strides=[1,2,2,1],padding='SAME')
g3 = g3 + g_b3
g3 = tf.contrib.layers.batch_norm(g3,epsilon=1e-5,scope='bn3')
g3 = tf.nn.relu(g3)
g3 = tf.image.resize_images(g3,[56,56]) # 卷积 + 激活
g_w4 = tf.get_variable('g_w4',[1,1,z_dim / 4,1],dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b4 = tf.get_variable('g_b4',[1],initializer=tf.truncated_normal_initializer(stddev=0.02))
g4 = tf.nn.conv2d(g3,g_w4,strides=[1,2,2,1],padding='SAME')
g4 = g4 + g_b4
g4 = tf.sigmoid(g4) # 输出g4的维度: batch_size x 28 x 28 x 1
return g4

逻辑实现如下,不同组成部分的loss值是分开计算的:

"""逻辑架构"""

tf.reset_default_graph()
batch_size = 50
z_dimensions = 100 z_placeholder = tf.placeholder(tf.float32, [None, z_dimensions], name='z_placeholder')
x_placeholder = tf.placeholder(tf.float32, shape = [None,28,28,1], name='x_placeholder') Gz = generator(z_placeholder, batch_size, z_dimensions) # 根据z生成伪造图片
Dx = discriminator(x_placeholder) # 辨别器辨别真实图片
Dg = discriminator(Gz, reuse=True) # 辨别器辨别伪造图片 #discriminator 的loss 分为两部分
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.zeros_like(Dg)))
d_loss=d_loss_real + d_loss_fake
# Generator的目标是生成尽可能真实的图像,所以计算Dg和1的loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg)))

优化器部分有一些注意点:

"""优化部分"""

# 由于训练时生成器和辨别器是分开训练的,
# 所以不同的训练过程对应的优化参数是要做区分的
tvars = tf.trainable_variables() d_vars = [var for var in tvars if 'd_' in var.name]
g_vars = [var for var in tvars if 'g_' in var.name] d_trainer_real = tf.train.AdamOptimizer(0.0003).minimize(d_loss_real, var_list=d_vars)
d_trainer_fake = tf.train.AdamOptimizer(0.0003).minimize(d_loss_fake, var_list=d_vars)
d_trainer = tf.train.AdamOptimizer(0.0003).minimize(d_loss, var_list=d_vars)
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

入注释所说,训练不同的位置,优化不同的参数,不可以混淆,所以这里就涉及了tf变量提取的手法,结果展示如下:

import pprint
pp = pprint.PrettyPrinter()
pp.pprint(d_vars)
pp.pprint(g_vars) [<tf.Variable 'd_w1:0' shape=(5, 5, 1, 32) dtype=float32_ref>,
<tf.Variable 'd_b1:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'd_w2:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'd_b2:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'd_w3:0' shape=(3136, 1024) dtype=float32_ref>,
<tf.Variable 'd_b3:0' shape=(1024,) dtype=float32_ref>,
<tf.Variable 'd_w4:0' shape=(1024, 1) dtype=float32_ref>,
<tf.Variable 'd_b4:0' shape=(1,) dtype=float32_ref>] [<tf.Variable 'g_w1:0' shape=(100, 3136) dtype=float32_ref>,
<tf.Variable 'g_b1:0' shape=(3136,) dtype=float32_ref>,
<tf.Variable 'g_w2:0' shape=(3, 3, 1, 50) dtype=float32_ref>,
<tf.Variable 'g_b2:0' shape=(50,) dtype=float32_ref>,
<tf.Variable 'g_w3:0' shape=(3, 3, 50, 25) dtype=float32_ref>,
<tf.Variable 'g_b3:0' shape=(25,) dtype=float32_ref>,
<tf.Variable 'g_w4:0' shape=(1, 1, 25, 1) dtype=float32_ref>,
<tf.Variable 'g_b4:0' shape=(1,) dtype=float32_ref>]

之后是训练过程:

"""迭代训练"""

sess = tf.Session()
sess.run(tf.global_variables_initializer()) # 对discriminator的预训练
for i in range(300):
print('.',end='')
z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions])
real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1])
# 用real and fake images分别对discriminator训练
_, __, dLossReal, dLossFake = sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake],
{x_placeholder: real_image_batch, z_placeholder: z_batch}) if (i % 100 == 0):
print("\rdLossReal:",dLossReal,"dLossFake:",dLossFake) # 交替训练 generator和discriminator
for i in range(100000):
print('.',end='')
real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1])
z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions]) # 用real and fake images同时对discriminator训练
_,dLossReal,dLossFake = sess.run([d_trainer,d_loss_real,d_loss_fake],
{x_placeholder: real_image_batch,z_placeholder: z_batch})
# 训练generator
z_batch = np.random.normal(0,1,size=[batch_size,z_dimensions])
_ = sess.run(g_trainer,feed_dict={z_placeholder: z_batch}) if i % 100 == 0:
# 每 100 iterations, 输出一个生成的图像
print("\rIteration:",i,"at",datetime.datetime.now())
z_batch = np.random.normal(0,1,size=[1,z_dimensions])
generated_images = generator(z_placeholder,1,z_dimensions, reuse=True)
images = sess.run(generated_images,{z_placeholder: z_batch})
plt.imshow(images[0].reshape([28,28]),cmap='Greys')
plt.show()
# 输出discriminator的值
im = images[0].reshape([1,28,28,1])
result = discriminator(x_placeholder, reuse=True)
estimate = sess.run(result,{x_placeholder: im})
print("Estimate:",np.squeeze(estimate))

先预训练分辨器,

然后交替训练分辨器和生成器。

其实是有一点图片可以展示的,但是我的电脑性能太渣(苏菲4),跑了600轮左右的迭代我实在于心不忍了,先搁置吧... 以后有机会回实验室在说,至少原理是体会到了。

共享变量

『TensorFlow』线程控制器类&变量作用域理解加深

之前看文档时体会不深,现在大体明白共享变量的存在意义了,它是在设计计算图时考虑的:

同一个变量如果有不同的数据流(计算图中不同的节点在不同的时刻去给同一个节点的同一个输入位置提供数据),

  • Variable变量会之间创建两个不同的变量节点去接收不同的数据流
  • get_variable变量在reuse为True时会使用同一个变量应付不同的数据流

这也就是共享变量的应用之处。这在上面的程序中体现在判别器的任务,如果接收到的是生成器生成的图像,判别器就尝试优化自己的网络结构来使自己输出0,如果接收到的是来自真实数据的图像,那么就尝试优化自己的网络结构来使自己输出1。也就是说,fake图像和real图像经过判别器的时候,要共享同一套变量,所以TensorFlow引入了变量共享机制,而和正常的卷积网络不同的是这里的fake和real变量并不在同一个计算图节点位置(real图片在x节点处输入,而fake图则在生成器输出节点位置计入计算图)。

『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上的更多相关文章

  1. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  2. 『cs231n』通过代码理解风格迁移

    『cs231n』卷积神经网络的可视化应用 文件目录 vgg16.py import os import numpy as np import tensorflow as tf from downloa ...

  3. 『cs231n』RNN之理解LSTM网络

    概述 LSTM是RNN的增强版,1.RNN能完成的工作LSTM也都能胜任且有更好的效果:2.LSTM解决了RNN梯度消失或爆炸的问题,进而可以具有比RNN更为长时的记忆能力.LSTM网络比较复杂,而恰 ...

  4. 『cs231n』计算机视觉基础

    线性分类器损失函数明细: 『cs231n』线性分类器损失函数 最优化Optimiz部分代码: 1.随机搜索 bestloss = float('inf') # 无穷大 for num in range ...

  5. 『cs231n』作业3问题2选讲_通过代码理解LSTM网络

    LSTM神经元行为分析 LSTM 公式可以描述如下: itftotgtctht=sigmoid(Wixxt+Wihht−1+bi)=sigmoid(Wfxxt+Wfhht−1+bf)=sigmoid( ...

  6. 『cs231n』作业2选讲_通过代码理解Dropout

    Dropout def dropout_forward(x, dropout_param): p, mode = dropout_param['p'], dropout_param['mode'] i ...

  7. 『cs231n』作业3问题3选讲_通过代码理解图像梯度

    Saliency Maps 这部分想探究一下 CNN 内部的原理,参考论文 Deep Inside Convolutional Networks: Visualising Image Classifi ...

  8. 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练

    一份不错的作业3资料(含答案) RNN神经元理解 单个RNN神经元行为 括号中表示的是维度 向前传播 def rnn_step_forward(x, prev_h, Wx, Wh, b): " ...

  9. 『cs231n』作业2选讲_通过代码理解优化器

    1).Adagrad一种自适应学习率算法,实现代码如下: cache += dx**2 x += - learning_rate * dx / (np.sqrt(cache) + eps) 这种方法的 ...

随机推荐

  1. ZOJ 4027 Sequence Swapping(DP)题解

    题意:一串括号,每个括号代表一个值,当有相邻括号组成()时,可以交换他们两个并得到他们值的乘积,问你最大能得到多少 思路:DP题,注定想得掉头发. 显然一个左括号( 的最远交换距离由他右边的左括号的最 ...

  2. (转) Supercharging Style Transfer

      Supercharging Style Transfer Wednesday, October 26, 2016 Posted by Vincent Dumoulin*, Jonathon Shl ...

  3. 【.Net】结合项目谈谈多线程

    提到多线程, 大家都知道, 在进程中启用多个线程进行工作, 会提升程序的效率等等. 本篇文章旨在解释多线程的基础概念之外, 还要结合实际的项目来谈多线程的具体使用. Thread 我们知道启动一个线程 ...

  4. Twitter开发2

    There are different API families The standard (free) Twitter APIs consist of REST APIs and Streaming ...

  5. IDEA入门级使用教程----你怎么还在用eclipse?

    http://blog.csdn.net/qq_31655965/article/details/52788374

  6. 键盘控制div移动并且解决停顿问题(原生js)

    <html> <head> <title>键盘控制div移动,解决停顿问题</title> <meta charset="utf-8&q ...

  7. 理解 Redis(8) - Ordered set 值

    ordered set 是根据 score值有序排列的数据集合. 首先还是清空数据, 并清屏, 此步骤省略~~~~ 新建一条 ordered set 数据 myset1, 并存入4个字符串, scor ...

  8. [原]windows sdk版本不对

    系统硬盘换了,重新安装一堆软件,SVN. 之前的SVN地址直接能找到 在编译vs项目的时候出现问题: windows sdk 10.0.14393.0 版本找不到 发现自己按照vs时候更新不了最新sd ...

  9. 学习笔记41—ttest误区

    1.grapPad软件里面双T结果和matlab,EXCEl里面双T结果一致时,设置如下:

  10. JS self=this

    1.每个函数都会有自己的this和arguments:this对象绑定运行环境,arguments绑定调用参数. 2.全局函数:this和全局环境绑定,浏览器指向全局window对象(node.js中 ...