前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条件的 GAN,和不加约束条件的GAN,我们先来搭建一个简单的 MNIST 数据集上加约束条件的 GAN。

首先下载数据:在  /home/your_name/TensorFlow/DCGAN/ 下建立文件夹 data/mnist,从 http://yann.lecun.com/exdb/mnist/ 网站上下载 mnist 数据集 train-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gzt10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gz 到 mnist 文件夹下得到四个 .gz 文件。

数据下载好之后,在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 read_data.py 读取数据,输入如下代码:

  1. import os
  2. import numpy as np
  3.  
  4. def read_data():
  5.  
  6. # 数据目录
  7. data_dir = '/home/your_name/TensorFlow/DCGAN/data/mnist'
  8.  
  9. # 打开训练数据
  10. fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
  11. # 转化成 numpy 数组
  12. loaded = np.fromfile(file=fd,dtype=np.uint8)
  13. # 根据 mnist 官网描述的数据格式,图像像素从 16 字节开始
  14. trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)
  15.  
  16. # 训练 label
  17. fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
  18. loaded = np.fromfile(file=fd,dtype=np.uint8)
  19. trY = loaded[8:].reshape((60000)).astype(np.float)
  20.  
  21. # 测试数据
  22. fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
  23. loaded = np.fromfile(file=fd,dtype=np.uint8)
  24. teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)
  25.  
  26. # 测试 label
  27. fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
  28. loaded = np.fromfile(file=fd,dtype=np.uint8)
  29. teY = loaded[8:].reshape((10000)).astype(np.float)
  30.  
  31. trY = np.asarray(trY)
  32. teY = np.asarray(teY)
  33.  
  34. # 由于生成网络由服从某一分布的噪声生成图片,不需要测试集,
  35. # 所以把训练和测试两部分数据合并
  36. X = np.concatenate((trX, teX), axis=0)
  37. y = np.concatenate((trY, teY), axis=0)
  38.  
  39. # 打乱排序
  40. seed = 547
  41. np.random.seed(seed)
  42. np.random.shuffle(X)
  43. np.random.seed(seed)
  44. np.random.shuffle(y)
  45.  
  46. # 这里,y_vec 表示对网络所加的约束条件,这个条件是类别标签,
  47. # 可以看到,y_vec 实际就是对 y 的独热编码,关于什么是独热编码,
  48. # 请参考 http://www.cnblogs.com/Charles-Wan/p/6207039.html
  49. y_vec = np.zeros((len(y), 10), dtype=np.float)
  50. for i, label in enumerate(y):
  51. y_vec[i,y[i]] = 1.0
  52.  
  53. return X/255., y_vec

这里顺便说明一下,由于 MNIST 数据总体占得内存不大(可以看下载的文件,最大的一个 45M 左右,)所以这样读取数据是允许的,一般情况下,数据特别庞大的时候,建议把数据转化成 tfrecords,用 TensorFlow 标准的数据读取格式,这样能带来比较高的效率。

然后,定义一些基本的操作层,例如卷积,池化,全连接等层,在 /home/your_name/TensorFlow/DCGAN/ 新建文件 ops.py,输入如下代码:

  1. import tensorflow as tf
  2. from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
  3.  
  4. # 常数偏置
  5. def bias(name, shape, bias_start = 0.0, trainable = True):
  6.  
  7. dtype = tf.float32
  8. var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
  9. initializer = tf.constant_initializer(
  10. bias_start, dtype = dtype))
  11. return var
  12.  
  13. # 随机权重
  14. def weight(name, shape, stddev = 0.02, trainable = True):
  15.  
  16. dtype = tf.float32
  17. var = tf.get_variable(name, shape, tf.float32, trainable = trainable,
  18. initializer = tf.random_normal_initializer(
  19. stddev = stddev, dtype = dtype))
  20. return var
  21.  
  22. # 全连接层
  23. def fully_connected(value, output_shape, name = 'fully_connected', with_w = False):
  24.  
  25. shape = value.get_shape().as_list()
  26.  
  27. with tf.variable_scope(name):
  28. weights = weight('weights', [shape[1], output_shape], 0.02)
  29. biases = bias('biases', [output_shape], 0.0)
  30.  
  31. if with_w:
  32. return tf.matmul(value, weights) + biases, weights, biases
  33. else:
  34. return tf.matmul(value, weights) + biases
  35.  
  36. # Leaky-ReLu 层
  37. def lrelu(x, leak=0.2, name = 'lrelu'):
  38.  
  39. with tf.variable_scope(name):
  40. return tf.maximum(x, leak*x, name = name)
  41.  
  42. # ReLu 层
  43. def relu(value, name = 'relu'):
  44. with tf.variable_scope(name):
  45. return tf.nn.relu(value)
  46.  
  47. # 解卷积层
  48. def deconv2d(value, output_shape, k_h = 5, k_w = 5, strides =[1, 2, 2, 1],
  49. name = 'deconv2d', with_w = False):
  50.  
  51. with tf.variable_scope(name):
  52. weights = weight('weights',
  53. [k_h, k_w, output_shape[-1], value.get_shape()[-1]])
  54. deconv = tf.nn.conv2d_transpose(value, weights,
  55. output_shape, strides = strides)
  56. biases = bias('biases', [output_shape[-1]])
  57. deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
  58. if with_w:
  59. return deconv, weights, biases
  60. else:
  61. return deconv
  62.  
  63. # 卷积层
  64. def conv2d(value, output_dim, k_h = 5, k_w = 5,
  65. strides =[1, 2, 2, 1], name = 'conv2d'):
  66.  
  67. with tf.variable_scope(name):
  68. weights = weight('weights',
  69. [k_h, k_w, value.get_shape()[-1], output_dim])
  70. conv = tf.nn.conv2d(value, weights, strides = strides, padding = 'SAME')
  71. biases = bias('biases', [output_dim])
  72. conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
  73.  
  74. return conv
  75.  
  76. # 把约束条件串联到 feature map
  77. def conv_cond_concat(value, cond, name = 'concat'):
  78.  
  79. # 把张量的维度形状转化成 Python 的 list
  80. value_shapes = value.get_shape().as_list()
  81. cond_shapes = cond.get_shape().as_list()
  82.  
  83. # 在第三个维度上(feature map 维度上)把条件和输入串联起来,
  84. # 条件会被预先设为四维张量的形式,假设输入为 [64, 32, 32, 32] 维的张量,
  85. # 条件为 [64, 32, 32, 10] 维的张量,那么输出就是一个 [64, 32, 32, 42] 维张量
  86. with tf.variable_scope(name):
  87. return tf.concat(3, [value,
  88. cond * tf.ones(value_shapes[0:3] + cond_shapes[3:])])
  89.  
  90. # Batch Normalization 层
  91. def batch_norm_layer(value, is_train = True, name = 'batch_norm'):
  92.  
  93. with tf.variable_scope(name) as scope:
  94. if is_train:
  95. return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
  96. is_training = is_train,
  97. updates_collections = None, scope = scope)
  98. else:
  99. return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True,
  100. is_training = is_train, reuse = True,
  101. updates_collections = None, scope = scope)

