0902-用GAN生成动漫头像

pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

一、概述

本节将通过 GAN 实现一个生成动漫人物头像的例子。

在日本的技术博客网站上有个博主,利用 DCGAN 从 20 万张动漫头像中学习,最终能够利用程序自动生成动漫头像。源程序是利用 Chainer 框架实现的,在这里我们将尝试利用 Pytorch 实现。

原始的图片是从网站中采集的,并利用 OpenCV 截取头像,处理起来非常麻烦。因此我们在这里通过之乎用户 何之源 爬取并经过处理的 5 万张图片,想要图片的百度网盘链接的可以加我微信:chenyoudea。需要注意的是,这里图片的分辨率是 3×96×96,而不是论文中的 3×64×64,因此需要相应地调整网络结构,使生成图像的尺寸为 96。

二、代码结构

下面我们首先来看下我们未来的一个代码结构。

  1. checkpoints/ # 无代码,用来保存模型
  2. imgs/ # 无代码,用来保存生成的图片
  3. data/ # 无代码,用来保存训练所需要的图片
  4. main.py # 训练和生成
  5. model.py # 模型定义
  6. visualize.py # 可视化工具 visdom 的开发
  7. requirement.txt # 程序中用到的第三方库
  8. README.MD # 说明

三、model.py

model.py 主要是用来定义生成器和判别器的。

3.1 生成器

  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # Coding by https://www.cnblogs.com/nickchen121/
  4. # Datatime:2021/5/10 10:37
  5. # Filename:model.py
  6. # Toolby: PyCharm
  7. from torch import nn
  8. class NetG(nn.Module):
  9. """
  10. 生成器定义
  11. """
  12. def __init__(self, opt):
  13. super(NetG, self).__init__()
  14. ngf = opt.ngf # 生成器 feature map 数
  15. self.main = nn.Sequential(
  16. # 输入是 nz 维度的噪声,可以认识它是一个 nz*1*1 的 feature map
  17. # H_{out} = (H_{in}-1)*stride - 2*padding + kernel_size
  18. # 以下面一行代码的ConvTranspose2d举例(初始 H_{in}=1):H_{out} = (1-1)*1-2*0+4 = 4
  19. nn.ConvTranspose2d(opt.nz, ngf * 8, (4, 4), (1, 1), (0, 0), bias=False),
  20. nn.BatchNorm2d(ngf * 8),
  21. nn.ReLU(True),
  22. # 上一步的输出形状:(ngf*8)*4*4,其中(ngf*8)是输出通道数,4 为 H_{out} 是通过上述公式计算出来的
  23. # 以下面一行代码的ConvTranspose2d举例(初始 H_{in}=4):H_{out} = (4-1)*2-2*1+4 =8
  24. nn.ConvTranspose2d(ngf * 8, ngf * 4, (4, 4), (2, 2), (1, 1), bias=False),
  25. nn.BatchNorm2d(ngf * 4),
  26. nn.ReLU(True),
  27. # 上一步的输出形状:(ngf*4)*8*8
  28. nn.ConvTranspose2d(ngf * 4, ngf * 2, (4, 4), (2, 2), (1, 1), bias=False),
  29. nn.BatchNorm2d(ngf * 2),
  30. nn.ReLU(True),
  31. # 上一步的输出形状是:(ngf*2)*16*16
  32. nn.ConvTranspose2d(ngf * 2, ngf, (4, 4), (2, 2), (1, 1), bias=False),
  33. nn.BatchNorm2d(ngf),
  34. nn.ReLU(True),
  35. # 上一步的输出形状:(ngf)*32*32
  36. nn.ConvTranspose2d(ngf, 3, (5, 5), (3, 3), (1, 1), bias=False),
  37. nn.Tanh()
  38. # 输出形状:3*96*96
  39. )
  40. def forward(self, inp):
  41. return self.main(inp)

从上述生成器的代码可以看出生成器的构建比较简单,直接用 nn.Sequential 把上卷积、激活等操作拼接起来就行了。这里稍微注意下 ConvTranspose2d 的使用,当 kernel size 为 4、stride 为 2、padding 为 1 时,根据公式 \(H_{out} = (H_{in}-1)*stride - 2*padding + kernel_size\),输出尺寸刚好变成输入的两倍。

最后一层我们使用了 tanh 把输出图片的像素归一化至 -1~1,如果希望归一化到 0~1,可以使用 sigimoid 方法。

