使用Keras编写GAN的入门

GAN

Time: 2017-5-31


前言

主要参考了网页[1]的教程,同时主要算法来自Ian J. Goodfellow 的论文,算法如下:

gan

代码

  1. %matplotlib inline
  2. import numpy as np
  3. import pandas as pd
  4. from keras.models import Model
  5. from keras.layers import Dense, Activation, Input, Reshape
  6. from keras.layers import Conv1D, Flatten, Dropout
  7. from keras.optimizers import SGD, Adam
  8. from tqdm import tqdm_notebook as tqdm # 进度条
  9. # 生成随机正弦曲线的数据
  10. def sample_data(n_samples=10000, x_vals=np.arange(0, 5, .1), max_offset=1000, mul_range=[1, 2]):
  11. vectors = []
  12. for i in range(n_samples):
  13. offset = np.random.random() * max_offset
  14. mul = mul_range[0] + np.random.random() * (mul_range[1] - mul_range[0])
  15. vectors.append(np.sin(offset + x_vals * mul) / 2 + .5)
  16. return np.array(vectors)
  17. # 创建生成模型
  18. def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):
  19. x = Dense(dense_dim)(G_in)
  20. x = Activation('tanh')(x)
  21. G_out = Dense(out_dim, activation='tanh')(x)
  22. G = Model(G_in, G_out)
  23. opt = SGD(lr=lr)
  24. G.compile(loss='binary_crossentropy', optimizer=opt)
  25. return G, G_out
  26. # 创建判别模型
  27. def get_discriminative(D_in, lr=1e-3, drate = .25, n_channels=50, conv_sz=5, leak=.2):
  28. x = Reshape((-1, 1))(D_in)
  29. x = Conv1D(n_channels, conv_sz, activation='relu')(x)
  30. x = Dropout(drate)(x)
  31. x = Flatten()(x)
  32. x = Dense(n_channels)(x)
  33. D_out = Dense(2, activation='sigmoid')(x)
  34. D = Model(D_in, D_out)
  35. dopt = Adam(lr=lr)
  36. D.compile(loss='binary_crossentropy', optimizer=dopt)
  37. return D, D_out
  38. def set_trainability(model, trainable=False):
  39. model.trainable = trainable
  40. for layer in model.layers:
  41. layer.trainable = trainable
  42. def make_gan(GAN_in, G, D):
  43. set_trainability(D, False)
  44. x = G(GAN_in)
  45. GAN_out = D(x)
  46. GAN = Model(GAN_in, GAN_out)
  47. GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)
  48. return GAN, GAN_out
  49. # 通过生成数据 预训练判别模型
  50. def sample_data_and_gen(G, noise_dim=10, n_samples=10000):
  51. XT = sample_data(n_samples=n_samples)
  52. XN_noise = np.random.uniform(0, 1, size=[n_samples, noise_dim])
  53. XN = G.predict(XN_noise)
  54. X = np.concatenate((XT, XN))
  55. y = np.zeros((2*n_samples, 2))
  56. y[:n_samples, 1] = 1
  57. y[n_samples:, 0] = 1
  58. return X, y
  59. def pretrain(G, D, noise_dim=10, n_samples=10000, batch_size=32):
  60. X, y = sample_data_and_gen(G, noise_dim=noise_dim, n_samples=n_samples)
  61. set_trainability(D, True)
  62. D.fit(X, y, epochs=1, batch_size=batch_size)
  63. # 开始交叉训练步骤
  64. def sample_noise(G, noise_dim=10, n_samples=10000):
  65. X = np.random.uniform(0, 1, size=[n_samples, noise_dim])
  66. y = np.zeros((n_samples, 2))
  67. y[:, 1] = 1
  68. return X, y
  69. def train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50):
  70. d_loss = []
  71. g_loss = []
  72. e_range = range(epochs)
  73. if verbose:
  74. e_range = tqdm(e_range)
  75. for epoch in e_range:
  76. X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim) # 对D进行训练
  77. set_trainability(D, True)
  78. d_loss.append(D.train_on_batch(X, y))
  79. X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim) # 对G训练
  80. set_trainability(D, False)
  81. g_loss.append(GAN.train_on_batch(X, y))
  82. if verbose and (epoch + 1) % v_freq == 0:
  83. print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))
  84. return d_loss, g_loss
  1. ax = pd.DataFrame(np.transpose(sample_data(5))).plot()
  2. G_in = Input(shape=[10])
  3. G, G_out = get_generative(G_in)
  4. G.summary()
  5. D_in = Input(shape=[50])
  6. D, D_out = get_discriminative(D_in)
  7. D.summary()
  1. _________________________________________________________________
  2. Layer (type) Output Shape Param #
  3. =================================================================
  4. input_9 (InputLayer) (None, 10) 0
  5. _________________________________________________________________
  6. dense_13 (Dense) (None, 200) 2200
  7. _________________________________________________________________
  8. activation_4 (Activation) (None, 200) 0
  9. _________________________________________________________________
  10. dense_14 (Dense) (None, 50) 10050
  11. =================================================================
  12. Total params: 12,250
  13. Trainable params: 12,250
  14. Non-trainable params: 0
  15. _________________________________________________________________
  16. _________________________________________________________________
  17. Layer (type) Output Shape Param #
  18. =================================================================
  19. input_10 (InputLayer) (None, 50) 0
  20. _________________________________________________________________
  21. reshape_4 (Reshape) (None, 50, 1) 0
  22. _________________________________________________________________
  23. conv1d_4 (Conv1D) (None, 46, 50) 300
  24. _________________________________________________________________
  25. dropout_4 (Dropout) (None, 46, 50) 0
  26. _________________________________________________________________
  27. flatten_4 (Flatten) (None, 2300) 0
  28. _________________________________________________________________
  29. dense_15 (Dense) (None, 50) 115050
  30. _________________________________________________________________
  31. dense_16 (Dense) (None, 2) 102
  32. =================================================================
  33. Total params: 115,452
  34. Trainable params: 115,452
  35. Non-trainable params: 0
  36. _________________________________________________________________

