这一篇博客以代码为主,主要是来介绍如果使用keras构建一个DCGAN,然后基于DCGAN,做一个自动生成动漫头像。训练过程如下(50轮的训练过程)“

关于DCGAN或者GAN的相关知识,可以参考GAN网络入门教程。建议先了解相关知识,再来看这一篇博客。

项目地址:GitHub

使用前准备

首先的首先,我们肯定是需要数据集的,这里使用的数据集来自kaggle——Anime Faces。里面有21551张动漫头像的图片。大家可以到kaggle上面去下载数据集,或者说到我的github上去下载数据集(求个 不过分吧)。部分数据如下:

如果自己电脑计算机资源不是很强的话,比如我,一个mx250小水管(玩玩lol还是可以的,训练这个模型可能要等到下辈子),推荐大家去注册一个kaggle或者colab账号去白嫖GPU资源(1080,2080的玩家请随意)。不过个人更加的推荐kaggle,因为感觉它的资源分配是可见的,且可以后台运行。

数据集

数据集是动漫图片,我们可以将图片的像素点的值变成\([-1,1]\)之间,具体代码如下:

  1. # 数据集的位置
  2. avatar_img_path = "./data"
  3. import imageio
  4. import os
  5. import numpy as np
  6. def load_data():
  7. """
  8. 加载数据集
  9. :return: 返回numpy数组
  10. """
  11. all_images = []
  12. for image_name in os.listdir(avatar_img_path):
  13. # 加载图片
  14. image = imageio.imread(os.path.join(avatar_img_path,image_name))
  15. all_images.append(image)
  16. all_images = np.array(all_images)
  17. # 将图片数值变成[-1,1]
  18. all_images = (all_images - 127.5) / 127.5
  19. # 将数据随机排序
  20. np.random.shuffle(all_images)
  21. return all_images
  22. img_dataset = load_data()

然后定义展示图片的方法:


  1. import matplotlib.pyplot as plt
  2. def show_images(images,index = -1):
  3. """
  4. 展示并保存图片
  5. :param images: 需要show的图片
  6. :param index: 图片名
  7. :return:
  8. """
  9. plt.figure()
  10. for i, image in enumerate(images):
  11. ax = plt.subplot(5, 5, i+1)
  12. plt.axis('off')
  13. plt.imshow(image)
  14. plt.savefig("data_%d.png"%index)
  15. plt.show()
  • 展示数据集中的部分图片:
  1. show_images(img_dataset[0: 25])

定义参数

这里我们只定义两个参数,图片的shape代表生成的图片是\(64 \times 64\)的RGB图片,以及noise的大小是100:

  1. # noise的维度
  2. noise_dim = 100
  3. # 图片的shape
  4. image_shape = (64,64,3)

构建网络

首先导入tensorflow中的keras库,如下:

  1. from tensorflow.keras.models import Sequential,Model
  2. from tensorflow.keras.layers import UpSampling2D, Conv2D, Dense, BatchNormalization, LeakyReLU, Input,Reshape, MaxPooling2D, Flatten, AveragePooling2D, Conv2DTranspose
  3. from tensorflow.keras.optimizers import Adam

下图中的网络结构参照了kaggle中的Anime face generation with DCGAN (beginner)

构建G网络

生成器网络,我们按照如下的结构进行构建:

