基于Jittor框架实现LSGAN图像生成对抗网络

生成对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。GAN模型由生成器(Generator)和判别器(Discriminator)两个部分组成。在训练过程中,生成器的目标就是尽量生成真实的图片去欺骗判别器。而判别器的目标就是尽量把生成器生成的图片和真实的图片分别开来。这样,生成器和判别器构成了一个动态的“博弈过程”。许多相关的研究工作表明GAN能够产生效果非常真实的生成效果。

使用Jittor框架实现了一种经典GAN模型LSGANLSGAN将GAN的目标函数由交叉熵损失替换成最小二乘损失,以此拒绝了标准GAN生成的图片质量不高以及训练过程不稳定这两个缺陷。通过LSGAN的实现介绍了Jittor数据加载、模型定义、模型训练的使用方法。

LSGAN论文:https://arxiv.org/abs/1611.04076

1.数据集准备

使用两种数据集进行LSGAN的训练,分别是Jittor自带的数据集MNIST,和用户构建的数据集CelebA。您可以通过以下链接下载CelebA数据集。

使用Jittor自带的MNIST数据加载器方法如下。使用jittor.transform可以进行数据归一化及数据增强,这里通过transform将图片归一化到[0,1]区间,并resize到标准大小112*112。。通过set_attrs函数可以修改数据集的相关参数,如batch_sizeshuffletransform等。

  1. from jittor.dataset.mnist import MNIST
  1. import jittor.transform as transform
  1.  
  1. transform = transform.Compose([
  1.     transform.Resize(size=img_size),
  1.     transform.ImageNormalize(mean=[0.5], std=[0.5])
  1. ])
  1. train_loader = MNIST (train=True, transform=transform)
  1.         .set_attrs(batch_size=batch_size, shuffle=True)
  1. val_loader = MNIST (train=False, transform = transform)
  1.         .set_attrs(batch_size=1, shuffle=True)

使用用户构建的CelebA数据集方法如下,通过通用数据加载器jittor.dataset.dataset.ImageFolder,输入数据集路径即可构建用户数据集。

  1. from jittor.dataset.dataset import ImageFolder
  1. import jittor.transform as transform
  1.  
  1. transform = transform.Compose([
  1.     transform.Resize(size=img_size),
  1.     transform.ImageNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  1. ])
  1. train_dir = './data/celebA_train'
  1. train_loader = ImageFolder(train_dir)
  1.         .set_attrs(transform=transform, batch_size=batch_size, shuffle=True)
  1. val_dir = './data/celebA_eval'
  1. val_loader = ImageFolder(val_dir)
  1.         .set_attrs(transform=transform, batch_size=1, shuffle=True)

2.模型定义

2.1.网络结构

使用LSGAN进行图像生成,下图为LSGAN论文给出的网络架构图,其中(a)为生成器,(b)为判别器。生成器网络输入一个1024维的向量,生成分辨率为112*112的图像;判别器网络输入112*112的图像,输出一个数字表示输入图像为真实图像的可信程度。

受到VGG模型的启发,生成器在与DCGAN的结构基础上在前两个反卷积层之后增加了两个步长=1的反卷积层。除使用最小二乘损失函数外判别器的结构与DCGAN中的结构相同。与DCGAN相同,生成器和判别器分别使用了ReLU激活函数和LeakyReLU激活函数。

下面将介绍如何使用Jittor定义一个网络模型。定义模型需要继承基类jittor.Module,并实现__init__execute函数。__init__函数在模型声明时会被调用,用于进行模型内部op或其他模型的声明及参数的初始化。该模型初始化时输入参数dim表示训练图像的通道数,对于MNIST数据集dim为1,对于CelebA数据集dim为3。

