生成对抗网络GAN与DCGAN的理解
作者在进行GAN学习中遇到的问题汇总到下方,并进行解读讲解,下面提到的题目是李宏毅老师机器学习课程的作业6(GAN)
一.GAN
网络上有关GAN和DCGAN的讲解已经很多,在这里不再加以赘述,放几个我认为比较好的讲解
1.GAN概念理解
2.理解GAN网络基本原理
3.李宏毅机器学习课程
4.换个角度看GAN:另一种损失函数
二.DCGAN
1.从头开始GAN【论文】(二) —— DCGAN
2.PyTorch教程之DCGAN
3.pytorch官方DCGAN样例讲解
三.示例代码解读
3.1关于数据集的下载
官方的数据集需要FQ下载,在查找相关网址后,上网找到了数据集,并成功下载,如下是数据集链接:
提取码:ctgr
成功下载并解压,可以删除作业代码中的有关下载和解压的部分
可以打开我最下面放的文件,改变数据地址即可(main函数中的workspace_dir)
You may replace the workspace directory if you want.
workspace_dir = '.'
Training progress bar
!pip install -q qqdm
!gdown --id 1IGrTr308mGAaCKotpkkm8wTKlWs9Jq-p --output "{workspace_dir}/crypko_data.zip"
3.2导入相关包和函数
import random
import torch
import numpy as np
import os
import glob
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from qqdm.notebook import qqdm
如果没有qqdm或者matplotlib则需要pip或者conda下载(qqdm是进度条,在训练过程中可以显示训练进度)
3.3DateSet 数据预处理
Transfrom:
1.transforms.Compose():将一系列的transforms有序组合,实现时按照这些方法依次对图像操作。
类似封装函数,依次执行
2.transforms.ToPILImage:将数据转换为PILImage。
3.transforms.Resize:图像变换
4.transforms.ToTensor:转为tensor,并归一化至[0-1]
5.transforms.Normalize:数据归一化处理
- mean:各通道的均值
- std:各通道的标准差
关于为什么要进行归一化处理可以参考transforms.Normalize()
6.主函数进行数据加载:
workspace_dir='D://机器学习//Jupyter//GAN学习//函数'
dataset = get_dataset(os.path.join(workspace_dir, 'faces'))
3.4Model-模型的建立-DCGAN
3.4.1权重初始化
DCGAN指出,所有的权重都以均值为0,标准差为0.2的正态分布随机初始化。weights_init 函数读取一个已初始化的模型并重新初始化卷积层,转置卷积层,batch normalization 层。这个函数在模型初始化之后使用。
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
3.4.2Generator-生成器模型
生成器的目的是将输入向量z zz 映射到真的数据空间。这儿我们的数据为图片,意味着我们需要将输入向量z zz转换为 3x64x64的RGB图像。实际操作时,我们通过一系列的二维转置卷,每次转置卷积后跟一个二维的batch norm层和一个relu激活层。生成器的输出接入tanh函数以便满足输出范围为[−1,1]。值得一提的是,每个转置卷积后面跟一个 batch norm 层,是DCGAN论文的一个主要贡献。这些网络层有助于训练时的梯度计算。
下面为生成器模型分析
- 该函数由反卷积+batch norm+relu构成
def dconv_bn_relu(in_dim, out_dim):
return nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
padding=2, output_padding=1, bias=False),
nn.BatchNorm2d(out_dim),
nn.ReLU()
)
- 类似dconv_bn_relu,但特殊在l1为网络的第一层,输入输出和后面的l2不同(我其实觉得可以放到一起,但可能分开更加清除一些)
self.l1 = nn.Sequential(
nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
nn.BatchNorm1d(dim * 8 * 4 * 4),
nn.ReLU()
)
- 实例化生成器并调用weights_init函数
self.l2_5 = nn.Sequential(
dconv_bn_relu(dim * 8, dim * 4),
dconv_bn_relu(dim * 4, dim * 2),
dconv_bn_relu(dim * 2, dim),
nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
nn.Tanh()
)
- 构建cnn网络结构,tanh函数使输出在[-1,1]之间
self.apply(weights_init)
- 整体模型合并
def forward(self, x):
y = self.l1(x)
y = y.view(y.size(0), -1, 4, 4)
y = self.l2_5(y)
return y
- 主函数中生成实例,并输出网络结构
netG=Generator(100)
print(netG)
- 运行后得到网络结构
(l1): Sequential(
(0): Linear(in_features=100, out_features=8192, bias=False)
(1): BatchNorm1d(8192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(l2_5): Sequential(
(0): Sequential(
(0): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(1): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(2): Sequential(
(0): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(3): ConvTranspose2d(64, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
(4): Tanh()
)
)
3.4.3 Discriminator-判别器模型
判别器的输入为3 *64 *64,输出为概率(分数),依次通过卷积层,BN层,LeakyReLU层,最后通过sigmoid函数输出得分
from torch import nn
from 函数.weights_inition import weights_init
class Discriminator(nn.Module):
"""
Input shape: (N, 3, 64, 64)
Output shape: (N, )
"""
def __init__(self, in_dim, dim=64):
super(Discriminator, self).__init__()
def conv_bn_lrelu(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, 5, 2, 2),
nn.BatchNorm2d(out_dim),
nn.LeakyReLU(0.2),
)
""" Medium: Remove the last sigmoid layer for WGAN. """
self.ls = nn.Sequential(
nn.Conv2d(in_dim, dim, 5, 2, 2),
nn.LeakyReLU(0.2),
conv_bn_lrelu(dim, dim * 2),
conv_bn_lrelu(dim * 2, dim * 4),
conv_bn_lrelu(dim * 4, dim * 8),
nn.Conv2d(dim * 8, 1, 4),
nn.Sigmoid(),
)
self.apply(weights_init)
def forward(self, x):
y = self.ls(x)
y = y.view(-1)
return y
类似生成器,具体框架不再分开说明
生成实例,观察网络结构
Discriminator(
(ls): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): LeakyReLU(negative_slope=0.2)
(2): Sequential(
(0): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(3): Sequential(
(0): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(4): Sequential(
(0): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
(6): Sigmoid()
)
)
3.5Training-模型的训练-DCGAN
3.5.1 创建网络结构
G = Generator(in_dim=z_dim).to(device)
D = Discriminator(3).to(device)
G.train()
D.train()
# Loss
criterion = nn.BCELoss()
# Optimizer
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
- in_dim=z_dim=100,z的分布(高斯分布)深度为100
- 因为input的是图片,3channels,所以Discriminator(3)
- 如果模型中有BN层,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差。
- 损失函数使用二元交叉熵损失(BCELoss)
- 这里使用Adam优化器更新参数Adam+pytorch,学习率设置为0.0002 Betal=0.5。
3.5.2加载数据
- Dataloader参考dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
- z为随机生成64*100的高斯分布数据(均值为0,方差为1)也叫噪声。
z为生成器的输入。
3.5.3 训练D(判别器)
- z为随机生成64*100的高斯分布数据(均值为0,方差为1)也叫噪声。
z为生成器的输入。
z = Variable(torch.randn(bs, z_dim)).to(device)
- f_imgs大小为 64 *3 *64 *64(生成64张假图片)
- 将z直接传入G (),可以直接调用forward()函数进行操作,参考forward。
- 下面展示forward()函数的流程
r_imgs = Variable(imgs).to(device)
f_imgs = G(z)
- 进行标签定义,真实图片的label为1,生成的图片的label为0。
r_label = torch.ones((bs)).to(device)
f_label = torch.zeros((bs)).to(device)
- 把两种图片放入判别器,将r_imgs设置为detach(),意为参数不再更新(很好理解,因为图片的数据肯定不可以改变,只能改变网络的参数,所以就锁定了图片数据)。r_logit表示真实图片得分(越高越好),f_logit表示假图片得分(越低越好)。
r_logit = D(r_imgs.detach())
f_logit = D(f_imgs.detach())
- 计算损失,就是将两种损失加起来除以二。
r_loss = criterion(r_logit, r_label)
f_loss = criterion(f_logit, f_label)
loss_D = (r_loss + f_loss) / 2
- module.zero_grad(),每一个batch的训练将参数的梯度清零。
- 对loss进行反向传播算法,.backward()可以计算所有与loss_D有关的参数的梯度,参考backward。
- optimizer.step(),进行参数更新(Adam)。
D.zero_grad()
loss_D.backward()
opt_D.step()
3.5.4 训练G(生成器)
- z为随机生成64*100的高斯分布数据(均值为0,方差为1)也叫噪声。
z为生成器的输入。
z = Variable(torch.randn(bs, z_dim)).to(device)
- 生成假图片,并计算得分(越高越好)。
f_imgs = G(z)
f_logit = D(f_imgs)
loss_G = criterion(f_logit, r_label)
- 更新参数
G.zero_grad()
loss_G.backward()
opt_G.step()
3.6结果展示
- 生成最后的结果(测试图片),每一个epoch生成一张图片。
有关pytorch中dataloader,dataset以及数据显示的学习可以参考PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
G.eval()
f_imgs_sample = (G(z_sample).data + 1) / 2.0
filename = os.path.join(log_dir, f'Epoch_{epoch + 1:03d}.jpg')
torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
print(f' | Save some samples to {filename}.')
# Show generated images in the jupyter notebook.
grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
plt.figure(figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
G.train()
if (e + 1) % 5 == 0 or e == 0:
# Save the checkpoints.
torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G.pth'))
torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D.pth'))
- 下图分别是我设置epoch为5,bitch_size为10和64生成的测试图片
bitch_size=10
bitch_size=64
观察图片会有一些问题,比如面部不全,眼睛颜色不对或者人脸模糊等问题,转换GAN类型或者增加数据集等可能使结果更好。
3.7代码文件
李宏毅老师的作业文件是.ipynb文件,我把函数分开写进了PyCharm中,在使用时需要更改各个函数中
import的文件名,还有数据的地址(在main函数中)改为自己的地址
提取码:dkdd
生成对抗网络GAN与DCGAN的理解的更多相关文章
- TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成
生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...
- 用MXNet实现mnist的生成对抗网络(GAN)
用MXNet实现mnist的生成对抗网络(GAN) 生成式对抗网络(Generative Adversarial Network,简称GAN)由一个生成网络与一个判别网络组成.生成网络从潜在空间(la ...
- 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...
- 生成对抗网络GAN介绍
GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...
- 深度学习-生成对抗网络GAN笔记
生成对抗网络(GAN)由2个重要的部分构成: 生成器G(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器 判别器D(Discriminator):判断这张图像是真实的 ...
- 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...
- 科普 | 生成对抗网络(GAN)的发展史
来源:https://en.wikipedia.org/wiki/Edmond_de_Belamy 五年前,Generative Adversarial Networks(GANs)在深度学习领域掀起 ...
- 生成对抗网络(GAN)
基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...
- 利用tensorflow训练简单的生成对抗网络GAN
对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(di ...
随机推荐
- c学习 - 第四章:顺序程序设计
4.4 字符数据的输入输出 putchar:函数的作用是想终端输出一个字符 putchar(c) getchar:函数的作用是从输入设备获取一个字符 getchar(c) 4.5 格式输入与输出 pr ...
- 【编程思想】【设计模式】【行为模式Behavioral】模板模式Template
Python转载版 https://github.com/faif/python-patterns/blob/master/behavioral/template.py #!/usr/bin/env ...
- CentOS Linux下编译安装MySQL
本文参考张宴的Nginx 0.8.x + PHP 5.2.13(FastCGI)搭建胜过Apache十倍的Web服务器(第6版)[原创]完成.所有操作命令都在CentOS 6.4 64位操作系统下实践 ...
- redis入门到精通系列(五):redis的持久化操作(RDB、AOF)
(一)持久化的概述 持久化顾名思义就是将存储在内存的数据转存到硬盘中.在生活中使用word等应用的时候,如果突然遇到断电的情况,理论上数据应该是都不见的,因为没有保存的word内容都存放在内存里,断电 ...
- Druid数据库监控
一.简介 Druid是阿里开源的一个JDBC应用组件, 其包括三部分: DruidDriver: 代理Driver,能够提供基于Filter-Chain模式的插件体系. DruidDataSource ...
- inode节点
目录 一.简介 二.信息 inode的内容 inode的大小 3.inode号码 三.目录文件 四.硬连接 五.软链接 六.inode的特殊作用 一.简介 理解inode,要从文件储存说起. 文件储存 ...
- Kerberos认证
http://www.cnblogs.com/artech/archive/2011/01/24/kerberos.html 最近一段时间都在折腾安全(Security)方面的东西,比如Windows ...
- 分布式可扩展web体系结构设计实例分析
Web分布式系统设计准则 下面以一个上传和查询图片的例子来说明分布式web结构的设计考虑和常用的提高性能的方法.该例子提供上传图片和下载图片两个简单功能,并且有一下假设条件?: - 可以存储无上限数量 ...
- 文件系统系列学习笔记 - inode/dentry/file/super(2)
此篇文章主要介绍下linux 文件系统下的主要对象及他们之间的关系. 1 inode inode结构中主要包含对文件或者目录原信息的描述,原信息包括但不限于文件大小.文件在磁盘块中的位置信息.权限位. ...
- 【模型推理】量化实现分享二:详解 KL 对称量化算法实现
欢迎关注我的公众号 [极智视界],回复001获取Google编程规范 O_o >_< o_O O_o ~_~ o_O 大家好,我是极智视界,本文剖析一下 K ...