GAN网络的整体公式:

公式各参数介绍如下:

X是真实地图片,而对应的标签是1。

G(Z)是通过给定的噪声Z,生成图片(实际上是通过给定的Z生成一个tensor),对应的标签是0。

D是一个二分类网络,对于给定的图片判别真假。

D和G的参数更新方式:

D通过输入的真假图片,通过BCE(二分类交叉熵)更新自己的参数。

D对G(Z)生成的标签L,G尽可能使L为true,也就是1,通过BCE(二分类交叉熵)更新自己的参数。

公式演变:

对于G来说要使D无法判别自己生成的图片是假的,故而要使G(Z)越大越好,所以就使得V(G,D)越小越好;而对于D,使G(Z)越小D(X)越大,故而使V(G,D)越大越好

为了便于求导,故而加了log,变为如下:

最后对整个batch求期望,变为如下:

基于mnist实现的GAN网络结构对应的代码

  1. import itertools
  2. import math
  3. import time
  4.  
  5. import torch
  6. import torchvision
  7. import torch.nn as nn
  8. import torchvision.datasets as dsets
  9. import torchvision.transforms as transforms
  10. import matplotlib.pyplot as plt
  11. from IPython import display
  12. from torch.autograd import Variable
  13. transform = transforms.Compose([
  14. transforms.ToTensor(),
  15. transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
  16. ])
  17.  
  18. train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)
  19. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
  20.  
  21. class Discriminator(nn.Module):
  22. def __init__(self):
  23. super().__init__()
  24. self.model = nn.Sequential(
  25. nn.Linear(784, 1024),
  26. nn.LeakyReLU(0.2, inplace=True),
  27. nn.Dropout(0.3),
  28. nn.Linear(1024, 512),
  29. nn.LeakyReLU(0.2, inplace=True),
  30. nn.Dropout(0.3),
  31. nn.Linear(512, 256),
  32. nn.LeakyReLU(0.2, inplace=True),
  33. nn.Dropout(0.3),
  34. nn.Linear(256, 1),
  35. nn.Sigmoid()
  36. )
  37.  
  38. def forward(self, x):
  39. out = self.model(x.view(x.size(0), 784))
  40. out = out.view(out.size(0), -1)
  41. return out
  42.  
  43. class Generator(nn.Module):
  44. def __init__(self):
  45. super().__init__()
  46. self.model = nn.Sequential(
  47. nn.Linear(100, 256),
  48. nn.LeakyReLU(0.2, inplace=True),
  49. nn.Linear(256, 512),
  50. nn.LeakyReLU(0.2, inplace=True),
  51. nn.Linear(512, 1024),
  52. nn.LeakyReLU(0.2, inplace=True),
  53. nn.Linear(1024, 784),
  54. nn.Tanh()
  55. )
  56.  
  57. def forward(self, x):
  58. x = x.view(x.size(0), -1)
  59. out = self.model(x)
  60. return out
  61.  
  62. discriminator = Discriminator().cuda()
  63. generator = Generator().cuda()
  64. criterion = nn.BCELoss()
  65. lr = 0.0002
  66. d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
  67. g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
  68.  
  69. def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels):
  70. discriminator.zero_grad()
  71. outputs = discriminator(images)
  72. real_loss = criterion(outputs, real_labels)
  73. real_score = outputs
  74.  
  75. outputs = discriminator(fake_images)
  76. fake_loss = criterion(outputs, fake_labels)
  77. fake_score = outputs
  78.  
  79. d_loss = real_loss + fake_loss
  80. d_loss.backward()
  81. d_optimizer.step()
  82. return d_loss, real_score, fake_score
  83. def train_generator(generator, discriminator_outputs, real_labels):
  84. generator.zero_grad()
  85. g_loss = criterion(discriminator_outputs, real_labels)
  86. g_loss.backward()
  87. g_optimizer.step()
  88. return g_loss
  89.  
  90. # draw samples from the input distribution to inspect the generation on training
  91. num_test_samples = 16
  92. test_noise = Variable(torch.randn(num_test_samples, 100).cuda())
  93. # create figure for plotting
  94. size_figure_grid = int(math.sqrt(num_test_samples))
  95. fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
  96. for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
  97. ax[i, j].get_xaxis().set_visible(False)
  98. ax[i, j].get_yaxis().set_visible(False)
  99.  
  100. # set number of epochs and initialize figure counter
  101. num_epochs = 200
  102. num_batches = len(train_loader)
  103. num_fig = 0
  104.  
  105. for epoch in range(num_epochs):
  106. for n, (images, _) in enumerate(train_loader):
  107. images = Variable(images.cuda())
  108. real_labels = Variable(torch.ones(images.size(0)).cuda())
  109.  
  110. # Sample from generator
  111. noise = Variable(torch.randn(images.size(0), 100).cuda())
  112. fake_images = generator(noise)
  113. fake_labels = Variable(torch.zeros(images.size(0)).cuda())
  114.  
  115. # Train the discriminator
  116. d_loss, real_score, fake_score = train_discriminator(discriminator, images, real_labels, fake_images,
  117. fake_labels)
  118.  
  119. # Sample again from the generator and get output from discriminator
  120. noise = Variable(torch.randn(images.size(0), 100).cuda())
  121. fake_images = generator(noise)
  122. outputs = discriminator(fake_images)
  123.  
  124. # Train the generator
  125. g_loss = train_generator(generator, outputs, real_labels)
  126.  
  127. if (n + 1) % 100 == 0:
  128. test_images = generator(test_noise)
  129.  
  130. for k in range(num_test_samples):
  131. i = k // 4
  132. j = k % 4
  133. ax[i, j].cla()
  134. ax[i, j].imshow(test_images[k, :].data.cpu().numpy().reshape(28, 28), cmap='Greys')
  135. display.clear_output(wait=True)
  136. display.display(plt.gcf())
  137.  
  138. plt.savefig('results/mnist-gan-%03d.png' % num_fig)
  139. num_fig += 1
  140. print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
  141. 'D(x): %.2f, D(G(z)): %.2f'
  142. % (epoch + 1, num_epochs, n + 1, num_batches, d_loss.data[0], g_loss.data[0],
  143. real_score.data.mean(), fake_score.data.mean()))
  144.  
  145. fig.close()