3.2 判别器

  1. class NetD(nn.Module):
  2. """
  3. 判别器定义
  4. """
  5. def __init__(self, opt):
  6. super(NetD, self).__init__()
  7. ndf = opt.ndf
  8. self.main = nn.Sequential(
  9. # 输入 3*96*96
  10. nn.Conv2d(3, ndf, (5, 5), (3, 3), (1, 1), bias=False),
  11. nn.LeakyReLU(0.2, inplace=True),
  12. # 输出 (ndf)*32*32
  13. nn.Conv2d(ndf, ndf * 2, (4, 4), (2, 2), (1, 1), bias=False),
  14. nn.BatchNorm2d(ndf * 2),
  15. nn.LeakyReLU(0.2, inplace=True),
  16. # 输出 (ndf*2)*16*16
  17. nn.Conv2d(ndf * 2, ndf * 4, (4, 4), (2, 2), (1, 1), bias=False),
  18. nn.BatchNorm2d(ndf * 4),
  19. nn.LeakyReLU(0.2, inplace=True),
  20. # 输出 (ndf*4)*8*8
  21. nn.Conv2d(ndf * 4, ndf * 8, (4, 4), (2, 2), (1, 1), bias=False),
  22. nn.BatchNorm2d(ndf * 8),
  23. nn.LeakyReLU(0.2, inplace=True),
  24. # 输出 (ndf*8)*4*4
  25. nn.Conv2d(ndf * 8, 1, (4, 4), (1, 1), (0, 0), bias=False),
  26. nn.Sigmoid() # 输出一个数:概率
  27. )
  28. def forward(self, inp):
  29. return self.main(inp).view(-1)

从上述代码可以看到判别器和生成器的网络结构几乎是对称的,从卷积核大小到 padding、stride 等设置,几乎一模一样。

需要注意的是,生成器的激活函数用的是 ReLU,而判别器使用的是 LeakyReLU,两者其实没有太大的区别,这种选择更多的是经验的总结。

判别器的最终输出是一个 0~1 的数,表示这个样本是真图片的概率。

四、参数配置

在开始写训练函数前,我们可以先配置模型参数

  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # Coding by https://www.cnblogs.com/nickchen121/
  4. # Datatime:2021/5/11 15:14
  5. # Filename:config.py
  6. # Toolby: PyCharm
  7. class Config(object):
  8. data_path = 'data/' # 数据集存放路径
  9. num_workers = 4 # 多进程加载数据所用的进程数
  10. image_size = 96 # 图片尺寸
  11. batch_size = 256
  12. max_epoch = 200
  13. lr1 = 2e-4 # 生成器的学习率
  14. lr2 = 2e-4 # 判别器的学习率
  15. beta1 = 0.5 # Adam 优化器的 beta1 参数
  16. use_gpu = False # 是否使用 GPU
  17. nz = 100 # 噪声维度
  18. ngf = 64 # 生成器的 feature map 数
  19. ndf = 64 # 判别器的 feature map 数
  20. save_path = 'imgs/' # 生成图片保存路径
  21. vis = True # 是否使用 visdom 可视化
  22. env = 'GAN' # visdom 的 env
  23. plot_every = 20 # 每隔 20 个 batch,visdom 画图一次
  24. debug_file = '/tmp/debuggan' # 存在该文件则进入 debug 模式
  25. d_every = 1 # 每 1 个 batch 训练一次判别器
  26. g_every = 5 # 每 5 个 batch 训练一次生成器
  27. decay_everty = 10 # 每 10 个 epoch 保存一次模型
  28. netd_path = 'checkpoints/netd_211.pth' # 预训练模型
  29. netg_path = 'checkpoints/netg_211.pth'
  30. # 测试时用的参数
  31. gen_img = 'result.png'
  32. # 从 512 张生成的图片路径中保存最好的 64 张
  33. gen_num = 64
  34. gen_search_num = 512
  35. gen_mean = 0 # 噪声的均值
  36. gen_std = 1 # 噪声的方差
  37. opt = Config()

上述这些都只是模型的默认参数,还可以利用 Fire 等工具通过命令行传入,覆盖默认值。

除此之外,还可以使用 opt.atrr,还可以利用 IDE/Python 提供的自动补全功能,十分方便。

上述的超参数大多是照搬 DCGAN 论文的默认值,这些默认值都是坐着经过大量的实验,发现这些参数能够更快地去训练出一个不错的模型。

五、数据处理

当我们下载完数据之后,需要把所有图片放在一文件夹,然后把文件夹移动到 data 目录下(并且要确保 data 下没有其他的文件夹)。使用这种方法是为了能够直接使用 pytorchvision 自带的 ImageFolder 读取图片,而没有必要自己写一个 Dataset。

