GAN网络原理介绍和代码
GAN网络的整体公式:
公式各参数介绍如下:
X是真实地图片,而对应的标签是1。
G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0。
D是一个二分类网络,对于给定的图片判别真假。
D和G的参数更新方式:
D通过输入的真假图片,通过BCE(二分类交叉熵)更新自己的参数。
D对G(Z)生成的标签L,G尽可能使L为true,也就是1,通过BCE(二分类交叉熵)更新自己的参数。
公式演变:
对于G来说要使D无法判别自己生成的图片是假的,故而要使G(Z)越大越好,所以就使得V(G,D)越小越好;而对于D,使G(Z)越小D(X)越大,故而使V(G,D)越大越好
为了便于求导,故而加了log,变为如下:
最后对整个batch求期望,变为如下:
基于mnist实现的GAN网络结构对应的代码
import itertools
import math
import time import torch
import torchvision
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from IPython import display
from torch.autograd import Variable
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]) train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True) class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
) def forward(self, x):
out = self.model(x.view(x.size(0), 784))
out = out.view(out.size(0), -1)
return out class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 784),
nn.Tanh()
) def forward(self, x):
x = x.view(x.size(0), -1)
out = self.model(x)
return out discriminator = Discriminator().cuda()
generator = Generator().cuda()
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr) def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels):
discriminator.zero_grad()
outputs = discriminator(images)
real_loss = criterion(outputs, real_labels)
real_score = outputs outputs = discriminator(fake_images)
fake_loss = criterion(outputs, fake_labels)
fake_score = outputs d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
return d_loss, real_score, fake_score
def train_generator(generator, discriminator_outputs, real_labels):
generator.zero_grad()
g_loss = criterion(discriminator_outputs, real_labels)
g_loss.backward()
g_optimizer.step()
return g_loss # draw samples from the input distribution to inspect the generation on training
num_test_samples = 16
test_noise = Variable(torch.randn(num_test_samples, 100).cuda())
# create figure for plotting
size_figure_grid = int(math.sqrt(num_test_samples))
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False) # set number of epochs and initialize figure counter
num_epochs = 200
num_batches = len(train_loader)
num_fig = 0 for epoch in range(num_epochs):
for n, (images, _) in enumerate(train_loader):
images = Variable(images.cuda())
real_labels = Variable(torch.ones(images.size(0)).cuda()) # Sample from generator
noise = Variable(torch.randn(images.size(0), 100).cuda())
fake_images = generator(noise)
fake_labels = Variable(torch.zeros(images.size(0)).cuda()) # Train the discriminator
d_loss, real_score, fake_score = train_discriminator(discriminator, images, real_labels, fake_images,
fake_labels) # Sample again from the generator and get output from discriminator
noise = Variable(torch.randn(images.size(0), 100).cuda())
fake_images = generator(noise)
outputs = discriminator(fake_images) # Train the generator
g_loss = train_generator(generator, outputs, real_labels) if (n + 1) % 100 == 0:
test_images = generator(test_noise) for k in range(num_test_samples):
i = k // 4
j = k % 4
ax[i, j].cla()
ax[i, j].imshow(test_images[k, :].data.cpu().numpy().reshape(28, 28), cmap='Greys')
display.clear_output(wait=True)
display.display(plt.gcf()) plt.savefig('results/mnist-gan-%03d.png' % num_fig)
num_fig += 1
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f'
% (epoch + 1, num_epochs, n + 1, num_batches, d_loss.data[0], g_loss.data[0],
real_score.data.mean(), fake_score.data.mean())) fig.close()
GAN网络原理介绍和代码的更多相关文章
- When I see you again(加密原理介绍,代码实现DES、AES、RSA、Base64、MD5)
关于网络安全的数据加密部分,本来打算总结一篇博客搞定,没想到东西太多,这已是第三篇了,而且这篇写了多次,熬了多次夜,真是again and again.起个名字:数据加密三部曲,前两部链接如下: 整体 ...
- 加密原理介绍,代码实现DES、AES、RSA、Base64、MD5
阅读目录 github下载地址 一.DES对称加密 二.AES对称加密 三.RSA非对称加密 四.实际使用 五.关于Padding 关于电脑终端Openssl加密解密命令 关于网络安全的数据加密部分, ...
- TF实战:(Mask R-CNN原理介绍与代码实现)-Chapter-8
二值掩膜输出依据种类预测分支(Faster R-CNN部分)预测结果:当前RoI的物体种类为i第i个二值掩膜输出就是该RoI的损失Lmask 对于预测的二值掩膜输出,我们对每个像素点应用sigmoid ...
- 『TensorFlow』通过代码理解gan网络_中
『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...
- 常见的GAN网络的相关原理及推导
常见的GAN网络的相关原理及推导 在上一篇中我们给大家介绍了GAN的相关原理和推导,GAN是VAE的后一半,再加上一个鉴别网络.这样而导致了完全不同的训练方式. GAN,生成对抗网络,主要有两部分构成 ...
- GAN网络从入门教程(一)之GAN网络介绍
GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...
- GAN网络从入门教程(二)之GAN原理
在一篇博客GAN网络从入门教程(一)之GAN网络介绍中,简单的对GAN网络进行了一些介绍,介绍了其是什么,然后大概的流程是什么. 在这篇博客中,主要是介绍其数学公式,以及其算法流程.当然数学公式只是简 ...
- 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上
GAN网络架构分析 上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake ...
- UIContainerView纯代码实现及原理介绍
UIContainerView纯代码实现及原理介绍 1.1-在StoryBoard中使用UIContainerView 1.2-纯代码使用UIContainerView 1.3-UIContainer ...
随机推荐
- Linux(三)
1.用户与用户组 Linux系统是一个多用户.多任务的操作系统,任何一个要使用系统资源的用户,都必须首先向系统管理员(root)申请一个账号,然后以这个账号的身份进入系统. ...
- pandas 过滤
条件过滤 通过loc进行行过滤,也可对过滤后的行进行赋值 import pandas as pd df = pd.DataFrame({"name": ["yang&qu ...
- JDK性能分析工具-引用于深入理解JVM
1.jps(JVM Process Status Tool) 列出正在运行的虚拟机进程. 2.jstat(JVM Statistics Monitoring Tool) 显示运行状态信息. 3.jin ...
- 字符串模式匹配——KMP算法
KMP算法匹配字符串 朴素匹配算法 字符串的模式匹配的方法刚开始是朴素匹配算法,也就是经常说的暴力匹配,说白了就是用子串去和父串一个一个匹配,从父串的第一个字符开始匹配,如果匹配到某一个失配了,就 ...
- Kali Linux install "Veil-Evasion"
Xx_Step wget https://github.com/ChrisTruncer/Veil/archive/master.zip unzip master.zip cd Veil-Evasio ...
- iOS中nil、 Nil、 NULL和NSNull的区别
参考链接:https://www.jianshu.com/p/c3017ae6684a
- Python读字节某一位的值,设置某一位的值,二进制位操作
Python读字节某一位的值,设置某一位的值,二进制位操作 在物联网实际应用项目开发中,为了提升性能,与设备端配合,往往最终使用的是二进制字节串方式进行的通信协议封装,更会把0和1.True和Fa ...
- Oracle dataguard切换实施步骤
主备库的切换主要在两种情况下切换,Switchover和Failover,这两种切换都需要手工执行完成,不建议自动执行.主库端 192.168.411.20备库端 192.168.411.221 一是 ...
- json属性里面出现特殊字符怎么获取属性
直接上代码. 这样的 获取这个object后需要获取里面的书属性,但是正常情况下是XX.属性名.但是属性名有特殊符号.这时候我们可以这样. XX['属性名']['属性名']....可以一直这样写 XX ...
- 各种数和各种反演(所谓FFT的前置知识?)
每次问NC做多项式的题需要什么知识点. 各种数. 各种反演. 多项式全家桶. 然后我就一个一个地学知识点.然而还差好多,学到后面的前面的已经忘了(可能是我太菜吧不是谁都是NC啊) 然后发现每个知识点基 ...