前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条件的 GAN,和不加约束条件的GAN,我们先来搭建一个简单的 MNIST 数据集上加约束条件的 GAN。

首先下载数据:在  /home/your_name/TensorFlow/DCGAN/ 下建立文件夹 data/mnist,从 http://yann.lecun.com/exdb/mnist/ 网站上下载 mnist 数据集 train-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gzt10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gz 到 mnist 文件夹下得到四个 .gz 文件。

数据下载好之后,在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 read_data.py 读取数据,输入如下代码:

  1. import os
  2. import numpy as np
  3.  
  4. def read_data():
  5.  
  6. # 数据目录
  7. data_dir = '/home/your_name/TensorFlow/DCGAN/data/mnist'
  8.  
  9. # 打开训练数据
  10. fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
  11. # 转化成 numpy 数组
  12. loaded = np.fromfile(file=fd,dtype=np.uint8)
  13. # 根据 mnist 官网描述的数据格式,图像像素从 16 字节开始
  14. trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)
  15.  
  16. # 训练 label
  17. fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
  18. loaded = np.fromfile(file=fd,dtype=np.uint8)
  19. trY = loaded[8:].reshape((60000)).astype(np.float)
  20.  
  21. # 测试数据
  22. fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
  23. loaded = np.fromfile(file=fd,dtype=np.uint8)
  24. teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)
  25.  
  26. # 测试 label
  27. fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
  28. loaded = np.fromfile(file=fd,dtype=np.uint8)
  29. teY = loaded[8:].reshape((10000)).astype(np.float)
  30.  
  31. trY = np.asarray(trY)
  32. teY = np.asarray(teY)
  33.  
  34. # 由于生成网络由服从某一分布的噪声生成图片,不需要测试集,
  35. # 所以把训练和测试两部分数据合并
  36. X = np.concatenate((trX, teX), axis=0)
  37. y = np.concatenate((trY, teY), axis=0)
  38.  
  39. # 打乱排序
  40. seed = 547
  41. np.random.seed(seed)
  42. np.random.shuffle(X)
  43. np.random.seed(seed)
  44. np.random.shuffle(y)
  45.  
  46. # 这里,y_vec 表示对网络所加的约束条件,这个条件是类别标签,
  47. # 可以看到,y_vec 实际就是对 y 的独热编码,关于什么是独热编码,
  48. # 请参考 http://www.cnblogs.com/Charles-Wan/p/6207039.html
  49. y_vec = np.zeros((len(y), 10), dtype=np.float)
  50. for i, label in enumerate(y):
  51. y_vec[i,y[i]] = 1.0
  52.  
  53. return X/255., y_vec

这里顺便说明一下,由于 MNIST 数据总体占得内存不大(可以看下载的文件,最大的一个 45M 左右,)所以这样读取数据是允许的,一般情况下,数据特别庞大的时候,建议把数据转化成 tfrecords,用 TensorFlow 标准的数据读取格式,这样能带来比较高的效率。

然后,定义一些基本的操作层,例如卷积,池化,全连接等层,在 /home/your_name/TensorFlow/DCGAN/ 新建文件 ops.py,输入如下代码:

  1. import tensorflow as tf
  2. from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
  3.  
  4. # 常数偏置
  5. def bias(name, shape, bias_start = 0.0, trainable = True):
  6.  
  7. dtype = tf.float32
  8. var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
  9. initializer = tf.constant_initializer(
  10. bias_start, dtype = dtype))
  11. return var
  12.  
  13. # 随机权重
  14. def weight(name, shape, stddev = 0.02, trainable = True):
  15.  
  16. dtype = tf.float32
  17. var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
  18. initializer = tf.random_normal_initializer(
  19. stddev = stddev, dtype = dtype))
  20. return var
  21.  
  22. # 全连接层
  23. def fully_connected(value, output_shape, name = 'fully_connected', with_w = False):
  24.  
  25. shape = value.get_shape().as_list()
  26.  
  27. with tf.variable_scope(name):
  28. weights = weight('weights', [shape[1], output_shape], 0.02)
  29. biases = bias('biases', [output_shape], 0.0)
  30.  
  31. if with_w:
  32. return tf.matmul(value, weights) + biases, weights, biases
  33. else:
  34. return tf.matmul(value, weights) + biases
  35.  
  36. # Leaky-ReLu 层
  37. def lrelu(x, leak=0.2, name = 'lrelu'):
  38.  
  39. with tf.variable_scope(name):
  40. return tf.maximum(x, leak*x, name = name)
  41.  
  42. # ReLu 层
  43. def relu(value, name = 'relu'):
  44. with tf.variable_scope(name):
  45. return tf.nn.relu(value)
  46.  
  47. # 解卷积层
  48. def deconv2d(value, output_shape, k_h = 5, k_w = 5, strides =[1, 2, 2, 1],
  49. name = 'deconv2d', with_w = False):
  50.  
  51. with tf.variable_scope(name):
  52. weights = weight('weights',
  53. [k_h, k_w, output_shape[-1], value.get_shape()[-1]])
  54. deconv = tf.nn.conv2d_transpose(value, weights,
  55. output_shape, strides = strides)
  56. biases = bias('biases', [output_shape[-1]])
  57. deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
  58. if with_w:
  59. return deconv, weights, biases
  60. else:
  61. return deconv
  62.  
  63. # 卷积层
  64. def conv2d(value, output_dim, k_h = 5, k_w = 5,
  65. strides =[1, 2, 2, 1], name = 'conv2d'):
  66.  
  67. with tf.variable_scope(name):
  68. weights = weight('weights',
  69. [k_h, k_w, value.get_shape()[-1], output_dim])
  70. conv = tf.nn.conv2d(value, weights, strides = strides, padding = 'SAME')
  71. biases = bias('biases', [output_dim])
  72. conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
  73.  
  74. return conv
  75.  
  76. # 把约束条件串联到 feature map
  77. def conv_cond_concat(value, cond, name = 'concat'):
  78.  
  79. # 把张量的维度形状转化成 Python 的 list
  80. value_shapes = value.get_shape().as_list()
  81. cond_shapes = cond.get_shape().as_list()
  82.  
  83. # 在第三个维度上(feature map 维度上)把条件和输入串联起来,
  84. # 条件会被预先设为四维张量的形式,假设输入为 [64, 32, 32, 32] 维的张量,
  85. # 条件为 [64, 32, 32, 10] 维的张量,那么输出就是一个 [64, 32, 32, 42] 维张量
  86. with tf.variable_scope(name):
  87. return tf.concat(3, [value,
  88. cond * tf.ones(value_shapes[0:3] + cond_shapes[3:])])
  89.  
  90. # Batch Normalization 层
  91. def batch_norm_layer(value, is_train = True, name = 'batch_norm'):
  92.  
  93. with tf.variable_scope(name) as scope:
  94. if is_train:
  95. return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
  96. is_training = is_train,
  97. updates_collections = None, scope = scope)
  98. else:
  99. return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
  100. is_training = is_train, reuse = True,
  101. updates_collections = None, scope = scope)