数据读取和加载的代码如下所示。

  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # Coding by https://www.cnblogs.com/nickchen121/
  4. # Datatime:2021/5/12 09:43
  5. # Filename:dataset.py
  6. # Toolby: PyCharm
  7. import torch as t
  8. import torchvision as tv
  9. from torch.utils.data import DataLoader
  10. from config import opt
  11. # 数据处理,输出规模为 -1~1
  12. transforms = tv.transforms.Compose([
  13. tv.transforms.Scale(opt.image_size),
  14. tv.transforms.CenterCrop(opt.image_size),
  15. tv.transforms.ToTensor(),
  16. tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  17. ])
  18. # 加载数据集
  19. dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
  20. dataloader = DataLoader(
  21. dataset,
  22. batch_size=opt.batch_size,
  23. shuffle=True,
  24. num_workers=opt.num_workers,
  25. drop_last=True
  26. )

从上述代码中可以发现,用 ImageFolder 配合 DataLoader 加载图片十分方便。

六、训练

在训练之前,我们还需要定义几个变量:模型、优化器、噪声等。

  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # Coding by https://www.cnblogs.com/nickchen121/
  4. # Datatime:2021/5/10 10:37
  5. # Filename:main.py
  6. # Toolby: PyCharm
  7. import os
  8. import ipdb
  9. import tqdm
  10. import fire
  11. import torch as t
  12. import torchvision as tv
  13. from visualize import Visualizer
  14. from torch.autograd import Variable
  15. from torchnet.meter import AverageValueMeter
  16. from config import opt
  17. from dataset import dataloader
  18. from model import NetD, NetG
  19. def train(**kwargs):
  20. # 定义模型
  21. netd = NetD()
  22. netg = NetG()
  23. # 定义网络
  24. map_location = lambda storage, loc: storage
  25. if opt.netd_path:
  26. netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
  27. if opt.netg_path:
  28. netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
  29. # 定义优化器和损失
  30. optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
  31. optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
  32. criterion = t.nn.BCELoss()
  33. # 真图片 label 为 1,假图片 label 为 0,noises 为生成网络的输入噪声
  34. true_labels = Variable(t.ones(opt.batch_size))
  35. fake_labels = Variable(t.zeros(opt.batch_size))
  36. fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
  37. noises = vars(t.randn(opt.batch_size, opt.nz, 1, 1))
  38. # 如果使用 GPU 训练,把数据转移到 GPU 上
  39. if opt.use_gpu:
  40. netd.cuda()
  41. netg.cuda()
  42. criterion.cuda()
  43. true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
  44. fix_noises, noises = fix_noises.cuda(), noises.cuda()

在加载预训练模型的时候,最好指定 map_location。因为如果程序之前在 GPU 上运行,那么模型就会被存成 torch.cuda.Tensor,这样加载的时候会默认把数据加载到显存上。如果运行该程序的计算机中没有 GPU,则会报错,因此指定 map_location 把 Tensor 默认加载到内存上,等有需要的时候再加载到显存中。

下面开始训练网络,训练的步骤如下所示:

  1. 训练判别器:

    • 固定生成器
    • 对于真图片,判别器的输出概率值尽可能接近 1
    • 对于生成器生成的图片,判别器尽可能输出 0
  2. 训练生成器
    • 固定判别器
    • 生成器生成图片,尽可能让判别器输出 1
  3. 返回第一步,循环交替训练
  1. epochs = range(opt.max_epoch)
  2. for epoch in iter(epochs):
  3. for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
  4. real_img = Variable(img)
  5. if opt.use_gpu:
  6. real_img = real_img.cuda()
  7. # 训练判别器
  8. if (ii + 1) % opt.d_every == 0:
  9. optimizer_d.zero_grad()
  10. # 尽可能把真图片判别为 1
  11. output = netd(real_img)
  12. error_d_real = criterion(output, true_labels)
  13. error_d_real.backward()
  14. # 尽可能把假图片判别为 0
  15. noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
  16. fake_img = netg(noises).detach() # 根据照片生成假图片
  17. fake_ouput = netd(fake_img)
  18. error_d_fake = criterion(fake_ouput, fake_labels)
  19. error_d_fake.backward()
  20. optimizer_d.step()
  21. # 训练生成器
  22. if (ii + 1) % opt.g_every == 0:
  23. optimizer_g.zero_grad()
  24. noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
  25. fake_img = netg(noises)
  26. fake_output = netd(fake_img)
  27. # 尽可能让判别器把假图片也判别为 1
  28. error_g = criterion(fake_output, true_labels)
  29. error_g.backward()
  30. optimizer_g.step()
  31. # 可视化
  32. if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
  33. # 定义可视化窗口
  34. vis = Visualizer(opt.env)
  35. if os.path.exists(opt.debug_file):
  36. ipdb.set_trace()
  37. global fix_fake_imgs
  38. fix_fake_imgs = netg(fix_noises)
  39. vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
  40. vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
  41. vis.plot('errord', errord_meter.value()[0])
  42. vis.plot('errorg', errorg_meter.value()[0])
  43. if (epoch + 1) % opt.save_every == 0:
  44. # 保存模型、图片
  45. tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
  46. range=(-1, 1))
  47. t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
  48. t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
  49. errord_meter.reset()
  50. errorg_meter.reset()