execute函数在网络前向传播时会被调用,用于定义前向传播的计算图,通过autograd机制在训练时Jittor会自动构建反向计算图。

  1. import jittor as jt
  1. from jittor import nn, Module
  1.  
  1. class generator(Module):
  1.     def __init__(self, dim=3):
  1.         super(generator, self).__init__()
  1.         self.fc = nn.Linear(1024, 7*7*256)
  1.         self.fc_bn = nn.BatchNorm(256)
  1.         self.deconv1 = nn.ConvTranspose(256, 256, 3, 2, 1, 1)
  1.         self.deconv1_bn = nn.BatchNorm(256)
  1.         self.deconv2 = nn.ConvTranspose(256, 256, 3, 1, 1)
  1.         self.deconv2_bn = nn.BatchNorm(256)
  1.         self.deconv3 = nn.ConvTranspose(256, 256, 3, 2, 1, 1)
  1.         self.deconv3_bn = nn.BatchNorm(256)
  1.         self.deconv4 = nn.ConvTranspose(256, 256, 3, 1, 1)
  1.         self.deconv4_bn = nn.BatchNorm(256)
  1.         self.deconv5 = nn.ConvTranspose(256, 128, 3, 2, 1, 1)
  1.         self.deconv5_bn = nn.BatchNorm(128)
  1.         self.deconv6 = nn.ConvTranspose(128, 64, 3, 2, 1, 1)
  1.         self.deconv6_bn = nn.BatchNorm(64)
  1.         self.deconv7 = nn.ConvTranspose(64 , dim, 3, 1, 1)
  1.         self.relu = nn.ReLU()
  1.         self.tanh = nn.Tanh()
  1.  
  1.     def execute(self, input):
  1.         x = self.fc_bn(self.fc(input).reshape((input.shape[0], 256, 7, 7)))
  1.         x = self.relu(self.deconv1_bn(self.deconv1(x)))
  1.         x = self.relu(self.deconv2_bn(self.deconv2(x)))
  1.         x = self.relu(self.deconv3_bn(self.deconv3(x)))
  1.         x = self.relu(self.deconv4_bn(self.deconv4(x)))
  1.         x = self.relu(self.deconv5_bn(self.deconv5(x)))
  1.         x = self.relu(self.deconv6_bn(self.deconv6(x)))
  1.         x = self.tanh(self.deconv7(x))
  1.         return x
  1. class discriminator(nn.Module):
  1.     def __init__(self, dim=3):
  1.         super(discriminator, self).__init__()
  1.         self.conv1 = nn.Conv(dim, 64, 5, 2, 2)
  1.         self.conv2 = nn.Conv(64, 128, 5, 2, 2)
  1.         self.conv2_bn = nn.BatchNorm(128)
  1.         self.conv3 = nn.Conv(128, 256, 5, 2, 2)
  1.         self.conv3_bn = nn.BatchNorm(256)
  1.         self.conv4 = nn.Conv(256, 512, 5, 2, 2)
  1.         self.conv4_bn = nn.BatchNorm(512)
  1.         self.fc = nn.Linear(512*7*7, 1)
  1.         self.leaky_relu = nn.Leaky_relu()
  1.  
  1.     def execute(self, input):
  1.         x = self.leaky_relu(self.conv1(input), 0.2)
  1.         x = self.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
  1.         x = self.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
  1.         x = self.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
  1.         x = x.reshape((x.shape[0], 512*7*7))
  1.         x = self.fc(x)
  1.         return x

2.2.损失函数

损失函数采用最小二乘损失函数,其中判别器损失函数如下。其中x为真实图像,z为服从正态分布的1024维向量,a取值为1,b取值为0。

生成器损失函数如下。其中z为服从正态分布的1024维向量,c取值为1。

具体实现如下,x为生成器的输出值,b表示该图像是否希望被判别为真。

  1. def ls_loss(x, b):
  1.     mini_batch = x.shape[0]
  1.     y_real_ = jt.ones((mini_batch,))
  1.     y_fake_ = jt.zeros((mini_batch,))
  1.     if b:
  1.         return (x-y_real_).sqr().mean()
  1.     else:
  1.         return (x-y_fake_).sqr().mean()

3.模型训练

3.1.参数设定

参数设定如下。

  1. # 通过use_cuda设置在GPU上进行训练
  1. jt.flags.use_cuda = 1
  1. # 批大小
  1. batch_size = 128
  1. # 学习率
  1. lr = 0.0002
  1. # 训练轮数
  1. train_epoch = 50
  1. # 训练图像标准大小
  1. img_size = 112
  1. # Adam优化器参数
  1. betas = (0.5,0.999)
  1. # 数据集图像通道数,MNIST为1,CelebA为3
  1. dim = 1 if task=="MNIST" else 3

3.2.模型、优化器声明

分别声明生成器和判别器,并使用Adam作为优化器。

  1. # 生成器
  1. G = generator (dim)
  1. # 判别器
  1. D = discriminator (dim)
  1. # 生成器优化器
  1. G_optim = nn.Adam(G.parameters(), lr, betas=betas)
  1. # 判别器优化器
  1. D_optim = nn.Adam(D.parameters(), lr, betas=betas)

3.3.训练

  1. for epoch in range(train_epoch):
  1.     for batch_idx, (x_, target) in enumerate(train_loader):
  1.          mini_batch = x_.shape[0]
  1.         # 判别器训练
  1.         D_result = D(sx)
  1.         D_real_loss = ls_loss(D_result, True)
  1.         z_ = init.gauss((mini_batch, 1024), 'float')
  1.         G_result = G(z_)
  1.         D_result_ = D(G_result)
  1.         D_fake_loss = ls_loss(D_result_, False)
  1.         D_train_loss = D_real_loss + D_fake_loss
  1.         D_train_loss.sync()
  1.         D_optim.step(D_train_loss)
  1.  
  1.         # 生成器训练
  1.         z_ = init.gauss((mini_batch, 1024), 'float')
  1.         G_result = G(z_)
  1.         D_result = D(G_result)
  1.         G_train_loss = ls_loss(D_result, True)
  1.         G_train_loss.sync()
  1.         G_optim.step(G_train_loss)
  1.         if (batch_idx%100==0):
  1.             print('D training loss =', D_train_loss.data.mean())
  1.             print('G training loss =', G_train_loss.data.mean())

