作者在进行GAN学习中遇到的问题汇总到下方,并进行解读讲解,下面提到的题目是李宏毅老师机器学习课程的作业6(GAN)

一.GAN

网络上有关GAN和DCGAN的讲解已经很多,在这里不再加以赘述,放几个我认为比较好的讲解

1.GAN概念理解

2.理解GAN网络基本原理

3.李宏毅机器学习课程

4.换个角度看GAN:另一种损失函数

二.DCGAN

1.从头开始GAN【论文】(二) —— DCGAN

2.PyTorch教程之DCGAN

3.pytorch官方DCGAN样例讲解

三.示例代码解读

3.1关于数据集的下载

官方的数据集需要FQ下载,在查找相关网址后,上网找到了数据集,并成功下载,如下是数据集链接:

提取码:ctgr

成功下载并解压,可以删除作业代码中的有关下载和解压的部分

可以打开我最下面放的文件,改变数据地址即可(main函数中的workspace_dir)

  1. You may replace the workspace directory if you want.
  2. workspace_dir = '.'
  3. Training progress bar
  4. !pip install -q qqdm
  5. !gdown --id 1IGrTr308mGAaCKotpkkm8wTKlWs9Jq-p --output "{workspace_dir}/crypko_data.zip"

3.2导入相关包和函数

  1. import random
  2. import torch
  3. import numpy as np
  4. import os
  5. import glob
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torchvision
  9. import torchvision.transforms as transforms
  10. from torch import optim
  11. from torch.autograd import Variable
  12. from torch.utils.data import Dataset, DataLoader
  13. import matplotlib.pyplot as plt
  14. from qqdm.notebook import qqdm

如果没有qqdm或者matplotlib则需要pip或者conda下载(qqdm是进度条,在训练过程中可以显示训练进度)

3.3DateSet 数据预处理

Transfrom:

1.transforms.Compose():将一系列的transforms有序组合,实现时按照这些方法依次对图像操作。

类似封装函数,依次执行

2.transforms.ToPILImage:将数据转换为PILImage。

3.transforms.Resize:图像变换

4.transforms.ToTensor:转为tensor,并归一化至[0-1]

5.transforms.Normalize:数据归一化处理

  • mean:各通道的均值
  • std:各通道的标准差

关于为什么要进行归一化处理可以参考transforms.Normalize()

6.主函数进行数据加载:

  1. workspace_dir='D://机器学习//Jupyter//GAN学习//函数'
  2. dataset = get_dataset(os.path.join(workspace_dir, 'faces'))

3.4Model-模型的建立-DCGAN

3.4.1权重初始化

DCGAN指出,所有的权重都以均值为0,标准差为0.2的正态分布随机初始化。weights_init 函数读取一个已初始化的模型并重新初始化卷积层,转置卷积层,batch normalization 层。这个函数在模型初始化之后使用。

  1. def weights_init(m):
  2. classname = m.__class__.__name__
  3. if classname.find('Conv') != -1:
  4. m.weight.data.normal_(0.0, 0.02)
  5. elif classname.find('BatchNorm') != -1:
  6. m.weight.data.normal_(1.0, 0.02)
  7. m.bias.data.fill_(0)

3.4.2Generator-生成器模型

生成器的目的是将输入向量z zz 映射到真的数据空间。这儿我们的数据为图片,意味着我们需要将输入向量z zz转换为 3x64x64的RGB图像。实际操作时,我们通过一系列的二维转置卷,每次转置卷积后跟一个二维的batch norm层和一个relu激活层。生成器的输出接入tanh函数以便满足输出范围为[−1,1]。值得一提的是,每个转置卷积后面跟一个 batch norm 层,是DCGAN论文的一个主要贡献。这些网络层有助于训练时的梯度计算。

下面为生成器模型分析

  • 该函数由反卷积+batch norm+relu构成