在上述训练代码中,需要注意以下几点:

  • 训练生成器的时候,不需要调整判别器的参数;训练判别器的时候,也不需要调整生成器的参数
  • 在训练判别器的时候,需要对生成器生成的图片用 detach 操作进行计算图截断,避免反向传播把梯度传到生成器中。因为在训练判别器的时候我们不需要训练生成器,也就不需要生成器的梯度。
  • 在训练分类器的时候,需要反向传播两次,一次是希望把真图片判为 1,一次是希望把假图片判为 0.也可以把这个两者的数据放到一个 batch 中,进行一次前向传播和一次反向传播即可。但是人们发现,在一个 batch 中只包含真图片或者只包含假图片的做法最好。
  • 对于假图片,在训练判别器的时候,我们希望它输出为 0;而在训练生成器的时候,我们希望它输出为 1.因此可以看到一堆相互矛盾的代码:error_d_fake = criterion(fake_output,fake_labels)error_g = criterion(fake_output, true_labels)。其实这也很好理解,判别器希望能够把假图片判别为 fake_label,而生成器希望能把它判别为 true_label,判别器和生成器相互对抗提升。
  • 其中的 Visualize 模块类似于上一章自己的写的模块,可以直接复制粘贴源码中的代码。

七、随机生成图片

除了上述所示的代码外,还提供了一个函数,能加载预训练好的模型,并且利用噪声随机生成图片。

  1. @t.no_grad()
  2. def generate():
  3. # 定义噪声和网络
  4. netg, netd = NetG(opt), NetD(opt)
  5. noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
  6. noises = Variable(noises)
  7. # 加载预训练的模型
  8. netd.load_state_dict(t.load(opt.netd_path))
  9. netg.load_state_dict(t.load(opt.netg_path))
  10. # 是否使用 GPU
  11. if opt.use_gpu:
  12. netd.cuda()
  13. netg.cuda()
  14. noises = noises.cuda()
  15. # 生成图片,并计算图片在判别器的分数
  16. fake_img = netg(noises)
  17. scores = netd(fake_img).data
  18. # 挑选最好的某几张
  19. indexs = scores.topk(opt.gen_num)[1]
  20. result = []
  21. for ii in indexs:
  22. result.append(fake_img.data[ii])
  23. # 保存图片
  24. tv.utils.save_image(t.stack(result), opt.gen_num, normalize=True, range=(-1, 1))

八、训练模型并测试

完整的代码可以添加我微信:chenyoudea,其实上述代码已经很完整了,或者去github https://github.com/chenyuntc/pytorch-book/tree/master/chapter07-AnimeGAN下载。

这里假设你是拥有完整的代码,那么准备好数据后,可以用下面的命令开始训练:

  1. python main.py train --gpu=True --vis=True --batch-size=256 --max-epoch=200

如果使用了 visdom,此时打开 http://localhost:8097 就能看到生成的图像。

训练完成后,我们就可以利用生成网络随机生成动漫头像,输入命令如下:

  1. python main.py generate --gen-img='result.5w.png' --gen-search-num=15000

