GAN网络原理介绍和代码
GAN网络的整体公式:
公式各参数介绍如下:
X是真实地图片,而对应的标签是1。
G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0。
D是一个二分类网络,对于给定的图片判别真假。
D和G的参数更新方式:
D通过输入的真假图片,通过BCE(二分类交叉熵)更新自己的参数。
D对G(Z)生成的标签L,G尽可能使L为true,也就是1,通过BCE(二分类交叉熵)更新自己的参数。
公式演变:
对于G来说要使D无法判别自己生成的图片是假的,故而要使G(Z)越大越好,所以就使得V(G,D)越小越好;而对于D,使G(Z)越小D(X)越大,故而使V(G,D)越大越好
为了便于求导,故而加了log,变为如下:
最后对整个batch求期望,变为如下:
基于mnist实现的GAN网络结构对应的代码
- import itertools
- import math
- import time
- import torch
- import torchvision
- import torch.nn as nn
- import torchvision.datasets as dsets
- import torchvision.transforms as transforms
- import matplotlib.pyplot as plt
- from IPython import display
- from torch.autograd import Variable
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
- ])
- train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
- class Discriminator(nn.Module):
- def __init__(self):
- super().__init__()
- self.model = nn.Sequential(
- nn.Linear(784, 1024),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout(0.3),
- nn.Linear(1024, 512),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout(0.3),
- nn.Linear(512, 256),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout(0.3),
- nn.Linear(256, 1),
- nn.Sigmoid()
- )
- def forward(self, x):
- out = self.model(x.view(x.size(0), 784))
- out = out.view(out.size(0), -1)
- return out
- class Generator(nn.Module):
- def __init__(self):
- super().__init__()
- self.model = nn.Sequential(
- nn.Linear(100, 256),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(256, 512),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(512, 1024),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Linear(1024, 784),
- nn.Tanh()
- )
- def forward(self, x):
- x = x.view(x.size(0), -1)
- out = self.model(x)
- return out
- discriminator = Discriminator().cuda()
- generator = Generator().cuda()
- criterion = nn.BCELoss()
- lr = 0.0002
- d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
- g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
- def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels):
- discriminator.zero_grad()
- outputs = discriminator(images)
- real_loss = criterion(outputs, real_labels)
- real_score = outputs
- outputs = discriminator(fake_images)
- fake_loss = criterion(outputs, fake_labels)
- fake_score = outputs
- d_loss = real_loss + fake_loss
- d_loss.backward()
- d_optimizer.step()
- return d_loss, real_score, fake_score
- def train_generator(generator, discriminator_outputs, real_labels):
- generator.zero_grad()
- g_loss = criterion(discriminator_outputs, real_labels)
- g_loss.backward()
- g_optimizer.step()
- return g_loss
- # draw samples from the input distribution to inspect the generation on training
- num_test_samples = 16
- test_noise = Variable(torch.randn(num_test_samples, 100).cuda())
- # create figure for plotting
- size_figure_grid = int(math.sqrt(num_test_samples))
- fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
- for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
- ax[i, j].get_xaxis().set_visible(False)
- ax[i, j].get_yaxis().set_visible(False)
- # set number of epochs and initialize figure counter
- num_epochs = 200
- num_batches = len(train_loader)
- num_fig = 0
- for epoch in range(num_epochs):
- for n, (images, _) in enumerate(train_loader):
- images = Variable(images.cuda())
- real_labels = Variable(torch.ones(images.size(0)).cuda())
- # Sample from generator
- noise = Variable(torch.randn(images.size(0), 100).cuda())
- fake_images = generator(noise)
- fake_labels = Variable(torch.zeros(images.size(0)).cuda())
- # Train the discriminator
- d_loss, real_score, fake_score = train_discriminator(discriminator, images, real_labels, fake_images,
- fake_labels)
- # Sample again from the generator and get output from discriminator
- noise = Variable(torch.randn(images.size(0), 100).cuda())
- fake_images = generator(noise)
- outputs = discriminator(fake_images)
- # Train the generator
- g_loss = train_generator(generator, outputs, real_labels)
- if (n + 1) % 100 == 0:
- test_images = generator(test_noise)
- for k in range(num_test_samples):
- i = k // 4
- j = k % 4
- ax[i, j].cla()
- ax[i, j].imshow(test_images[k, :].data.cpu().numpy().reshape(28, 28), cmap='Greys')
- display.clear_output(wait=True)
- display.display(plt.gcf())
- plt.savefig('results/mnist-gan-%03d.png' % num_fig)
- num_fig += 1
- print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
- 'D(x): %.2f, D(G(z)): %.2f'
- % (epoch + 1, num_epochs, n + 1, num_batches, d_loss.data[0], g_loss.data[0],
- real_score.data.mean(), fake_score.data.mean()))
- fig.close()
GAN网络原理介绍和代码的更多相关文章
- When I see you again(加密原理介绍,代码实现DES、AES、RSA、Base64、MD5)
关于网络安全的数据加密部分,本来打算总结一篇博客搞定,没想到东西太多,这已是第三篇了,而且这篇写了多次,熬了多次夜,真是again and again.起个名字:数据加密三部曲,前两部链接如下: 整体 ...
- 加密原理介绍,代码实现DES、AES、RSA、Base64、MD5
阅读目录 github下载地址 一.DES对称加密 二.AES对称加密 三.RSA非对称加密 四.实际使用 五.关于Padding 关于电脑终端Openssl加密解密命令 关于网络安全的数据加密部分, ...
- TF实战:(Mask R-CNN原理介绍与代码实现)-Chapter-8
二值掩膜输出依据种类预测分支(Faster R-CNN部分)预测结果:当前RoI的物体种类为i第i个二值掩膜输出就是该RoI的损失Lmask 对于预测的二值掩膜输出,我们对每个像素点应用sigmoid ...
- 『TensorFlow』通过代码理解gan网络_中
『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...
- 常见的GAN网络的相关原理及推导
常见的GAN网络的相关原理及推导 在上一篇中我们给大家介绍了GAN的相关原理和推导,GAN是VAE的后一半,再加上一个鉴别网络.这样而导致了完全不同的训练方式. GAN,生成对抗网络,主要有两部分构成 ...
- GAN网络从入门教程(一)之GAN网络介绍
GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...
- GAN网络从入门教程(二)之GAN原理
在一篇博客GAN网络从入门教程(一)之GAN网络介绍中,简单的对GAN网络进行了一些介绍,介绍了其是什么,然后大概的流程是什么. 在这篇博客中,主要是介绍其数学公式,以及其算法流程.当然数学公式只是简 ...
- 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上
GAN网络架构分析 上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake ...
- UIContainerView纯代码实现及原理介绍
UIContainerView纯代码实现及原理介绍 1.1-在StoryBoard中使用UIContainerView 1.2-纯代码使用UIContainerView 1.3-UIContainer ...
随机推荐
- ES-入门
https://es.xiaoleilu.com/010_Intro/10_Installing_ES.html 1. 安装 https://www.elastic.co/cn/downloads/ ...
- jQuery总结01_jq的基本概念+选择器
jQuery基本概念 学习目标:学会如何使用jQuery,掌握jQuery的常用api,能够使用jQuery实现常见的效果. 为什么要学习jQuery? [01-让div显示与设置内容.html] 使 ...
- maven 解决jar包冲突及简单使用
maven 解决jar包冲突 1.jar包冲突原因 maven中使用坐标导入jar包时会把与之相关的依赖jar包导入(导入spring-context的jar时就会把spring的整个主体导入) ,而 ...
- Viewpager+Fragment 跳转Activity报错android.os.TransactionTooLargeException: data parcel size xxxxx bytes
Viewpager + Fragment 跳转Activity报错android.os.TransactionTooLargeException: data parcel size xxxxx byt ...
- Linux下安装及使用mysql
(注:本人在centos7进行的安装及使用) 1.安装wget yum install wget 2.下载mysql安装包 wget http://repo.mysql.com/mysql57-com ...
- 二、VUE项目BaseCms系列文章:项目目录结构介绍
一. 目录结构截图 二. 目录结构说明 - documents 存放项目相关的文档文件 - api api 数据接口目录 - assets 资源文件目录 - components ...
- C#后台架构师成长之路-Orm篇体系
成为了高工,只是完成体系的熟练,这个时候就要学会啃一些框架了... 常用Orm底层框架的熟悉: 1.轻量泛型的DBHelper,一般高工都自己写的出来的 2.EF-基于Linq的,好好用 3.Keel ...
- SQL Server查看login所授予的具体权限
在SQL Server数据库中如何查看一个登录名(login)的具体权限呢,如果使用SSMS的UI界面查看登录名的具体权限的话,用户数据库非常多的话,要梳理完它所有的权限,操作又耗时又麻烦,个人十分崇 ...
- Tornado—添加请求头允许跨域请求访问
跨域请求访问 如果是前后端分离,那就肯定会遇到cros跨域请求难题,可以设置一个BaseHandler,然后继承即可. class BaseHandler(tornado.web.RequestHan ...
- Ubuntu系统修改资源为阿里云镜像
一般都会推荐使用国内的镜像源,比如163或者阿里云的镜像服务器将下列文本添加到/etc/apt/sources.list文件里 deb http://mirrors.aliyun.com/ubuntu ...