4.结果与测试

4.1.生成结果

分别使用MNISTCelebA数据集进行了50个epoch的训练。训练完成后各随机采样了25张图像,结果如下。

4.2.速度对比

使用Jittor与主流的深度学习框架PyTorch进行了训练速度的对比,下表为PyTorch(是/否打开benchmark)及Jittor在两种数据集上进行1次训练迭带的使用时间。得益于Jittor特有的元算子融合技术,其训练速度比PyTorch快了40%~55%。

基于Jittor框架实现LSGAN图像生成对抗网络的更多相关文章

  1. 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)

    参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...

  2. AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华

    注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...

  3. [ZZ] Valse 2017 | 生成对抗网络(GAN)研究年度进展评述

    Valse 2017 | 生成对抗网络(GAN)研究年度进展评述 https://www.leiphone.com/news/201704/fcG0rTSZWqgI31eY.html?viewType ...

  4. 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】

    本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...

  5. 渐进结构—条件生成对抗网络(PSGAN)

    Full-body High-resolution Anime Generation with Progressive Structure-conditional Generative Adversa ...

  6. 知物由学 | AI网络安全实战:生成对抗网络

    本文由  网易云发布. “知物由学”是网易云易盾打造的一个品牌栏目,词语出自汉·王充<论衡·实知>.人,能力有高下之分,学习才知道事物的道理,而后才有智慧,不去求问就不会知道.“知物由学” ...

  7. 生成对抗网络(Generative Adversarial Networks,GAN)初探

    1. 从纳什均衡(Nash equilibrium)说起 我们先来看看纳什均衡的经济学定义: 所谓纳什均衡,指的是参与人的这样一种策略组合,在该策略组合上,任何参与人单独改变策略都不会得到好处.换句话 ...

  8. 【超分辨率】—(ESRGAN)增强型超分辨率生成对抗网络-解读与实现

    一.文献解读 我们知道GAN 在图像修复时更容易得到符合视觉上效果更好的图像,今天要介绍的这篇文章——ESRGAN: Enhanced Super-Resolution Generative Adve ...

  9. 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...

随机推荐

  1. hdu1501 记忆化搜索

    题意:       给你三个字符串,问你前两个能不能拼成第三个串. 思路:       直接记忆化神搜就行,思路水,看下代码就知道了.这个题目我感觉最大公共子序列dp的作法是错的,虽然有人ac了,随便 ...

  2. POJ3040给奶牛发工资

    题意:       有n种硬币,每种硬币有mi个,然后让你给奶牛发工资,每周发至少c元(就是不找零钱的意思)然后问你能发几周?(硬币之间都是倍数关系) 思路:       这个题目做了两天,丢脸,看完 ...

  3. SEO优化技术的简介

    严格来讲,seo技术没有所谓的严格的黑帽与白帽之分.即使是正常的301重定向,在某些情况下也能作用于黑帽seo技术.我们能判定一个人是真正的好人还是坏人么?答案是否定的.之所以解密所谓的黑帽seo,是 ...

  4. C++ Socket 简单封装

    以下代码一部分来自于<网络多人游戏架构与编程>, 其它的都是我瞎写的. 备忘. 一个简单的Socket封装,没有做什么高级的操作(比如IO完成端口等等). 1 #pragma once 2 ...

  5. Linux基本内容

    当你学会开发完成一个项目之后,你就可以将项目进行上线,而且其实并不难,你需要先对Linux操作系统了解一下,博客下面的内容是基于CentOs7服务器. 购买服务器 参考链接 Linux宝塔面板 Lin ...

  6. MVC三层架构的功能的简要说明

    MVC 介绍 MVC: Model 模型 ​ View 视图 ​ Controller 控制器 M (Model) : 模型 功能 DAO层 : 对数据库进行操作(CRUD) Service层 : 处 ...

  7. 获取某日期后一周、一月、一年的日期 php

    //获取某日期后三周同一天日期public static function getNextDate($date){ $return = [ date( 'Y-m-d', strtotime(" ...

  8. Pytest自动化测试-简易入门教程(02)

    Pytest框架简介 Pytest是一个非常成熟的全功能的Python测试框架,主要有以下几个特点:1.简单灵活,容易上手,支持参数化2.能够支持简单的单元测试和复杂的功能测试,3.还可以用来做sel ...

  9. uboot1: 启动流程和移植框架

    目录 0 环境 1 移植框架 3 执行流程 3.0 链接地址 3.1 start.S, 入口 3.2 __main 3.3 board_init_f()和init_sequence_f[] 3.4 r ...

  10. Class和ClassLoader的getResource方法对比

    最近在看写Spring的源代码,里面有好多地方都用到了Class和ClassLoader类的getResource方法来加载资源文件.之前对这两个类的这个方法一知半解,概念也很模糊,这边做下整理,加深 ...