原理是我们通过全连接层将nosise的向量放大,然后在再使用反卷积等操作将其逐渐变成shape为\((64,64,3)\)的图片。

  1. def build_G():
  2. """
  3. 构建生成器
  4. :return:
  5. """
  6. model = Sequential()
  7. # 全连接层 100 -> 2048
  8. model.add(Dense(2048,input_dim = noise_dim))
  9. # 激活函数
  10. model.add(LeakyReLU(0.2))
  11. # 全连接层 2048 -> 8 * 8 * 256
  12. model.add(Dense(8 * 8 * 256))
  13. # DN层
  14. model.add(BatchNormalization())
  15. model.add(LeakyReLU(0.2))
  16. # 8 * 8 * 256 -> (8,8,256)
  17. model.add(Reshape((8, 8, 256)))
  18. # 卷积层 (8,8,256) -> (8,8,128)
  19. model.add(Conv2D(128, kernel_size=5, padding='same'))
  20. model.add(BatchNormalization())
  21. model.add(LeakyReLU(0.2))
  22. # 反卷积层 (8,8,128) -> (16,16,128)
  23. model.add(Conv2DTranspose(128, kernel_size=5, strides=2, padding='same'))
  24. model.add(LeakyReLU(0.2))
  25. # 反卷积层 (16,16,128) -> (32,32,64)
  26. model.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding='same'))
  27. model.add(LeakyReLU(0.2))
  28. # 反卷积层 (32,32,64) -> (64,64,3) = 图片
  29. model.add(Conv2DTranspose(3, kernel_size=5, strides=2, padding='same', activation='tanh'))
  30. return model
  31. G = build_G()

可以发现,\(G\)网络并没有compile这一步,这是因为\(G\)网络的权重优化并不是直接优化的,而是通过GAN网络进行间接优化的。

构建D网络

D网络的结构示意图如下:

判别器网络就是一个寻常的CNN网络:


  1. def build_D():
  2. """
  3. 构建判别器
  4. :return:
  5. """
  6. model = Sequential()
  7. # 卷积层
  8. model.add(Conv2D(64, kernel_size=5, padding='valid',input_shape = image_shape))
  9. # BN层
  10. model.add(BatchNormalization())
  11. # 激活层
  12. model.add(LeakyReLU(0.2))
  13. # 平均池化层
  14. model.add(AveragePooling2D(pool_size=2))
  15. # 卷积层
  16. model.add(Conv2D(128, kernel_size=3, padding='valid'))
  17. model.add(BatchNormalization())
  18. model.add(LeakyReLU(0.2))
  19. model.add(AveragePooling2D(pool_size=2))
  20. model.add(Conv2D(256, kernel_size=3, padding='valid'))
  21. model.add(BatchNormalization())
  22. model.add(LeakyReLU(0.2))
  23. model.add(AveragePooling2D(pool_size=2))
  24. # 将输入展平
  25. model.add(Flatten())
  26. # 全连接层
  27. model.add(Dense(1024))
  28. model.add(BatchNormalization())
  29. model.add(LeakyReLU(0.2))
  30. # 最终输出1(true img) 0(fake img)的概率大小
  31. model.add(Dense(1, activation='sigmoid'))
  32. model.compile(loss='binary_crossentropy',
  33. optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
  34. return model
  35. D = build_D()

构建GAN网络

由前面的博客,我们知道,GAN网络由G网络和D网络组成,GAN网络的input为nosie,输出为图片真假的概率。因此它的网络结构示意图如下所示:


  1. def build_gan():
  2. """
  3. 构建GAN网络
  4. :return:
  5. """
  6. # 冷冻判别器,也就是在训练的时候只优化G的网络权重,而对D保持不变
  7. D.trainable = False
  8. # GAN网络的输入
  9. gan_input = Input(shape=(noise_dim,))
  10. # GAN网络的输出
  11. gan_out = D(G(gan_input))
  12. # 构建网络
  13. gan = Model(gan_input,gan_out)
  14. # 编译GAN网络,使用Adam优化器,以及加上交叉熵损失函数(一般用于二分类)
  15. gan.compile(loss='binary_crossentropy',optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
  16. return gan
  17. GAN = build_gan()

关于GAN的小trick

我们会将真实的图片的lable标记为1,fake图片的lable标记为0,但是我们训练的时候可以使lable的值在一定的范围内浮动。关于更多的trick,可以参考这篇 GANs training tricks


  1. def sample_noise(batch_size):
  2. """
  3. 随机产生正态分布(0,1)的noise
  4. :param batch_size:
  5. :return: 返回的shape为(batch_size,noise)
  6. """
  7. return np.random.normal(size=(batch_size, noise_dim))
  8. def smooth_pos_labels(y):
  9. """
  10. 使得true label的值的范围为[0.7,1.2]
  11. :param y:
  12. :return:
  13. """
  14. return y - 0.3 + (np.random.random(y.shape) * 0.5)
  15. def smooth_neg_labels(y):
  16. """
  17. 使得fake label的值的范围为[0.0,0.3]
  18. :param y:
  19. :return:
  20. """
  21. return y + np.random.random(y.shape) * 0.3

训练

开始训练之前,我们还介绍一个函数,load_batch,因为我们训练图片不可能说一次将图片全部进行训练而是分批次进行训练(full batch需要大量的内存空间),而load_batch函数就行按批次加载图片。

  1. def load_batch(data, batch_size,index):
  2. """
  3. 按批次加载图片
  4. :param data: 图片数据集
  5. :param batch_size: 批次大小
  6. :param index: 批次序号
  7. :return:
  8. """
  9. return data[index*batch_size: (index+1)*batch_size]

然后我们就需要定义\(train\)函数了:


  1. def train(epochs=100, batch_size=64):
  2. """
  3. 训练函数
  4. :param epochs: 训练的次数
  5. :param batch_size: 批尺寸
  6. :return:
  7. """
  8. # 判别器损失
  9. discriminator_loss = 0
  10. # 生成器损失
  11. generator_loss = 0
  12. # img_dataset.shape[0] / batch_size 代表这个数据可以分为几个批次进行训练
  13. n_batches = int(img_dataset.shape[0] / batch_size)
  14. for i in range(epochs):
  15. for index in range(n_batches):
  16. # 按批次加载数据
  17. x = load_batch(img_dataset, batch_size,index)
  18. # 产生noise
  19. noise = sample_noise(batch_size)
  20. # G网络产生图片
  21. generated_images = G.predict(noise)
  22. # 产生为1的标签
  23. y_real = np.ones(batch_size)
  24. # 将1标签的范围变成[0.7 , 1.2]
  25. y_real = smooth_pos_labels(y_real)
  26. # 产生为0的标签
  27. y_fake = np.zeros(batch_size)
  28. # 将0标签的范围变成[0.0 , 0.3]
  29. y_fake = smooth_neg_labels(y_fake)
  30. # 训练真图片loss
  31. d_loss_real = D.train_on_batch(x, y_real)
  32. # 训练假图片loss
  33. d_loss_fake = D.train_on_batch(generated_images, y_fake)
  34. discriminator_loss = d_loss_real + d_loss_fake
  35. # 产生为1的标签
  36. y_real = np.ones(batch_size)
  37. # 训练GAN网络,input = fake_img ,label = 1
  38. generator_loss = GAN.train_on_batch(noise, y_real)
  39. print('[Epoch {0}]. Discriminator loss : {1}. Generator_loss: {2}.'.format(i, discriminator_loss, generator_loss))
  40. # 随机产生(25,100)的noise
  41. test_noise = sample_noise(25)
  42. # 使用G网络生成25张图偏
  43. test_images = G.predict(test_noise)
  44. # show 预测 img
  45. show_images(test_images,i)

开始训练:

  1. train(epochs=500, batch_size=32)

最后就进入到了漫长的等待结果的时间了。

总结

项目地址:GitHub

参考

GAN网络之入门教程(四)之基于DCGAN动漫头像生成的更多相关文章

  1. GAN网络之入门教程(五)之基于条件cGAN动漫头像生成

    目录 Prepare 在上篇博客(AN网络之入门教程(四)之基于DCGAN动漫头像生成)中,介绍了基于DCGAN的动漫头像生成,时隔几月,序属三秋,在这篇博客中,将介绍如何使用条件GAN网络(cond ...

  2. GAN网络从入门教程(一)之GAN网络介绍

    GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...

  3. GAN网络从入门教程(二)之GAN原理

    在一篇博客GAN网络从入门教程(一)之GAN网络介绍中,简单的对GAN网络进行了一些介绍,介绍了其是什么,然后大概的流程是什么. 在这篇博客中,主要是介绍其数学公式,以及其算法流程.当然数学公式只是简 ...

  4. GAN网络从入门教程(三)之DCGAN原理

    目录 DCGAN简介 DCGAN的特点 几个重要概念 下采样(subsampled) 上采样(upsampling) 反卷积(Deconvolution) 批标准化(Batch Normalizati ...

  5. 【Zigbee技术入门教程-号外】基于Z-Stack协议栈的抢答系统

    [Zigbee技术入门教程-号外]基于Z-Stack协议栈的抢答系统 广东职业技术学院  欧浩源 一.引言    2017年全国职业院校技能大赛"物联网技术应用"赛项中任务三题2的 ...

  6. 无废话ExtJs 入门教程四[表单:FormPanel]

    无废话ExtJs 入门教程四[表单:FormPanel] extjs技术交流,欢迎加群(201926085) 继上一节内容,我们在窗体里加了个表单.如下所示代码区的第28行位置,items:form. ...

  7. PySide——Python图形化界面入门教程(四)

    PySide——Python图形化界面入门教程(四) ——创建自己的信号槽 ——Creating Your Own Signals and Slots 翻译自:http://pythoncentral ...

  8. Elasticsearch入门教程(四):Elasticsearch文档CURD

    原文:Elasticsearch入门教程(四):Elasticsearch文档CURD 版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接: ...

  9. RabbitMQ入门教程(四):工作队列(Work Queues)

    原文:RabbitMQ入门教程(四):工作队列(Work Queues) 版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https:/ ...

随机推荐

  1. C++模板函数只能全特化不能偏特化

    C++模板函数只能全特化不能偏特化

  2. java的方法详解和总结

    一.什么是方法 在日常生活中,我们所说的方法就是为了解决某件事情,而采取的解决办法 java中的方法可以理解为语句的集合,用来完成解决某件事情或实现某个功能的办法 方法的优点: 程序变得更加简短而清晰 ...

  3. 深入了解Netty【六】Netty工作原理

    引言 前面学习了NIO与零拷贝.IO多路复用模型.Reactor主从模型. 服务器基于IO模型管理连接,获取输入数据,又基于线程模型,处理请求. 下面来学习Netty的具体应用. 1.Netty线程模 ...

  4. idea中右击的快捷键都找不到 Diagrams

    今天突然发现了一件很恐怖的事情,那就是我的IDEA的右击中找不到Diagrams了,因为我是用这个东西打开 .bpmn文件生成png的,突然没了.. 说一下解决吧 在FIle -> settin ...

  5. vue-router-next 通过hash模式访问页面不生效,直接刷新页面一直停留在根路由界面的解决办法

    vue3中,配合的vueRouter版本更改为vue-router-next通过 npm i vue-router@next 的方式进行引入添加,随后创建 router.js,在main.js里面引入 ...

  6. 信号、多app应用、flask-script

    信号 Flask 框架中的信号基于blinker,其只要就是让开发者可以在flak请求过程中制定一些用户行为 安装:pip3 install blinker 内置信号 request_started ...

  7. Spring学习(五)bean装配详解之 【XML方式配置】

    一.配置Bean的方式及选择 配置方式 在 XML 文件中显式配置 在 Java 的接口和类中实现配置 隐式 Bean 的发现机制和自动装配原则 方式选择的原则 最优先:通过隐式 Bean 的发现机制 ...

  8. vue学习02-v-text

    vue学习02-v-text 引入环境版本,cdn网络引用或者本地js应用 html的结构,一般是div 创建vue实例 el:挂载点 v-text指令的作用是设置标签的内容 默认写法会替换全部内容, ...

  9. MySQL分区 (分区介绍与实际使用)

    分区介绍: 一.什么是分区? 所谓分区,就是将一个表分成多个区块进行操作和保存,从而降低每次操作的数据,提高性能.而对于应用来说则是透明的,从逻辑上看只有一张表,但在物理上这个表可能是由多个物理分区组 ...

  10. 普转提Day2

    T1 给定一个区间,求这个区间中只有一个数字与其他数组不相同的数的个数. 给出的区间范围较大,但是要求的数比较少.所以我的想法是这样的:因为这些数只有一个数字和每个数字都相同的数不同,所以考虑将所有数 ...