png
  1. GAN_in = Input([10])
  2. GAN, GAN_out = make_gan(GAN_in, G, D)
  3. GAN.summary()
  1. _________________________________________________________________
  2. Layer (type) Output Shape Param #
  3. =================================================================
  4. input_11 (InputLayer) (None, 10) 0
  5. _________________________________________________________________
  6. model_9 (Model) (None, 50) 12250
  7. _________________________________________________________________
  8. model_10 (Model) (None, 2) 115452
  9. =================================================================
  10. Total params: 127,702
  11. Trainable params: 12,250
  12. Non-trainable params: 115,452
  13. _________________________________________________________________
  1. pretrain(G, D)
  1. Epoch 1/1
  2. 20000/20000 [==============================] - 3s - loss: 0.0072
  1. d_loss, g_loss = train(GAN, G, D, verbose=True)
  1. Epoch #50: Generative Loss: 4.41527795791626, Discriminative Loss: 0.6733301877975464
  2. Epoch #100: Generative Loss: 3.8898046016693115, Discriminative Loss: 0.09901376813650131
  3. Epoch #150: Generative Loss: 6.2410054206848145, Discriminative Loss: 0.034074194729328156
  4. Epoch #200: Generative Loss: 5.206066608428955, Discriminative Loss: 0.13078376650810242
  5. Epoch #250: Generative Loss: 3.5144925117492676, Discriminative Loss: 0.07160962373018265
  6. Epoch #300: Generative Loss: 3.705162525177002, Discriminative Loss: 0.05893774330615997
  7. Epoch #350: Generative Loss: 3.511479616165161, Discriminative Loss: 0.09775738418102264
  8. Epoch #400: Generative Loss: 4.141300678253174, Discriminative Loss: 0.03169865906238556
  9. Epoch #450: Generative Loss: 3.500260829925537, Discriminative Loss: 0.05957922339439392
  10. Epoch #500: Generative Loss: 2.9797921180725098, Discriminative Loss: 0.10566817969083786
  1. ax = pd.DataFrame(
  2. {
  3. 'Generative Loss': g_loss,
  4. 'Discriminative Loss': d_loss,
  5. }
  6. ).plot(title='Training loss', logy=True)
  7. ax.set_xlabel("Epochs")
  8. ax.set_ylabel("Loss")

png
  1. N_VIEWED_SAMPLES = 2
  2. data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
  3. pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).plot()

png
  1. N_VIEWED_SAMPLES = 2
  2. data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
  3. pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()

png

reference

[1] http://www.rricard.me/machine/learning/generative/adversarial/networks/keras/tensorflow/2017/04/05/gans-part2.html#Imports

