深度卷积生成对抗网络(DCGAN)

---- 生成 MNIST 手写图片

1、基本原理

生成对抗网络(GAN)由2个重要的部分构成:

  • 生成器(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器
  • 判别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器做的“假数据”

训练过程

    1. 固定判别器,让生成器不断生成假数据,给判别器判别,开始生成器很弱,但是随着不断的训练,生成器不断提升,最终骗过判别器。此时判别器判断假数据的概率为50%
    1. 固定生成器,训练判别器。判别器经过训练,提高鉴别能力,最终能准确判断虽有的假图片
    1. 循环上两个阶段,最终生成器和判别器都越来越强。然后就可以使用生成器来生成我们想要的图片了

2、相关数学原理

  • 判别器在这里是一种分类器,用于区分样本的真伪,因此我们常常使用交叉熵(cross entropy)来进行判别分布的相似性

\[H(p, q) := -\sum_i p_i \log q_i
\]

公式中 \(p_i\) 和 \(q_i\) 为真实的样本分布和生成器的生成分布

假定 \(y_1\) 为正确样本分布,那么对应的( \(1-y_1\) )就是生成样本的分布。\(D\) 表示判别器,则 \(D(x_1)\) 表示判别样本为正确的概率, \(1-D(x_1)\) 则对应着判别为错误样本的概率。则有如下式子(这里仅仅是对当前情况下的交叉熵损失的具体化)。

\[H((x_i, y_i)_{i=1}^N, D) = - \sum_{i=1}^N y_i\log D(x_i) - \sum_{i=1}^N(1-y_i)\log (1 - D(x_i))
\]

对于GAN中的样本点 \(x_i\) ,对应于两个出处,要么来自于真实样本,要么来自于生成器生成的样本 $\tilde{x} - G(z) $ ( 这里的 \(z\) 是服从于投到生成器中噪声的分布)。

对于来自于真实的样本,我们要判别为正确的分布 \(y_i\) 。来自于生成的样本我们要判别其为错误分布( \(1-y_i\) )。将上面式子进一步使用概率分布的期望形式写出(为了表达无限的样本情况,相当于无限样本求和情况),并且让 \(y_i\) 为 1/2 且使用 \(G(z)\) 表示生成样本可以得到如下公式:

\[H \left( (x_i, y_i)_{i=1}^\infty, D \right) = -\frac{1}{2}E_{x-p_{data}}\left[ \log D(x) \right] - \frac{1}{2}E_z\left[ \log (1-D(G(z))) \right] \\\
GAN损失函数期望形式
\]

对于论文中的公式

\[min_G max_D V(D, G) = E_{x-p_{data}(x)}\left[ \log D(x) \right] + E_{z-p_z(z)}\left[ \log (1-D(G(z))) \right] \\\
GAN损失函数的 min max表达
\]

其实是与上面公式一样的,下面做解释

  • 这里的 \(V(D, G)\) 相当于表示真实样本和生成样本的差异程度。
  • \(max_D V(D, G)\) 的意思是固定生成器 \(G\), 尽可能地让判别器能够最大化地判别出样本来自于真实数据还是生成的数据。
  • 再将后面的 $L = max_D V(D, G) $ 看成整体,对于 \(min_G L\)这里是在固定判别器\(D\)的条件下得到生成器 \(G\),这个 \(G\) 要求能够最小化真实样本与生成样本的差异。
  • 通过上述 \(min\) \(max\) 的博弈过程,理想情况下会收敛于生成分布拟合于真实分布。

3、卷积对抗生成网络

卷积对抗生成网络(DCGAN)是在GAN的基础上加入了CNN,主要是改进了网络结构,在训练过程中状态稳定,并且可以有效实现高质量图片的生成以及相关的生成模型应用。DCGAN的生成器网络结构如下图:

DCGAN的改进:

  • 使用步长卷积代替上采样层,卷积在提取图像特征上具有很好的作用,并且使用卷积代替全连接层
  • 生成器G和判别器D中几乎每一层都使用batchnorm层,将特征层的输出归一化到一起,加速了训练,提升了训练的稳定性。
  • 在判别器中使用leakrelu激活函数,而不是RELU,防止梯度稀疏,生成器中仍然采用relu,但是输出层采用tanh。

4、DCGAN代码实现

  1. shenduimport numpy as np
  2. import matplotlib.pyplot as plt
  3. import tensorflow as tf
  4. from tensorflow import keras
  5. from tensorflow.keras import optimizers, losses, layers, Sequential, Model
  1. class DCGAN():
  2. '''
  3. 实现深度对抗神经网络
  4. 生成 MNIST 手写数字图片
  5. 输入的噪声为服从正态分布均值为 0 方差为 1 的分布, shape:(None, 100)
  6. 生成器(G)输入 噪声, 输出为 (None, 28, 28, 1)的图片
  7. 分类器(D)输入为 (None, 28, 28, 1)的图片,输出图片的分类真假
  8. '''
  9. def __init__(self):
  10. self.img_rows = 28
  11. self.img_cols = 28
  12. self.channels = 1
  13. self.img_shape = (self.img_rows, self.img_cols, self.channels)
  14. optimizer = optimizers.Adam(0.0002)
  15. # 构建编译分类器
  16. self.discriminator = self.build_discriminator()
  17. self.discriminator.compile(loss='binary_crossentropy',
  18. optimizer=optimizer,
  19. metrics=['accuracy'])
  20. # 构建编译生成器
  21. self.generator = self.build_generator()
  22. self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)
  23. # 生成器输入为噪音,生成图片
  24. z = layers.Input(shape=(100,))
  25. img = self.generator(z)
  26. # 对于整个对抗网络模型只优化生成器的参数
  27. self.discriminator.trainable = False
  28. # 用生成的图片输入分类器判断
  29. valid = self.discriminator(img)
  30. # 对于整个对抗网络 输入噪音 => 生成图片 => 决定图片是否有效
  31. self.combined = Model(z, valid)
  32. self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
  33. def build_generator(self):
  34. '''
  35. 构建生成器
  36. '''
  37. noise_shape = (100,)
  38. model = tf.keras.Sequential()
  39. # 添加全连接层
  40. model.add(layers.Dense(7*7*256, use_bias=False, input_shape=noise_shape))
  41. # 添加 BatchNormalization 层,对数据进行归一化
  42. model.add(layers.BatchNormalization())
  43. model.add(layers.LeakyReLU())
  44. model.add(layers.Reshape((7, 7, 256)))
  45. # 添加逆卷积层,卷积核大小为 5X5,数量 128, 步长为 1
  46. model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
  47. assert model.output_shape == (None, 7, 7, 128)
  48. model.add(layers.BatchNormalization())
  49. model.add(layers.LeakyReLU())
  50. # 添加逆卷积层,卷积核大小为 5X5,数量 64, 步长为 2
  51. model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
  52. assert model.output_shape == (None, 14, 14, 64)
  53. model.add(layers.BatchNormalization())
  54. model.add(layers.LeakyReLU())
  55. # 添加逆卷积层,卷积核大小为 5X5,数量 1, 步长为 2
  56. model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
  57. assert model.output_shape == (None, 28, 28, 1)
  58. model.summary()
  59. noise = layers.Input(shape=noise_shape)
  60. img = model(noise)
  61. # 返回 Model 对象,输入为 噪声, 输出为 图像
  62. return keras.Model(noise, img)
  63. def build_discriminator(self):
  64. '''
  65. 构建分类器
  66. '''
  67. img_shape = (self.img_rows, self.img_cols, self.channels)
  68. model = tf.keras.Sequential()
  69. model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
  70. input_shape=img_shape))
  71. model.add(layers.LeakyReLU())
  72. # 添加 Dropout 层,减少参数数量
  73. model.add(layers.Dropout(0.3))
  74. model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
  75. model.add(layers.LeakyReLU())
  76. model.add(layers.Dropout(0.3))
  77. # 把数据铺平
  78. model.add(layers.Flatten())
  79. model.add(layers.Dense(1))
  80. model.summary()
  81. img = layers.Input(shape=img_shape)
  82. validity = model(img)
  83. return keras.Model(img, validity)
  84. def train(self, epochs, batch_size=128, save_interval=50):
  85. '''
  86. 网络训练
  87. '''
  88. # 加载 数据集
  89. (X_train, _), (_, _) = keras.datasets.mnist.load_data()
  90. # 把数据缩放到 [-1, 1]
  91. X_train = (X_train.astype(np.float32) - 127.5) / 127.5
  92. # 添加通道维度
  93. X_train = np.expand_dims(X_train, axis=3)
  94. half_batch = int(batch_size / 2)
  95. for epoch in range(epochs):
  96. # ---------------------
  97. # 训练分类器
  98. # ---------------------
  99. # 随机的选择一半的 batch 数量图片
  100. idx = np.random.randint(0, X_train.shape[0], half_batch)
  101. imgs = X_train[idx]
  102. noise = np.random.normal(0, 1, (half_batch, 100))
  103. # 生成一半 batch 数量的 图片
  104. gen_imgs = self.generator.predict(noise)
  105. # 分类器损失
  106. d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
  107. d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
  108. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  109. # ---------------------
  110. # 训练生成器
  111. # ---------------------
  112. noise = np.random.normal(0, 1, (batch_size, 100))
  113. # The generator wants the discriminator to label the generated samples
  114. # as valid (ones)
  115. # 对于生成器,希望分类器把更多的图片判为 有效 (用 1 表示)
  116. valid_y = np.array([1] * batch_size)
  117. # 训练生成器
  118. g_loss = self.combined.train_on_batch(noise, valid_y)
  119. # 打印训练进度
  120. print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
  121. # 每个 save_interval 周期保存一张图片
  122. if epoch % save_interval == 0:
  123. self.save_imgs(epoch)
  124. def save_imgs(self, epoch):
  125. r, c = 5, 5
  126. noise = np.random.normal(0, 1, (r * c, 100))
  127. gen_imgs = self.generator.predict(noise)
  128. # 把图片数据缩放到 0 - 1
  129. gen_imgs = 0.5 * gen_imgs + 0.5
  130. fig, axs = plt.subplots(r, c)
  131. cnt = 0
  132. for i in range(r):
  133. for j in range(c):
  134. axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
  135. axs[i,j].axis('off')
  136. cnt += 1
  137. fig.savefig("dcgan/images/mnist_%d.png" % epoch)
  138. plt.close()
  139. if __name__ == '__main__':
  140. dcgan = DCGAN()
  141. dcgan.train(epochs=10000, batch_size=32, save_interval=200)

