深度卷积生成对抗网络(DCGAN)

---- 生成 MNIST 手写图片

1、基本原理

生成对抗网络(GAN)由2个重要的部分构成:

  • 生成器(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器
  • 判别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器做的“假数据”

训练过程

    1. 固定判别器,让生成器不断生成假数据,给判别器判别,开始生成器很弱,但是随着不断的训练,生成器不断提升,最终骗过判别器。此时判别器判断假数据的概率为50%
    1. 固定生成器,训练判别器。判别器经过训练,提高鉴别能力,最终能准确判断虽有的假图片
    1. 循环上两个阶段,最终生成器和判别器都越来越强。然后就可以使用生成器来生成我们想要的图片了

2、相关数学原理

  • 判别器在这里是一种分类器,用于区分样本的真伪,因此我们常常使用交叉熵(cross entropy)来进行判别分布的相似性

\[H(p, q) := -\sum_i p_i \log q_i
\]

公式中 \(p_i\) 和 \(q_i\) 为真实的样本分布和生成器的生成分布

假定 \(y_1\) 为正确样本分布,那么对应的( \(1-y_1\) )就是生成样本的分布。\(D\) 表示判别器,则 \(D(x_1)\) 表示判别样本为正确的概率, \(1-D(x_1)\) 则对应着判别为错误样本的概率。则有如下式子(这里仅仅是对当前情况下的交叉熵损失的具体化)。

\[H((x_i, y_i)_{i=1}^N, D) = - \sum_{i=1}^N y_i\log D(x_i) - \sum_{i=1}^N(1-y_i)\log (1 - D(x_i))
\]

对于GAN中的样本点 \(x_i\) ,对应于两个出处,要么来自于真实样本,要么来自于生成器生成的样本 $\tilde{x} - G(z) $ ( 这里的 \(z\) 是服从于投到生成器中噪声的分布)。

对于来自于真实的样本,我们要判别为正确的分布 \(y_i\) 。来自于生成的样本我们要判别其为错误分布( \(1-y_i\) )。将上面式子进一步使用概率分布的期望形式写出(为了表达无限的样本情况,相当于无限样本求和情况),并且让 \(y_i\) 为 1/2 且使用 \(G(z)\) 表示生成样本可以得到如下公式:

\[H \left( (x_i, y_i)_{i=1}^\infty, D \right) = -\frac{1}{2}E_{x-p_{data}}\left[ \log D(x) \right] - \frac{1}{2}E_z\left[ \log (1-D(G(z))) \right] \\\
GAN损失函数期望形式
\]

对于论文中的公式

\[min_G max_D V(D, G) = E_{x-p_{data}(x)}\left[ \log D(x) \right] + E_{z-p_z(z)}\left[ \log (1-D(G(z))) \right] \\\
GAN损失函数的 min max表达
\]

其实是与上面公式一样的,下面做解释

  • 这里的 \(V(D, G)\) 相当于表示真实样本和生成样本的差异程度。
  • \(max_D V(D, G)\) 的意思是固定生成器 \(G\), 尽可能地让判别器能够最大化地判别出样本来自于真实数据还是生成的数据。
  • 再将后面的 $L = max_D V(D, G) $ 看成整体,对于 \(min_G L\)这里是在固定判别器\(D\)的条件下得到生成器 \(G\),这个 \(G\) 要求能够最小化真实样本与生成样本的差异。
  • 通过上述 \(min\) \(max\) 的博弈过程,理想情况下会收敛于生成分布拟合于真实分布。

3、卷积对抗生成网络

卷积对抗生成网络(DCGAN)是在GAN的基础上加入了CNN,主要是改进了网络结构,在训练过程中状态稳定,并且可以有效实现高质量图片的生成以及相关的生成模型应用。DCGAN的生成器网络结构如下图:

DCGAN的改进:

  • 使用步长卷积代替上采样层,卷积在提取图像特征上具有很好的作用,并且使用卷积代替全连接层
  • 生成器G和判别器D中几乎每一层都使用batchnorm层,将特征层的输出归一化到一起,加速了训练,提升了训练的稳定性。
  • 在判别器中使用leakrelu激活函数,而不是RELU,防止梯度稀疏,生成器中仍然采用relu,但是输出层采用tanh。

4、DCGAN代码实现

  1. shenduimport numpy as np
  2. import matplotlib.pyplot as plt
  3. import tensorflow as tf
  4. from tensorflow import keras
  5. from tensorflow.keras import optimizers, losses, layers, Sequential, Model
  1. class DCGAN():
  2. '''
  3. 实现深度对抗神经网络
  4. 生成 MNIST 手写数字图片
  5. 输入的噪声为服从正态分布均值为 0 方差为 1 的分布, shape:(None, 100)
  6. 生成器(G)输入 噪声, 输出为 (None, 28, 28, 1)的图片
  7. 分类器(D)输入为 (None, 28, 28, 1)的图片,输出图片的分类真假
  8. '''
  9. def __init__(self):
  10. self.img_rows = 28
  11. self.img_cols = 28
  12. self.channels = 1
  13. self.img_shape = (self.img_rows, self.img_cols, self.channels)
  14. optimizer = optimizers.Adam(0.0002)
  15. # 构建编译分类器
  16. self.discriminator = self.build_discriminator()
  17. self.discriminator.compile(loss='binary_crossentropy',
  18. optimizer=optimizer,
  19. metrics=['accuracy'])
  20. # 构建编译生成器
  21. self.generator = self.build_generator()
  22. self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)
  23. # 生成器输入为噪音,生成图片
  24. z = layers.Input(shape=(100,))
  25. img = self.generator(z)
  26. # 对于整个对抗网络模型只优化生成器的参数
  27. self.discriminator.trainable = False
  28. # 用生成的图片输入分类器判断
  29. valid = self.discriminator(img)
  30. # 对于整个对抗网络 输入噪音 => 生成图片 => 决定图片是否有效
  31. self.combined = Model(z, valid)
  32. self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
  33. def build_generator(self):
  34. '''
  35. 构建生成器
  36. '''
  37. noise_shape = (100,)
  38. model = tf.keras.Sequential()
  39. # 添加全连接层
  40. model.add(layers.Dense(7*7*256, use_bias=False, input_shape=noise_shape))
  41. # 添加 BatchNormalization 层,对数据进行归一化
  42. model.add(layers.BatchNormalization())
  43. model.add(layers.LeakyReLU())
  44. model.add(layers.Reshape((7, 7, 256)))
  45. # 添加逆卷积层,卷积核大小为 5X5,数量 128, 步长为 1
  46. model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
  47. assert model.output_shape == (None, 7, 7, 128)
  48. model.add(layers.BatchNormalization())
  49. model.add(layers.LeakyReLU())
  50. # 添加逆卷积层,卷积核大小为 5X5,数量 64, 步长为 2
  51. model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
  52. assert model.output_shape == (None, 14, 14, 64)
  53. model.add(layers.BatchNormalization())
  54. model.add(layers.LeakyReLU())
  55. # 添加逆卷积层,卷积核大小为 5X5,数量 1, 步长为 2
  56. model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
  57. assert model.output_shape == (None, 28, 28, 1)
  58. model.summary()
  59. noise = layers.Input(shape=noise_shape)
  60. img = model(noise)
  61. # 返回 Model 对象,输入为 噪声, 输出为 图像
  62. return keras.Model(noise, img)
  63. def build_discriminator(self):
  64. '''
  65. 构建分类器
  66. '''
  67. img_shape = (self.img_rows, self.img_cols, self.channels)
  68. model = tf.keras.Sequential()
  69. model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
  70. input_shape=img_shape))
  71. model.add(layers.LeakyReLU())
  72. # 添加 Dropout 层,减少参数数量
  73. model.add(layers.Dropout(0.3))
  74. model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
  75. model.add(layers.LeakyReLU())
  76. model.add(layers.Dropout(0.3))
  77. # 把数据铺平
  78. model.add(layers.Flatten())
  79. model.add(layers.Dense(1))
  80. model.summary()
  81. img = layers.Input(shape=img_shape)
  82. validity = model(img)
  83. return keras.Model(img, validity)
  84. def train(self, epochs, batch_size=128, save_interval=50):
  85. '''
  86. 网络训练
  87. '''
  88. # 加载 数据集
  89. (X_train, _), (_, _) = keras.datasets.mnist.load_data()
  90. # 把数据缩放到 [-1, 1]
  91. X_train = (X_train.astype(np.float32) - 127.5) / 127.5
  92. # 添加通道维度
  93. X_train = np.expand_dims(X_train, axis=3)
  94. half_batch = int(batch_size / 2)
  95. for epoch in range(epochs):
  96. # ---------------------
  97. # 训练分类器
  98. # ---------------------
  99. # 随机的选择一半的 batch 数量图片
  100. idx = np.random.randint(0, X_train.shape[0], half_batch)
  101. imgs = X_train[idx]
  102. noise = np.random.normal(0, 1, (half_batch, 100))
  103. # 生成一半 batch 数量的 图片
  104. gen_imgs = self.generator.predict(noise)
  105. # 分类器损失
  106. d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
  107. d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
  108. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  109. # ---------------------
  110. # 训练生成器
  111. # ---------------------
  112. noise = np.random.normal(0, 1, (batch_size, 100))
  113. # The generator wants the discriminator to label the generated samples
  114. # as valid (ones)
  115. # 对于生成器,希望分类器把更多的图片判为 有效 (用 1 表示)
  116. valid_y = np.array([1] * batch_size)
  117. # 训练生成器
  118. g_loss = self.combined.train_on_batch(noise, valid_y)
  119. # 打印训练进度
  120. print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
  121. # 每个 save_interval 周期保存一张图片
  122. if epoch % save_interval == 0:
  123. self.save_imgs(epoch)
  124. def save_imgs(self, epoch):
  125. r, c = 5, 5
  126. noise = np.random.normal(0, 1, (r * c, 100))
  127. gen_imgs = self.generator.predict(noise)
  128. # 把图片数据缩放到 0 - 1
  129. gen_imgs = 0.5 * gen_imgs + 0.5
  130. fig, axs = plt.subplots(r, c)
  131. cnt = 0
  132. for i in range(r):
  133. for j in range(c):
  134. axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
  135. axs[i,j].axis('off')
  136. cnt += 1
  137. fig.savefig("dcgan/images/mnist_%d.png" % epoch)
  138. plt.close()
  139. if __name__ == '__main__':
  140. dcgan = DCGAN()
  141. dcgan.train(epochs=10000, batch_size=32, save_interval=200)

