Mirza M, Osindero S. Conditional Generative Adversarial Nets.[J]. arXiv: Learning, 2014.

@article{mirza2014conditional,

title={Conditional Generative Adversarial Nets.},

author={Mirza, Mehdi and Osindero, Simon},

journal={arXiv: Learning},

year={2014}}

GAN (Generative Adversarial Nets) 能够通过隐变量\(z\)来生成一些数据, 但是我们没有办法去控制, 因为隐变量\(z\)是完全随机的. 这篇文章便很自然地提出了条件GAN,增加一个输入\(y\)(比如类别标签)去控制输出. 比如在MNIST数据集上, 我们随机采样一个\(z\), 并给定

\[y=[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
\]

结果应当是数字2.

主要内容

文章的优化函数如下:



网络"结构"如下:



代码

"""
这个几乎就是照搬别人的代码
lr=0.0001,
epochs=50
但是10轮就差不多收敛了
""" import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt class Generator(nn.Module):
"""
生成器
"""
def __init__(self, input_size=(100, 10), output_size=784):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size[0], 256),
nn.BatchNorm1d(256),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Linear(input_size[1], 256),
nn.BatchNorm1d(256),
nn.ReLU()
)
self.dense = nn.Sequential(
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, output_size),
nn.Tanh()
) def forward(self, z, y):
"""
:param z: 随机隐变量
:param y: 条件隐变量
:return:
"""
z = self.fc1(z)
y = self.fc2(y)
out = self.dense(
torch.cat((z, y), 1)
)
return out class Discriminator(nn.Module): def __init__(self, input_size=(784, 10)):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size[0], 1024),
nn.LeakyReLU(0.2)
)
self.fc2 = nn.Sequential(
nn.Linear(input_size[1], 1024),
nn.LeakyReLU(0.2)
)
self.dense = nn.Sequential(
nn.Linear(2048, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
) def forward(self, x, y):
x = self.fc1(x)
y = self.fc2(y)
out = self.dense(
torch.cat((x, y), 1)
)
return out class Train: def __init__(self, z_size=100, y_size=10, x_size=784,
criterion=nn.BCELoss(), lr=1e-4):
self.generator = Generator(input_size=(z_size, y_size), output_size=x_size)
self.discriminator = Discriminator(input_size=(x_size, y_size))
self.criterion = criterion
self.opti1 = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
self.opti2 = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
self.z_size = z_size
self.y_size = y_size
self.x_size = x_size
self.lr = lr
cpath = os.path.abspath('.')
self.gen_path = os.path.join(cpath, 'generator3.pt')
self.dis_path = os.path.join(cpath, 'discriminator3.pt')
self.imgspath = lambda i: os.path.join(cpath, 'image3', 'fig{0}'.format(i))
#self.loading() def transform_y(self, labels):
return torch.eye(self.y_size)[labels] def sampling_z(self, size):
return torch.randn(size) def showimgs(self, imgs, order):
n = imgs.size(0)
imgs = imgs.data.view(n, 28, 28)
fig, axs = plt.subplots(10, 10)
for i in range(10):
for j in range(10):
axs[i, j].get_xaxis().set_visible(False)
axs[i, j].get_yaxis().set_visible(False) for i in range(10):
for j in range(10):
t = i * 10 + j
img = imgs[t]
axs[i, j].cla()
axs[i, j].imshow(img.data.view(28, 28).numpy(), cmap='gray') fig.savefig(self.imgspath(order))
for i in range(10):
for j in range(10):
t = i * 10 + j
img = imgs[t]
axs[i, j].cla()
axs[i, j].imshow(img.data.view(28, 28).numpy() / 2 + 0.5, cmap='gray') fig.savefig(self.imgspath(order+1))
#plt.show()
#plt.cla() def train(self, trainloader, epochs=50, classes=10):
order = 2
for epoch in range(epochs):
running_loss_d = 0.
running_loss_g = 0.
if (epoch + 1) % 5 is 0.:
self.opti1.param_groups[0]['lr'] /= 10
self.opti2.param_groups[0]['lr'] /= 10
print("learning rate change!") if (epoch + 1) % order is 0.:
self.showimgs(fake_imgs, order=order)
self.showimgs(real_imgs, order=order+2)
order += 4 for i, data in enumerate(trainloader): real_imgs, labels = data
real_imgs = real_imgs.view(real_imgs.size(0), -1)
y = self.transform_y(labels)
d_out = self.discriminator(real_imgs, y).squeeze() z = self.sampling_z((y.size(0), self.z_size))
fake_y = self.transform_y(torch.randint(classes, size=(y.size(0),)))
fake_imgs = self.generator(z, fake_y).squeeze()
g_out = self.discriminator(fake_imgs, fake_y).squeeze() # 训练判别器
loss1 = self.criterion(d_out, torch.ones_like(d_out))
loss2 = self.criterion(g_out, torch.zeros_like(g_out)) d_loss = loss1 + loss2
self.opti2.zero_grad()
d_loss.backward()
self.opti2.step() # 训练生成器
z = self.sampling_z((y.size(0), self.z_size))
fake_y = self.transform_y(torch.randint(classes, size=(y.size(0),)))
fake_imgs = self.generator(z, fake_y).squeeze()
g_out = self.discriminator(fake_imgs, fake_y).squeeze()
g_loss = self.criterion(g_out, torch.ones_like(g_out))
self.opti1.zero_grad()
g_loss.backward()
self.opti1.step() running_loss_d += d_loss
running_loss_g += g_loss
if i % 10 is 0 and i != 0:
print("[epoch {0:<d}: d_loss: {1:<5f} g_loss: {2:<5f}]".format(
epoch, running_loss_d / 10, running_loss_g / 10
))
running_loss_d = 0.
running_loss_g = 0. torch.save(self.generator.state_dict(), self.gen_path)
torch.save(self.discriminator.state_dict(), self.dis_path)
def loading(self):
self.generator.load_state_dict(torch.load(self.gen_path))
self.generator.eval()
self.discriminator.load_state_dict(torch.load(self.dis_path))
self.discriminator.eval()

结果



此时判别器对这些图片进行判别, 但部分都是0.5以下, 也就是说这些基本上都被认为是伪造的图片.


"""
lr=0.001,
SGD,
网络结构简化了
"""
class Generator(nn.Module):
"""
生成器
"""
def __init__(self, input_size=(100, 10), output_size=784):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size[0], 128),
nn.BatchNorm1d(128),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Linear(input_size[1], 128),
nn.BatchNorm1d(128),
nn.ReLU()
)
self.dense = nn.Sequential(
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, output_size),
nn.BatchNorm1d(output_size),
nn.Tanh()
) def forward(self, z, y):
"""
:param z: 随机隐变量
:param y: 条件隐变量
:return:
"""
z = self.fc1(z)
y = self.fc2(y)
out = self.dense(
torch.cat((z, y), 1)
)
return out class Discriminator(nn.Module): def __init__(self, input_size=(784, 10)):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(input_size[0], 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2)
)
self.fc2 = nn.Sequential(
nn.Linear(input_size[1], 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2)
)
self.dense = nn.Sequential(
nn.Linear(2048, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1),
nn.Sigmoid()
) def forward(self, x, y):
x = self.fc1(x)
y = self.fc2(y)
out = self.dense(
torch.cat((x, y), 1)
)
return out class Train: def __init__(self, z_size=100, y_size=10, x_size=784,
criterion=nn.BCELoss(), lr=1e-3, momentum=0.9):
self.generator = Generator(input_size=(z_size, y_size), output_size=x_size)
self.discriminator = Discriminator(input_size=(x_size, y_size))
self.criterion = criterion
self.opti1 = torch.optim.SGD(self.generator.parameters(), lr=lr, momentum=momentum)
self.opti2 = torch.optim.SGD(self.discriminator.parameters(), lr=lr, momentum=momentum)
self.z_size = z_size
self.y_size = y_size
self.x_size = x_size
self.lr = lr
cpath = os.path.abspath('.')
self.gen_path = os.path.join(cpath, 'generator2.pt')
self.dis_path = os.path.join(cpath, 'discriminator2.pt')
self.imgspath = lambda i: os.path.join(cpath, 'image', 'fig{0}'.format(i))
#self.loading() def transform_y(self, labels):
return torch.eye(self.y_size)[labels] def sampling_z(self, size):
return torch.randn(size) def showimgs(self, imgs, order):
n = imgs.size(0)
imgs = imgs.data.view(n, 28, 28)
fig, axs = plt.subplots(10, 10)
for i in range(10):
for j in range(10):
axs[i, j].get_xaxis().set_visible(False)
axs[i, j].get_yaxis().set_visible(False) for i in range(10):
for j in range(10):
t = i * 10 + j
img = imgs[t]
axs[i, j].cla()
axs[i, j].imshow(img.data.view(28, 28).numpy(), cmap='gray')
fig.savefig(self.imgspath(order)) def train(self, trainloader, epochs=5, classes=10):
order = 0
for epoch in range(epochs):
running_loss_d = 0.
running_loss_g = 0.
if (epoch + 1) % 5 is 0.:
self.opti1.param_groups[0]['lr'] /= 10
self.opti2.param_groups[0]['lr'] /= 10
print("learning rate change!")
for i, data in enumerate(trainloader): real_imgs, labels = data
real_imgs = real_imgs.view(real_imgs.size(0), -1)
y = self.transform_y(labels) d_out = self.discriminator(real_imgs, y).squeeze() z = self.sampling_z((y.size(0), self.z_size))
fake_y = self.transform_y(torch.randint(classes, size=(y.size(0),)))
fake_imgs = self.generator(z, fake_y).squeeze()
g_out = self.discriminator(fake_imgs.detach(), fake_y).squeeze() # 训练判别器
loss1 = self.criterion(d_out, torch.ones_like(d_out))
loss2 = self.criterion(g_out, torch.zeros_like(g_out)) d_loss = loss1 + loss2
self.opti2.zero_grad()
d_loss.backward()
self.opti2.step() # 训练生成器
z = self.sampling_z((y.size(0), self.z_size))
fake_y = self.transform_y(torch.randint(classes, size=(y.size(0),)))
fake_imgs = self.generator(z, fake_y).squeeze()
g_out = self.discriminator(fake_imgs, fake_y).squeeze()
g_loss = self.criterion(g_out, torch.ones_like(g_out))
self.opti1.zero_grad()
g_loss.backward()
self.opti1.step() running_loss_d += d_loss
running_loss_g += g_loss
if i % 10 is 0 and i != 0:
print("[epoch {0:<d}: d_loss: {1:<5f} g_loss: {2:<5f}]".format(
epoch, running_loss_d / 10, running_loss_g / 10
))
running_loss_d = 0.
running_loss_g = 0.
if (epoch + 1) % 2 is 0:
self.showimgs(fake_imgs, order=order)
order += 1 torch.save(self.generator.state_dict(), self.gen_path)
torch.save(self.discriminator.state_dict(), self.dis_path)
def loading(self):
self.generator.load_state_dict(torch.load(self.gen_path))
self.generator.eval()
self.discriminator.load_state_dict(torch.load(self.dis_path))
self.discriminator.eval()