网络参数信息



5、训练结果

下面是循环了 10000 次 epoch 后,从开始每隔 2000 个 epoch 生成器生成的图片

  • 可以看到,刚开始全部都是噪声,随着训练的进行,图片逐渐清晰

  • 生成的图片还是不太清晰,一方面的原因是我训练的 epoch 周期太少,因为自己电脑性能问题,太耗时间,所以训练的epoch 周期少,如果有条件后提高训练周期应该会好很多。另一方面或许因为我构建的网络还有不合理之,后期还需要改进。









卷积生成对抗网络(DCGAN)---生成手写数字的更多相关文章

  1. 使用TensorFlow的卷积神经网络识别自己的单个手写数字,填坑总结

    折腾了几天,爬了大大小小若干的坑,特记录如下.代码在最后面. 环境: Python3.6.4 + TensorFlow 1.5.1 + Win7 64位 + I5 3570 CPU 方法: 先用MNI ...

  2. 图片训练:使用卷积神经网络(CNN)识别手写数字

    这篇文章中,我们将使用CNN构建一个Tensorflow.js模型来分辨手写的数字.首先,我们通过使之“查看”数以千计的数字图片以及他们对应的标识来训练分辨器.然后我们再通过此模型从未“见到”过的测试 ...

  3. tensorflow学习之(十)使用卷积神经网络(CNN)分类手写数字0-9

    #卷积神经网络cnn import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #数据包,如 ...

  4. GAN实战笔记——第四章深度卷积生成对抗网络(DCGAN)

    深度卷积生成对抗网络(DCGAN) 我们在第3章实现了一个GAN,其生成器和判别器是具有单个隐藏层的简单前馈神经网络.尽管很简单,但GAN的生成器充分训练后得到的手写数字图像的真实性有些还是很具说服力 ...

  5. 生成对抗网络(Generative Adversarial Networks,GAN)初探

    1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...

  6. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  7. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  8. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 07 训练Tensorflow识别手写数字

    打开Python Shell,输入以下代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input ...