使用Keras编写GAN的入门的更多相关文章

  1. BAT脚本编写教程简单入门篇

    BAT脚本编写教程简单入门篇 批处理文件最常用的几个命令: echo表示显示此命令后的字符 echo on  表示在此语句后所有运行的命令都显示命令行本身 echo off 表示在此语句后所有运行的命 ...

  2. keras搭建神经网络快速入门笔记

    之前学习了tensorflow2.0的小伙伴可能会遇到一些问题,就是在读论文中的代码和一些实战项目往往使用keras+tensorflow1.0搭建, 所以本次和大家一起分享keras如何搭建神经网络 ...

  3. 在ubuntu下编写python(python入门)

    在ubuntu下编写python 一般情况下,ubuntu已经安装了python,打开终端,直接输入python,即可进行python编写. 默认为python2 如果想写python3,在终端输入p ...

  4. 【深度学习】--GAN从入门到初始

    一.前述 GAN,生成对抗网络,在2016年基本火爆深度学习,所有有必要学习一下.生成对抗网络直观的应用可以帮我们生成数据,图片. 二.具体 1.生活案例 比如假设真钱 r 坏人定义为G  我们通过 ...

  5. Linux编写Shell脚本入门

    一. 一般编写shell需要分3个步骤 1. 新建一个脚本文件,并编写程序 vi hello.sh #!/bin/bash #注释 #输出 printf '%s\n' "Hello Worl ...

  6. keras人工神经网络构建入门

    //2019.07.29-301.Keras 是提供一些高度可用神经网络框架的 Python API ,能帮助你快速的构建和训练自己的深度学习模型,它的后端是 TensorFlow 或者 Theano ...

  7. keras运行gan的几个bug解决

    http://blog.csdn.net/u012317000/article/details/79211274 https://www.jianshu.com/p/5b1f7004144d

  8. GAN网络之入门教程(四)之基于DCGAN动漫头像生成

    目录 使用前准备 数据集 定义参数 构建网络 构建G网络 构建D网络 构建GAN网络 关于GAN的小trick 训练 总结 参考 这一篇博客以代码为主,主要是来介绍如果使用keras构建一个DCGAN ...

  9. WPF 像素着色器入门:使用 Shazzam Shader Editor 编写 HLSL 像素着色器代码

    原文:WPF 像素着色器入门:使用 Shazzam Shader Editor 编写 HLSL 像素着色器代码 HLSL,High Level Shader Language,高级着色器语言,是 Di ...

随机推荐

  1. web自动化测试—selenium操作游览器属性

    # coding=utf-8'''web游览器属性: 页面最大化 maximize_window() 获取当前页面地址 current_url 代码 page_source title title 后 ...

  2. SpringCloud服务组合

    SpringCloud生态强调微服务,微服务也就意味着将各个功能独立的业务抽象出来,做成一个单独的服务供外部调用.但每个人对服务究竟要有多“微”的理解差异很大,导致微服务的粒度很难掌控,划分规则也不统 ...

  3. 吝啬的国度 ---用vector 来构图

    根据题目可以看出来  有n 个城市 只有 n-1  条路线 那么  就可以确定这个图中  不存在 圆  所以从一个点到另一个点 只有一条唯一的路  所以从一个节点到另一个节点 那么 这个节点只有一个唯 ...

  4. $ST表刷题记录$

    \(st表的题目不太多\) 我做过的就这些吧. https://www.luogu.org/problemnew/show/P3865 https://www.luogu.org/problemnew ...

  5. AndroidManifest.xml详解

    一.关于AndroidManifest.xml AndroidManifest.xml 是每个android程序中必须的文件.它位于整个项目的根目录,描述了package中暴露的组件(activiti ...

  6. JS——switch case

    语法: switch(n) { case 1: 执行代码块 1 break; case 2: 执行代码块 2 break; default: n 与 case 1 和 case 2 不同时执行的代码 ...

  7. MFC CAD控制权问题

    begineditorcommand(); 隐藏对话框  把控制权交给CAD completeeditorcommand(); 完成交互返回到应用程序 canceleditorcommand CAD被 ...

  8. strcmp 与 _tcscmp

    strcmp 用来比较ANSI字符串,而_tcscmp用来比较UNICODE(宽字符)的字符串.ANSI字符串中,1个英文字母为1个字节,1个中文字符为2个字节,遇到0字符表示字符串结束.而在UNIC ...

  9. matplotlib命令与格式:标题(title),标注(annotate),文字说明(text)

      1.title设置图像标题 (1)title常用参数 fontsize设置字体大小,默认12,可选参数 ['xx-small', 'x-small', 'small', 'medium', 'la ...

  10. DP背包问题小总结

    DP的背包问题可谓是最基础的DP了,分为01背包,完全背包,多重背包 01背包 装与不装是一个问题 01背包基本模型,背包的总体积为v,总共有n件物体,每件物品的体积为v[i],价值为w[i],每件物 ...