TensorFlow 里使用 Batch Normalization 层,有很多种方法,这里我们直接使用官方 contrib 里面的层,其中 decay 指的是滑动平均的 decay,epsilon 作用是加到分母 variance 上避免分母为零,scale 是个布尔变量,如果为真值 True, 结果要乘以 gamma,否则 gamma 不使用,is_train 也是布尔变量,为真值代表训练过程,否则代表测试过程(在 BN 层中,训练过程和测试过程是不同的,具体请参考论文:https://arxiv.org/abs/1502.03167)。关于 batch_norm 的其他的参数,请看参考文献2。

参考文献:

1. https://github.com/carpedm20/DCGAN-tensorflow

2. https://github.com/tensorflow/tensorflow/blob/b826b79718e3e93148c3545e7aa3f90891744cc0/tensorflow/contrib/layers/python/layers/layers.py#L100

不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作的更多相关文章

  1. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  2. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  3. GAN生成式对抗网络(一)——原理

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...

  4. 不要怂,就是GAN (生成式对抗网络) (一)

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  5. 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介

    前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...

  6. 不要怂,就是GAN (生成式对抗网络) (二)

    前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...

  7. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  8. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  9. 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

    在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...

随机推荐

  1. test20181019 B君的第二题

    题意 分析 快速子集和变换以及快速超集和变换的裸题. 用\(f(s)\)表示集合s的方案数,初始化为输入中s出现的次数. 做一遍快速子集和变换,此时f(s)表示s及其子集在输入中出现的次数. 对所有f ...

  2. StreamSets 部署 Pipelines 到 SDC Edge

    可以使用如下方法: 下载edge 运行包并包含pipeline定义文件. 直接发布到edge 设备. 在data colelctor 机器配置并配置了edge server 地址(主要需要网络可访问) ...

  3. linq to sql 怎么查询前 11 条数据

    (from 新表 in db.books where 新表.bookid < 400 select 新表).Take(11); storeDB.Albums.OrderByDescending( ...

  4. [转]无网络环境,在Windows Server 2008 R2和SQL Server 2008R2环境安装SharePoint2013 RT

    无网络环境,在Windows Server 2008 R2和SQL Server 2008R2环境安装SharePoint2013 RT,这个还有点麻烦,所以记录一下,下次遇到省得绕弯路.进入正题: ...

  5. linux 使用中括号进行条件判断

       格式 “#”代表空格,不可缺少 [# param1#op# param2 #] 这种带比较操作符的形式,op左右必须使用空格隔开. 如 [# “3”==”2” #]  这种缺少空格的写法会得到结 ...

  6. SQL中利用脚本创建database mail.

    SQL中利用脚本创建database mail   编写人:CC阿爸 2014-6-14 多话不讲,请参考以下脚本 use  

  7. linux CentOS 安装rz和sz命令 lrzsz 实现windows和linux之间的文件上传 下载

    https://blog.nbhao.org/1902.html https://bbs.csdn.net/topics/391989523 https://www.cnblogs.com/zhoul ...

  8. python 多态、多继承、函数重写、迭代器

    用于类的函数 issubclass(cls,class_or_tuple) 判断一个类是否继承自其他的类,如果此类cls是class或tuole中的一个派生(子类)则返回True,否则返回False ...

  9. ubuntu 下出现E: Sub-process /usr/bin/dpkg returned an error code

    在用apt-get安装软件时出现了类似于 install-info: No dir file specified; try –help for more information.dpkg:处理 get ...

  10. MPI 并行奇偶交换排序 + 集合通信函数 Sendrecv() Sendvecv_replace()

    ▶ <并行程序设计导论>第三章的例子程序 ● 代码 #include <stdio.h> #include <mpi.h> #include <stdlib. ...