网络参数信息



5、训练结果

下面是循环了 10000 次 epoch 后,从开始每隔 2000 个 epoch 生成器生成的图片

  • 可以看到,刚开始全部都是噪声,随着训练的进行,图片逐渐清晰

  • 生成的图片还是不太清晰,一方面的原因是我训练的 epoch 周期太少,因为自己电脑性能问题,太耗时间,所以训练的epoch 周期少,如果有条件后提高训练周期应该会好很多。另一方面或许因为我构建的网络还有不合理之,后期还需要改进。









卷积生成对抗网络(DCGAN)---生成手写数字的更多相关文章

  1. 使用TensorFlow的卷积神经网络识别自己的单个手写数字,填坑总结

    折腾了几天,爬了大大小小若干的坑,特记录如下.代码在最后面. 环境: Python3.6.4 + TensorFlow 1.5.1 + Win7 64位 + I5 3570 CPU 方法: 先用MNI ...

  2. 图片训练:使用卷积神经网络(CNN)识别手写数字

    这篇文章中,我们将使用CNN构建一个Tensorflow.js模型来分辨手写的数字.首先,我们通过使之“查看”数以千计的数字图片以及他们对应的标识来训练分辨器.然后我们再通过此模型从未“见到”过的测试 ...

  3. tensorflow学习之(十)使用卷积神经网络(CNN)分类手写数字0-9

    #卷积神经网络cnn import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #数据包,如 ...

  4. GAN实战笔记——第四章深度卷积生成对抗网络(DCGAN)

    深度卷积生成对抗网络(DCGAN) 我们在第3章实现了一个GAN,其生成器和判别器是具有单个隐藏层的简单前馈神经网络.尽管很简单,但GAN的生成器充分训练后得到的手写数字图像的真实性有些还是很具说服力 ...

  5. 生成对抗网络(Generative Adversarial Networks,GAN)初探

    1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...

  6. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  7. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  8. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 07 训练Tensorflow识别手写数字

    打开Python Shell,输入以下代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input ...