GAN网络原理介绍和代码的更多相关文章

  1. When I see you again(加密原理介绍,代码实现DES、AES、RSA、Base64、MD5)

    关于网络安全的数据加密部分,本来打算总结一篇博客搞定,没想到东西太多,这已是第三篇了,而且这篇写了多次,熬了多次夜,真是again and again.起个名字:数据加密三部曲,前两部链接如下: 整体 ...

  2. 加密原理介绍,代码实现DES、AES、RSA、Base64、MD5

    阅读目录 github下载地址 一.DES对称加密 二.AES对称加密 三.RSA非对称加密 四.实际使用 五.关于Padding 关于电脑终端Openssl加密解密命令 关于网络安全的数据加密部分, ...

  3. TF实战:(Mask R-CNN原理介绍与代码实现)-Chapter-8

    二值掩膜输出依据种类预测分支(Faster R-CNN部分)预测结果:当前RoI的物体种类为i第i个二值掩膜输出就是该RoI的损失Lmask 对于预测的二值掩膜输出,我们对每个像素点应用sigmoid ...

  4. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  5. 常见的GAN网络的相关原理及推导

    常见的GAN网络的相关原理及推导 在上一篇中我们给大家介绍了GAN的相关原理和推导,GAN是VAE的后一半,再加上一个鉴别网络.这样而导致了完全不同的训练方式. GAN,生成对抗网络,主要有两部分构成 ...

  6. GAN网络从入门教程(一)之GAN网络介绍

    GAN网络从入门教程(一)之GAN网络介绍 稍微的开一个新坑,同样也是入门教程(因此教程的内容不会是从入门到精通,而是从入门到入土).主要是为了完成数据挖掘的课程设计,然后就把挖掘榔头挖到了GAN网络 ...

  7. GAN网络从入门教程(二)之GAN原理

    在一篇博客GAN网络从入门教程(一)之GAN网络介绍中,简单的对GAN网络进行了一些介绍,介绍了其是什么,然后大概的流程是什么. 在这篇博客中,主要是介绍其数学公式,以及其算法流程.当然数学公式只是简 ...

  8. 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上

    GAN网络架构分析 上图即为GAN的逻辑架构,其中的noise vector就是特征向量z,real images就是输入变量x,标签的标准比较简单(二分类么),real的就是tf.ones,fake ...

  9. UIContainerView纯代码实现及原理介绍

    UIContainerView纯代码实现及原理介绍 1.1-在StoryBoard中使用UIContainerView 1.2-纯代码使用UIContainerView 1.3-UIContainer ...

随机推荐

  1. ES-入门

    https://es.xiaoleilu.com/010_Intro/10_Installing_ES.html 1. 安装 https://www.elastic.co/cn/downloads/ ...

  2. jQuery总结01_jq的基本概念+选择器

    jQuery基本概念 学习目标:学会如何使用jQuery,掌握jQuery的常用api,能够使用jQuery实现常见的效果. 为什么要学习jQuery? [01-让div显示与设置内容.html] 使 ...

  3. maven 解决jar包冲突及简单使用

    maven 解决jar包冲突 1.jar包冲突原因 maven中使用坐标导入jar包时会把与之相关的依赖jar包导入(导入spring-context的jar时就会把spring的整个主体导入) ,而 ...

  4. Viewpager+Fragment 跳转Activity报错android.os.TransactionTooLargeException: data parcel size xxxxx bytes

    Viewpager + Fragment 跳转Activity报错android.os.TransactionTooLargeException: data parcel size xxxxx byt ...

  5. Linux下安装及使用mysql

    (注:本人在centos7进行的安装及使用) 1.安装wget yum install wget 2.下载mysql安装包 wget http://repo.mysql.com/mysql57-com ...

  6. 二、VUE项目BaseCms系列文章:项目目录结构介绍

    一. 目录结构截图 二. 目录结构说明 - documents    存放项目相关的文档文件 - api   api 数据接口目录 - assets    资源文件目录 - components   ...

  7. C#后台架构师成长之路-Orm篇体系

    成为了高工,只是完成体系的熟练,这个时候就要学会啃一些框架了... 常用Orm底层框架的熟悉: 1.轻量泛型的DBHelper,一般高工都自己写的出来的 2.EF-基于Linq的,好好用 3.Keel ...

  8. SQL Server查看login所授予的具体权限

    在SQL Server数据库中如何查看一个登录名(login)的具体权限呢,如果使用SSMS的UI界面查看登录名的具体权限的话,用户数据库非常多的话,要梳理完它所有的权限,操作又耗时又麻烦,个人十分崇 ...

  9. Tornado—添加请求头允许跨域请求访问

    跨域请求访问 如果是前后端分离,那就肯定会遇到cros跨域请求难题,可以设置一个BaseHandler,然后继承即可. class BaseHandler(tornado.web.RequestHan ...

  10. Ubuntu系统修改资源为阿里云镜像

    一般都会推荐使用国内的镜像源,比如163或者阿里云的镜像服务器将下列文本添加到/etc/apt/sources.list文件里 deb http://mirrors.aliyun.com/ubuntu ...