结果, 不是特别好

SGD改成Adam之后的结果(50个epochs都训练完了, 结果居然有点好).

Conditional Generative Adversarial Nets的更多相关文章

  1. 论文笔记之:Conditional Generative Adversarial Nets

    Conditional Generative Adversarial Nets arXiv 2014   本文是 GANs 的拓展,在产生 和 判别时,考虑到额外的条件 y,以进行更加"激烈 ...

  2. Generative Adversarial Nets[content]

    0. Introduction 基于纳什平衡,零和游戏,最大最小策略等角度来作为GAN的引言 1. GAN GAN开山之作 图1.1 GAN的判别器和生成器的结构图及loss 2. Condition ...

  3. Generative Adversarial Nets[CAAE]

    本文来自<Age Progression/Regression by Conditional Adversarial Autoencoder>,时间线为2017年2月. 该文很有意思,是如 ...

  4. Generative Adversarial Nets[pix2pix]

    本文来自<Image-to-Image Translation with Conditional Adversarial Networks>,是Phillip Isola与朱俊彦等人的作品 ...

  5. GAN(Generative Adversarial Nets)的发展

    GAN(Generative Adversarial Nets),产生式对抗网络 存在问题: 1.无法表示数据分布 2.速度慢 3.resolution太小,大了无语义信息 4.无reference ...

  6. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  7. 论文笔记之:Generative Adversarial Nets

    Generative Adversarial Nets NIPS 2014  摘要:本文通过对抗过程,提出了一种新的框架来预测产生式模型,我们同时训练两个模型:一个产生式模型 G,该模型可以抓住数据分 ...

  8. Generative Adversarial Nets[BEGAN]

    本文来自<BEGAN: Boundary Equilibrium Generative Adversarial Networks>,时间线为2017年3月.是google的工作. 作者提出 ...

  9. Generative Adversarial Nets[CycleGAN]

    本文来自<Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks>,时间线为2017 ...