随机推荐

  1. JSP+Servlet+JDBC+mysql实现的学生成绩管理系统

    项目简介 项目来源于:https://gitee.com/zzdoreen/SSMS 本系统基于JSP+Servlet+Mysql 一个基于JSP+Servlet+Jdbc的学生成绩管理系统.涉及技术 ...

  2. 真香警告!扩展 swagger支持文档自动列举所有枚举值

    承接上篇文章 <一站式解决使用枚举的各种痛点> 文章最后提到:在使用 swagger 来编写接口文档时,需要告诉前端枚举类型有哪些取值,每次增加取值之后,不仅要改代码,还要找到对应的取值在 ...

  3. .net core kafka 入门实例 一篇看懂

      kafka 相信都有听说过,不管有没有用过,在江湖上可以说是大名鼎鼎,就像天龙八部里的乔峰.国际惯例,先介绍生平事迹   简介 Kafka 是由 Apache软件基金会 开发的一个开源流处理平台, ...

  4. BZOJ1010单调性DP优化

    1010: [HNOI2008]玩具装箱toy Time Limit: 1 Sec  Memory Limit: 162 MBSubmit: 10707  Solved: 4445[Submit][S ...

  5. Kd Tree算法详解

    kd树(k-dimensional树的简称),是一种分割k维数据空间的数据结构,主要应用于多维空间关键数据的近邻查找(Nearest Neighbor)和近似最近邻查找(Approximate Nea ...

  6. 405 - 不允许用于访问此页的 HTTP 谓词的处理办法

    今天介绍的是针对访问html页面时出现此类错误的处理办法,如果你的问题页面是其他类型,可以参考如下信息: IIS 返回 405 - 不允许用于访问此页的 HTTP 谓词.终极解决办法!!!! 1.为什 ...

  7. Mysql与Mysqli的区别及特点

    1)PHP-MySQL 是 PHP 操作 MySQL 资料库最原始的 Extension ,PHP-MySQLi 的 i 代表 Improvement ,提更了相对进阶的功能,就 Extension ...

  8. 把数据写入txt中 open函数中 a与w的区别

    a: 打开一个文件用于追加.如果该文件已存在,文件指针将会放在文件的结尾. 也就是说,新的内容将会被写入到已有内容之后.如果该文件不存在,创建新文件进行写入. w:  打开一个文件只用于写入.如果该文 ...

  9. 【译】OWIN: Open Web Server Interface for .NET

    主要是使用 OAuth 时,它运行在 OWIN 上,然后又出了若干问题,总之,发现对 IIS.ASP.NET 和 OWIN 理解一塌糊涂. 后面看到 OWIN: Open Web Server Int ...

  10. Sping源码+Redis+Nginx+MySQL等七篇实战技术文档,阿里大佬推荐

    JVM JVM是Java Virtual Machine(Java虚拟机)的缩写,JVM是一种用于计算设备的规范,它是一个虚构出来的计算机,是通过在实际的计算机上仿真模拟各种计算机功能来实现的. 引入 ...