随机推荐

  1. throttle和debounce

    遇到的问题 在开发过程中会遇到频率很高的事件或者连续的事件,如果不进行性能的优化,就可能会出现页面卡顿的现象,比如: 鼠标事件:mousemove(拖曳)/mouseover(划过)/mouseWhe ...

  2. vue v-for 渲染input 输入有问题 解决方案

    v-for循环input标签的时候输入信息两个输入框一同显示输入信息 解决方案: <input :placeholder="items.title" v-model = &q ...

  3. 数学分析新讲(1) NOTE

    前言:无聊才翻翻看看来复习啦..所以慢更(●'◡'●) 1.利用求和公式的性质推导: \[\sum^{n}_{k=1}k=n \] \[\sum^{n}_{k=1}k^2=\frac{n(n+1)(2 ...

  4. 详细讲解使用Sublime Text 3进行Markdown编辑和实时预览

    所需安装的插件 Markdown Editing // Markdown编辑和语法高亮 Markdown Preview// Markdown导出html预览 LiveReload// 时时预览 安装 ...

  5. 【python代码】 最大流问题+最小花费问题+python(ortool库)实现

    目录 基本概念 图 邻接矩阵 最大流问题 python解决最大流问题 python解决最大流最小费用问题 基本概念 图 定义: 图G(V,E)是指一个二元组(V(G),E(G)),其中: V(G)={ ...

  6. Django之url反向解析

    在urls.py文件中,在进行url映射时,为请求的url命个名,以便在模板页面或者views.py视图中可以进行反向解析,同时在修改了url映射的请求路径,名称不变的情况下,不再修改模板页面或者视图 ...

  7. php IE中文乱码

    echo mb_convert_encoding("你是我的朋友", "big5", "GB2312"); 详细出处参考:http://ww ...

  8. BZOJ1066 网络流

    拆点,将一个柱子拆成入点和出点,入点出点之间的容量就是柱子的容量    1066: [SCOI2007]蜥蜴 在一个r行c列的网格地图中有一些高度不同的石柱,一些石柱上站着一些蜥蜴,你的任务是让尽量多 ...

  9. 移动端在ios上以及微信浏览器上的兼容性

    1.document.以及window.body在移动h5不能触发点击事件 解决方法:给body加上cursor: pointer;就可以有点击事件了. ios上默认的body是没有点击事件的: 接着 ...

  10. 【Java】几种典型的内存溢出案例,都在这儿了!

    写在前面 作为程序员,多多少少都会遇到一些内存溢出的场景,如果你还没遇到,说明你工作的年限可能比较短,或者你根本就是个假程序员!哈哈,开个玩笑.今天,我们就以Java代码的方式来列举几个典型的内存溢出 ...