TensorFlow 里使用 Batch Normalization 层,有很多种方法,这里我们直接使用官方 contrib 里面的层,其中 decay 指的是滑动平均的 decay,epsilon 作用是加到分母 variance 上避免分母为零,scale 是个布尔变量,如果为真值 True, 结果要乘以 gamma,否则 gamma 不使用,is_train 也是布尔变量,为真值代表训练过程,否则代表测试过程(在 BN 层中,训练过程和测试过程是不同的,具体请参考论文:https://arxiv.org/abs/1502.03167)。关于 batch_norm 的其他的参数,请看参考文献2。

参考文献:

1. https://github.com/carpedm20/DCGAN-tensorflow

2. https://github.com/tensorflow/tensorflow/blob/b826b79718e3e93148c3545e7aa3f90891744cc0/tensorflow/contrib/layers/python/layers/layers.py#L100

不要怂,就是GAN (生成式对抗网络) (二)的更多相关文章

  1. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  2. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  3. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  4. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  6. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  7. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  8. 不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  9. 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...

随机推荐

  1. ubuntu配置android开发环境和编译源码遇到的一些问题

    ---------------------------------------------环境变量设置--------------------------------------------- 1.设 ...

  2. WCF服务实现客户端Cookie共享,表单验证的解决方案

    基于前几篇的文章,如果理解了通道 拦截器  服务转发的概念,相信你肯定也能理解咋的玩了. 说白了就是创建客户端的拦截器: 实现接口:IClientMessageInspector. 里面的方法就是客户 ...

  3. MyBatis里json型字段到Java类的映射

    一.简介 我们在用MyBatis里,很多时间有这样一个需求:bean里有个属性是非基本数据类型,在DB存储时我们想存的是json格式的字符串,从DB拿出来时想直接映射成目标类型,也即json格式的字符 ...

  4. Dynamic Binding & Static Binding

    Reference: JavaPoint BeginnerBook What is Binding Connecting a method call to the method body is kno ...

  5. Install and configure sql server 2008 express

    http://www.symantec.com/connect/articles/install-and-configure-sql-server-2008-express

  6. 马士兵 Servlet_JSP(2) JSP源代码)

    1.最简单的JSP HelloWorld.jsp <html>     <head>         <title>Hello</title>     ...

  7. RequireJS入门(一)

    RequireJS由James Burke创建,他也是AMD规范的创始人. RequireJS会让你以不同于往常的方式去写JavaScript.你将不再使用script标签在HTML中引入JS文件,以 ...

  8. poj 3176 Cow Bowling(dp基础)

    Description The cows don't use actual bowling balls when they go bowling. They each take a number (i ...

  9. IOS uitableviewcell 向左滑动删除编辑等

    主要实现这个方法就好了 -(NSArray<UITableViewRowAction *> *)tableView:(UITableView *)tableView editActions ...

  10. (转)iOS Wow体验 - 第六章 - 交互模型与创新的产品概念(2)

    本文是<iOS Wow Factor:Apps and UX Design Techniques for iPhone and iPad>第六章译文精选的第二部分,其余章节将陆续放出.上一 ...