反卷积参考这里

  1. def dconv_bn_relu(in_dim, out_dim):
  2. return nn.Sequential(
  3. nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
  4. padding=2, output_padding=1, bias=False),
  5. nn.BatchNorm2d(out_dim),
  6. nn.ReLU()
  7. )
  • 类似dconv_bn_relu,但特殊在l1为网络的第一层,输入输出和后面的l2不同(我其实觉得可以放到一起,但可能分开更加清除一些)
  1. self.l1 = nn.Sequential(
  2. nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
  3. nn.BatchNorm1d(dim * 8 * 4 * 4),
  4. nn.ReLU()
  5. )
  • 实例化生成器并调用weights_init函数
  1. self.l2_5 = nn.Sequential(
  2. dconv_bn_relu(dim * 8, dim * 4),
  3. dconv_bn_relu(dim * 4, dim * 2),
  4. dconv_bn_relu(dim * 2, dim),
  5. nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
  6. nn.Tanh()
  7. )
  • 构建cnn网络结构,tanh函数使输出在[-1,1]之间
  1. self.apply(weights_init)
  • 整体模型合并
  1. def forward(self, x):
  2. y = self.l1(x)
  3. y = y.view(y.size(0), -1, 4, 4)
  4. y = self.l2_5(y)
  5. return y
  • 主函数中生成实例,并输出网络结构
  1. netG=Generator(100)
  2. print(netG)
  • 运行后得到网络结构
  1. (l1): Sequential(
  2. (0): Linear(in_features=100, out_features=8192, bias=False)
  3. (1): BatchNorm1d(8192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  4. (2): ReLU()
  5. )
  6. (l2_5): Sequential(
  7. (0): Sequential(
  8. (0): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
  9. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  10. (2): ReLU()
  11. )
  12. (1): Sequential(
  13. (0): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
  14. (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  15. (2): ReLU()
  16. )
  17. (2): Sequential(
  18. (0): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
  19. (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  20. (2): ReLU()
  21. )
  22. (3): ConvTranspose2d(64, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
  23. (4): Tanh()
  24. )
  25. )

3.4.3 Discriminator-判别器模型

判别器的输入为3 *64 *64,输出为概率(分数),依次通过卷积层,BN层,LeakyReLU层,最后通过sigmoid函数输出得分

  1. from torch import nn
  2. from 函数.weights_inition import weights_init
  3. class Discriminator(nn.Module):
  4. """
  5. Input shape: (N, 3, 64, 64)
  6. Output shape: (N, )
  7. """
  8. def __init__(self, in_dim, dim=64):
  9. super(Discriminator, self).__init__()
  10. def conv_bn_lrelu(in_dim, out_dim):
  11. return nn.Sequential(
  12. nn.Conv2d(in_dim, out_dim, 5, 2, 2),
  13. nn.BatchNorm2d(out_dim),
  14. nn.LeakyReLU(0.2),
  15. )
  16. """ Medium: Remove the last sigmoid layer for WGAN. """
  17. self.ls = nn.Sequential(
  18. nn.Conv2d(in_dim, dim, 5, 2, 2),
  19. nn.LeakyReLU(0.2),
  20. conv_bn_lrelu(dim, dim * 2),
  21. conv_bn_lrelu(dim * 2, dim * 4),
  22. conv_bn_lrelu(dim * 4, dim * 8),
  23. nn.Conv2d(dim * 8, 1, 4),
  24. nn.Sigmoid(),
  25. )
  26. self.apply(weights_init)
  27. def forward(self, x):
  28. y = self.ls(x)
  29. y = y.view(-1)
  30. return y

类似生成器,具体框架不再分开说明

生成实例,观察网络结构

  1. Discriminator(
  2. (ls): Sequential(
  3. (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  4. (1): LeakyReLU(negative_slope=0.2)
  5. (2): Sequential(
  6. (0): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  7. (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  8. (2): LeakyReLU(negative_slope=0.2)
  9. )
  10. (3): Sequential(
  11. (0): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  12. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  13. (2): LeakyReLU(negative_slope=0.2)
  14. )
  15. (4): Sequential(
  16. (0): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
  17. (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  18. (2): LeakyReLU(negative_slope=0.2)
  19. )
  20. (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
  21. (6): Sigmoid()
  22. )
  23. )

3.5Training-模型的训练-DCGAN

3.5.1 创建网络结构

  1. G = Generator(in_dim=z_dim).to(device)
  2. D = Discriminator(3).to(device)
  3. G.train()
  4. D.train()
  5. # Loss
  6. criterion = nn.BCELoss()
  7. # Optimizer
  8. opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
  9. opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
  • in_dim=z_dim=100,z的分布(高斯分布)深度为100
  • 因为input的是图片,3channels,所以Discriminator(3)
  • 如果模型中有BN层,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差。
  • 损失函数使用二元交叉熵损失(BCELoss)
  • 这里使用Adam优化器更新参数Adam+pytorch,学习率设置为0.0002 Betal=0.5。

3.5.2加载数据

  1. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  • z为随机生成64*100的高斯分布数据(均值为0,方差为1)也叫噪声。

    z为生成器的输入。

3.5.3 训练D(判别器)

  • z为随机生成64*100的高斯分布数据(均值为0,方差为1)也叫噪声。

    z为生成器的输入。
  1. z = Variable(torch.randn(bs, z_dim)).to(device)
  • f_imgs大小为 64 *3 *64 *64(生成64张假图片)
  • 将z直接传入G (),可以直接调用forward()函数进行操作,参考forward
  • 下面展示forward()函数的流程

    avatar
  1. r_imgs = Variable(imgs).to(device)
  2. f_imgs = G(z)
  • 进行标签定义,真实图片的label为1,生成的图片的label为0。
  1. r_label = torch.ones((bs)).to(device)
  2. f_label = torch.zeros((bs)).to(device)
  • 把两种图片放入判别器,将r_imgs设置为detach(),意为参数不再更新(很好理解,因为图片的数据肯定不可以改变,只能改变网络的参数,所以就锁定了图片数据)。r_logit表示真实图片得分(越高越好),f_logit表示假图片得分(越低越好)。
  1. r_logit = D(r_imgs.detach())
  2. f_logit = D(f_imgs.detach())
  • 计算损失,就是将两种损失加起来除以二。
  1. r_loss = criterion(r_logit, r_label)
  2. f_loss = criterion(f_logit, f_label)
  3. loss_D = (r_loss + f_loss) / 2
  • module.zero_grad(),每一个batch的训练将参数的梯度清零。
  • 对loss进行反向传播算法,.backward()可以计算所有与loss_D有关的参数的梯度,参考backward
  • optimizer.step(),进行参数更新(Adam)。
  1. D.zero_grad()
  2. loss_D.backward()
  3. opt_D.step()

3.5.4 训练G(生成器)

  • z为随机生成64*100的高斯分布数据(均值为0,方差为1)也叫噪声。

    z为生成器的输入。
  1. z = Variable(torch.randn(bs, z_dim)).to(device)
  • 生成假图片,并计算得分(越高越好)。
  1. f_imgs = G(z)
  2. f_logit = D(f_imgs)
  3. loss_G = criterion(f_logit, r_label)
  • 更新参数
  1. G.zero_grad()
  2. loss_G.backward()
  3. opt_G.step()

3.6结果展示

  • 生成最后的结果(测试图片),每一个epoch生成一张图片。

有关pytorch中dataloader,dataset以及数据显示的学习可以参考PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

  1. G.eval()
  2. f_imgs_sample = (G(z_sample).data + 1) / 2.0
  3. filename = os.path.join(log_dir, f'Epoch_{epoch + 1:03d}.jpg')
  4. torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
  5. print(f' | Save some samples to {filename}.')
  6. # Show generated images in the jupyter notebook.
  7. grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
  8. plt.figure(figsize=(10, 10))
  9. plt.imshow(grid_img.permute(1, 2, 0))
  10. plt.show()
  11. G.train()
  12. if (e + 1) % 5 == 0 or e == 0:
  13. # Save the checkpoints.
  14. torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G.pth'))
  15. torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D.pth'))
  • 下图分别是我设置epoch为5,bitch_size为10和64生成的测试图片

    bitch_size=10

    avatar
    bitch_size=64

    avatar

观察图片会有一些问题,比如面部不全,眼睛颜色不对或者人脸模糊等问题,转换GAN类型或者增加数据集等可能使结果更好。

3.7代码文件

李宏毅老师的作业文件是.ipynb文件,我把函数分开写进了PyCharm中,在使用时需要更改各个函数中

import的文件名,还有数据的地址(在main函数中)改为自己的地址

提取码:dkdd

生成对抗网络GAN与DCGAN的理解的更多相关文章

  1. TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成

    生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...

  2. 用MXNet实现mnist的生成对抗网络(GAN)

    用MXNet实现mnist的生成对抗网络(GAN) 生成式对抗网络(Generative Adversarial Network,简称GAN)由一个生成网络与一个判别网络组成.生成网络从潜在空间(la ...

  3. 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...

  4. 生成对抗网络GAN介绍

    GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...

  5. 深度学习-生成对抗网络GAN笔记

    生成对抗网络(GAN)由2个重要的部分构成: 生成器G(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器 判别器D(Discriminator):判断这张图像是真实的 ...

  6. 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)

    参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...

  7. 科普 | ​生成对抗网络(GAN)的发展史

    来源:https://en.wikipedia.org/wiki/Edmond_de_Belamy 五年前,Generative Adversarial Networks(GANs)在深度学习领域掀起 ...

  8. 生成对抗网络(GAN)

    基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...

  9. 利用tensorflow训练简单的生成对抗网络GAN

    对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(di ...

随机推荐

  1. c学习 - 第四章:顺序程序设计

    4.4 字符数据的输入输出 putchar:函数的作用是想终端输出一个字符 putchar(c) getchar:函数的作用是从输入设备获取一个字符 getchar(c) 4.5 格式输入与输出 pr ...

  2. 【编程思想】【设计模式】【行为模式Behavioral】模板模式Template

    Python转载版 https://github.com/faif/python-patterns/blob/master/behavioral/template.py #!/usr/bin/env ...

  3. CentOS Linux下编译安装MySQL

    本文参考张宴的Nginx 0.8.x + PHP 5.2.13(FastCGI)搭建胜过Apache十倍的Web服务器(第6版)[原创]完成.所有操作命令都在CentOS 6.4 64位操作系统下实践 ...

  4. redis入门到精通系列(五):redis的持久化操作(RDB、AOF)

    (一)持久化的概述 持久化顾名思义就是将存储在内存的数据转存到硬盘中.在生活中使用word等应用的时候,如果突然遇到断电的情况,理论上数据应该是都不见的,因为没有保存的word内容都存放在内存里,断电 ...

  5. Druid数据库监控

    一.简介 Druid是阿里开源的一个JDBC应用组件, 其包括三部分: DruidDriver: 代理Driver,能够提供基于Filter-Chain模式的插件体系. DruidDataSource ...

  6. inode节点

    目录 一.简介 二.信息 inode的内容 inode的大小 3.inode号码 三.目录文件 四.硬连接 五.软链接 六.inode的特殊作用 一.简介 理解inode,要从文件储存说起. 文件储存 ...

  7. Kerberos认证

    http://www.cnblogs.com/artech/archive/2011/01/24/kerberos.html 最近一段时间都在折腾安全(Security)方面的东西,比如Windows ...

  8. 分布式可扩展web体系结构设计实例分析

    Web分布式系统设计准则 下面以一个上传和查询图片的例子来说明分布式web结构的设计考虑和常用的提高性能的方法.该例子提供上传图片和下载图片两个简单功能,并且有一下假设条件?: - 可以存储无上限数量 ...

  9. 文件系统系列学习笔记 - inode/dentry/file/super(2)

    此篇文章主要介绍下linux 文件系统下的主要对象及他们之间的关系. 1 inode inode结构中主要包含对文件或者目录原信息的描述,原信息包括但不限于文件大小.文件在磁盘块中的位置信息.权限位. ...

  10. 【模型推理】量化实现分享二:详解 KL 对称量化算法实现

      欢迎关注我的公众号 [极智视界],回复001获取Google编程规范   O_o   >_<   o_O   O_o   ~_~   o_O   大家好,我是极智视界,本文剖析一下 K ...