随机推荐

  1. 日常Java 2021/10/29

    Java Object类是所有类的父类,也就是说Java的所有类都继承了Object,子类可以使用Object的所有方法. Object类位于java.lang 包中,编译时会自动导入,我们创建一个类 ...

  2. Spark(二)【sc.textfile的分区策略源码分析】

    sparkcontext.textFile()返回的是HadoopRDD! 关于HadoopRDD的官方介绍,使用的是旧版的hadoop api ctrl+F12搜索 HadoopRDD的getPar ...

  3. 【Android】安装插件 + 改变文字大小、颜色 + 隐藏代码区块的直线

    安装插件 可以在搜寻框里面填入关键字搜寻,具体的插件,网上有很多介绍了 改变文字大小.颜色 隐藏代码区块的直线

  4. CentOS6+nginx+uwsgi+mysql+django1.6.6+python2.6.6

    1.配置网关 #vi /etc/sysconfig/network NETWORKING=yes(表示系统是否使用网络,一般设置为yes.如果设为no,则不能使用网络,而且很多系统服务程序将无法启动) ...

  5. my39_InnoDB锁机制之Gap Lock、Next-Key Lock、Record Lock解析

    MySQL InnoDB支持三种行锁定方式: 行锁(Record Lock):锁直接加在索引记录上面,锁住的是key. 间隙锁(Gap Lock): 锁定索引记录间隙,确保索引记录的间隙不变.间隙锁是 ...

  6. 【Python】【Algorithm】排序

    冒泡排序 dic = [12, 45, 22, 6551, 74, 155, 6522, 1, 386, 15, 369, 15, 128, 123, ] for j in range(1, len( ...

  7. Spring AOP通过注解的方式设置切面和切入点

    切面相当于一个功能的某一个类,切入点是这个类的某部分和需要额外执行的其他代码块,这两者是多对多的关系,在代码块处指定执行的条件. Aspect1.java package com.yh.aop.sch ...

  8. java的bio和nio写入及读取txt文件

    一.bio的写入及读取 1.采用bio之BufferedWriter 写入文件 public static void main(String[] args) throws IOException { ...

  9. ciscn_2019_s_3

    拿到题目例行检查 64位程序开启了nx保护,将程序放入ida 看到没有system函数第一时间想到的就是泄露libc来做,后来才知道是我学识尚浅,需要用execve函数来做 进入main发现跳转到vu ...

  10. [BUUCTF]PWN——jarvisoj_tell_me_something

    jarvisoj_tell_me_something 附件 步骤: 例行检查,64位程序,开启了NX保护 运行一下程序,看看程序的大概流程 64位ida载入,shift+f12检索程序里的字符串 看到 ...