0902-用GAN生成动漫头像的更多相关文章

  1. GAN网络之入门教程(四)之基于DCGAN动漫头像生成

    目录 使用前准备 数据集 定义参数 构建网络 构建G网络 构建D网络 构建GAN网络 关于GAN的小trick 训练 总结 参考 这一篇博客以代码为主,主要是来介绍如果使用keras构建一个DCGAN ...

  2. GAN生成图像论文总结

    GAN Theory Modifyingthe Optimization of GAN 题目 内容 GAN   DCGAN   WGAN   Least-square GAN   Loss Sensi ...

  3. GAN网络之入门教程(五)之基于条件cGAN动漫头像生成

    目录 Prepare 在上篇博客(AN网络之入门教程(四)之基于DCGAN动漫头像生成)中,介绍了基于DCGAN的动漫头像生成,时隔几月,序属三秋,在这篇博客中,将介绍如何使用条件GAN网络(cond ...

  4. DCGAN in Tensorflow生成动漫人物

    引自:GAN学习指南:从原理入门到制作生成Demo 生成式对抗网络(GAN)是近年来大热的深度学习模型.最近正好有空看了这方面的一些论文,跑了一个GAN的代码,于是写了这篇文章来介绍一下GAN. 本文 ...

  5. GAN︱生成模型学习笔记(运行机制、NLP结合难点、应用案例、相关Paper)

    我对GAN"生成对抗网络"(Generative Adversarial Networks)的看法: 前几天在公开课听了新加坡国立大学[机器学习与视觉实验室]负责人冯佳时博士在[硬 ...

  6. 4.keras实现-->生成式深度学习之用GAN生成图像

    生成式对抗网络(GAN,generative adversarial network)由Goodfellow等人于2014年提出,它可以替代VAE来学习图像的潜在空间.它能够迫使生成图像与真实图像在统 ...

  7. GAN 生成mnist数据

    参考资料 GAN原理学习笔记 生成式对抗网络GAN汇总 GAN的理解与TensorFlow的实现 TensorFlow小试牛刀(2):GAN生成手写数字 参考代码之一 #coding=utf-8 #h ...

  8. 小程序利用canvas 绘制图案 (生成海报, 生成有特色的头像)

    小程序利用canvas 绘制图案 (生成海报, 生成有特色的头像) 微信小程序生成特色头像,海报等是比较常见的.下面我来介绍下实现该类小程序的过程. 首先选择前端来通过 canvas 绘制.这样比较节 ...

  9. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

随机推荐

  1. React实用技巧

    取消请求 React 中当前正在发出请求的组件从页面上卸载了,理想情况下这个请求也应该取消掉,那么如何把请求的取消和页面的卸载关联在一起呢? 这里要考虑利用 useEffect 传入函数的返回值: u ...

  2. java注解基础入门

    前言 这篇博客主要是对java注解相关的知识进行入门级的讲解,包括**,核心内容主要体现在对java注解的理解以及如何使用.希望通过写这篇博客的过程中让自己对java注解有更深入的理解,在工作中可以巧 ...

  3. P1036_选数(JAVA语言)

    题目描述 已知 n 个整数x1​,x2​,-,xn​,以及1个整数k(k<n).从n个整数中任选k个整数相加,可分别得到一系列的和.例如当n=4,k=3,4个整数分别为3,7,12,19时,可得 ...

  4. Ubuntu-搭建Clang Static Analyzer环境

    其实也就是一个开源的漏洞扫描器 专门扫描C/C++ 0BJECT-C++这种,实不相瞒我搭建了5天这个环境,最后我发现了一种超级方便的办法 前面怎么走的坑还是不分享了吧,由于没有看到前面很多人的办法或 ...

  5. 敏捷史话(十二):你现在接触的敏捷也许是“黑暗敏捷”——Ron Jeffries

    他很少提起往事,也不再提及二十年前那场引起软件行业变革的会议,他专注于当下,一直活跃在敏捷领域.八十多岁的他依然运营维护着网站和博客,是极限编程网站 XProgramming.com 的作者,该网站是 ...

  6. vue之mixin理解与使用

    使用场景 当有两个非常相似的组件,除了一些个别的异步请求外其余的配置都一样,甚至父组件传的值也是一样的,但他们之间又存在着足够的差异性,这时候就不得不拆分成两个组件,如果拆分成两个组件,你就不得不冒着 ...

  7. 第22 章 : 有状态应用编排 StatefulSet

    有状态应用编排 StatefulSet 本文将主要分享以下四方面的内容: "有状态"需求 用例解读 操作演示 架构设计 "有状态"需求 课程回顾 我们之前讲到过 ...

  8. [.net] 关于Exception的几点思考和在项目中的使用(三)

    本文链接: https://www.cnblogs.com/hubaijia/p/about-exceptions-3.html 系列文章: 关于Exception的几点思考和在项目中的使用(一) 关 ...

  9. Dynamics CRM与ADFS安装到同一台服务器后ADFS服务与Dynamics CRM沙盒服务冲突提示808端口占用问题

    当我们安装Dynamics CRM的产品时如果是单台服务器部署而且部署了IFD的情况会遇到一个问题就是ADFS服务的监听端口和Dynamics CRM沙盒服务的端口冲突了. 这样会导致两个服务中的一个 ...

  10. BUAA_OO_2020_第一单元总结

    BUAA_OO_2020_第一单元总结 OO第一单元作业主题为表达式求导,主要学习目标为熟悉面向对象思想,学会使用类来管理数据,感受分工协作的行为设计,建立程序鲁棒性概念.如今,第一单元的学习已落下帷 ...