先放结果



这是通过GAN迭代训练30W次,耗时3小时生成的手写字图片效果,大部分的还是能看出来是数字的。

实现原理

简单说下原理,生成对抗网络需要训练两个任务,一个叫生成器,一个叫判别器,如字面意思,一个负责生成图片,一个负责判别图片,生成器不断生成新的图片,然后判别器去判断哪儿哪儿不行,生成器再不断去改进,不断的像真实的图片靠近。

这就如同一个造假团伙一样,A负责生产,B负责就鉴定,刚开始的时候,两个人都是菜鸟,A随便画了一幅画拿给B看,B说你这不行,然后A再改进,当然需要改进的不止A,随着A的改进,B也得不断提升,B需要发现更细微的差异,直至他们觉得已经没什么差异了(实际肯定还存在差异),他们便决定停止"训练",开始卖吧。

实现代码
# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-02-22
# @version: python2.7 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime
import numpy as np
import os
import matplotlib.pyplot as plt os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' class Config:
alpha = 1e-2
drop_rate = 0.5 # 保留比例
steps = 300000 # 迭代次数
batch_size = 128 # 每批次训练样本数
epochs = 100 # 训练轮次 num_units = 128
size = 784
noise_size = 100 smooth = 0.01
learning_rate = 1e-4 print_per_step = 1000 class Gan: def __init__(self):
print('Loading data......')
# 读取MNIST数据集
self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 定义占位符,真实图片和生成的图片
self.real_images = tf.placeholder(tf.float32, [None, Config.size], name='real_images')
self.noise = tf.placeholder(tf.float32, [None, Config.noise_size], name='noise')
self.drop_rate = tf.placeholder('float') self.train_step() def generator_graph(self, noise, n_units, out_dim, alpha, reuse=False): with tf.variable_scope('generator', reuse=reuse):
# Hidden layer
h1 = tf.layers.dense(noise, n_units, activation=None)
# Leaky ReLU
h1 = tf.maximum(alpha * h1, h1)
h1 = tf.layers.dropout(h1, rate=self.drop_rate)
# Logits and tanh output
logits = tf.layers.dense(h1, out_dim, activation=None)
out = tf.tanh(logits) return out @staticmethod
def discriminator_graph(image, n_units, alpha, reuse=False): with tf.variable_scope('discriminator', reuse=reuse):
# Hidden layer
h1 = tf.layers.dense(image, n_units, activation=None)
# Leaky ReLU
h1 = tf.maximum(alpha * h1, h1) logits = tf.layers.dense(h1, 1, activation=None)
# out = tf.sigmoid(logits) return logits def net(self):
# generator
fake_image = self.generator_graph(self.noise, Config.num_units, Config.size, Config.alpha) # discriminator
real_logits = self.discriminator_graph(self.real_images, Config.num_units, Config.alpha)
fake_logits = self.discriminator_graph(fake_image, Config.num_units, Config.alpha, reuse=True) # discriminator的loss
# 识别真实图片
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits)) * (
1 - Config.smooth))
# 识别生成的图片
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))
# 总体loss
d_loss = tf.add(d_loss_real, d_loss_fake) # generator的loss
g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits)) * (
1 - Config.smooth)) net_vars = tf.trainable_variables() # generator中的tensor
g_vars = [var for var in net_vars if var.name.startswith("generator")]
# discriminator中的tensor
d_vars = [var for var in net_vars if var.name.startswith("discriminator")] # optimizer
dis_optimizer = tf.train.AdamOptimizer(Config.learning_rate).minimize(d_loss, var_list=d_vars)
gen_optimizer = tf.train.AdamOptimizer(Config.learning_rate).minimize(g_loss, var_list=g_vars) return dis_optimizer, gen_optimizer, d_loss, g_loss def train_step(self):
dis_optimizer, gen_optimizer, d_loss, g_loss = self.net() print('Training & Evaluating......')
start_time = datetime.now()
sess = tf.Session()
sess.run(tf.global_variables_initializer()) for step in range(Config.steps):
real_image, _ = self.mnist.train.next_batch(Config.batch_size) real_image = real_image * 2 - 1 # generator的输入噪声
batch_noise = np.random.uniform(-1, 1, size=(Config.batch_size, Config.noise_size)) sess.run(gen_optimizer, feed_dict={self.noise: batch_noise, self.drop_rate: Config.drop_rate})
sess.run(dis_optimizer, feed_dict={self.noise: batch_noise, self.real_images: real_image}) if step % Config.print_per_step == 0:
dis_loss = sess.run(d_loss, feed_dict={self.noise: batch_noise, self.real_images: real_image})
gen_loss = sess.run(g_loss, feed_dict={self.noise: batch_noise, self.drop_rate: 1.})
end_time = datetime.now()
time_diff = (end_time - start_time).seconds msg = 'Step {:3}k Dis_Loss:{:6.2f}, Gen_Loss:{:6.2f}, Time_Usage:{:6.2f} mins.'
print(msg.format(int(step / 1000), dis_loss, gen_loss, time_diff / 60.)) self.gen_image(sess) def gen_image(self, sess):
sample_noise = np.random.uniform(-1, 1, size=(25, Config.noise_size))
samples = sess.run(
self.generator_graph(self.noise, Config.num_units, Config.size, Config.alpha, reuse=True),
feed_dict={self.noise: sample_noise}) plt.figure(figsize=(8, 8), dpi=80)
for i in range(25):
img = samples[i]
plt.subplot(5, 5, i + 1)
plt.imshow(img.reshape((28, 28)), cmap='Greys_r')
plt.axis('off')
plt.show() if __name__ == "__main__":
Gan()

Peace~~

使用生成对抗网络(GAN)生成手写字的更多相关文章

  1. TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成

    生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...

  2. 用MXNet实现mnist的生成对抗网络(GAN)

    用MXNet实现mnist的生成对抗网络(GAN) 生成式对抗网络(Generative Adversarial Network,简称GAN)由一个生成网络与一个判别网络组成.生成网络从潜在空间(la ...

  3. 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...

  4. 生成对抗网络GAN介绍

    GAN原理 生成对抗网络GAN由生成器和判别器两部分组成: 判别器是常规的神经网络分类器,一半时间判别器接收来自训练数据中的真实图像,另一半时间收到来自生成器中的虚假图像.训练判别器使得对于真实图像, ...

  5. 生成对抗网络(GAN)

    基本思想 GAN全称生成对抗网络,是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的. 譬如:我要升职加薪,你领导力还不行,我现在领导力有了要升职加薪,你执行力还不行,我现在执行力有了要升职加薪 ...

  6. 深度学习-生成对抗网络GAN笔记

    生成对抗网络(GAN)由2个重要的部分构成: 生成器G(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器 判别器D(Discriminator):判断这张图像是真实的 ...

  7. 深度学习框架PyTorch一书的学习-第七章-生成对抗网络(GAN)

    参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN生成动漫头像 GAN解决了非监督学习中的著名问题:给定一批样本,训 ...

  8. 科普 | ​生成对抗网络(GAN)的发展史

    来源:https://en.wikipedia.org/wiki/Edmond_de_Belamy 五年前,Generative Adversarial Networks(GANs)在深度学习领域掀起 ...

  9. 利用tensorflow训练简单的生成对抗网络GAN

    对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(di ...

  10. 生成对抗网络GAN详解与代码

    1.GAN的基本原理其实非常简单,这里以生成图片为例进行说明.假设我们有两个网络,G(Generator)和D(Discriminator).正如它的名字所暗示的那样,它们的功能分别是: G是一个生成 ...

随机推荐

  1. C++等号操作符重载

    在新学操作符重载时最令人头疼的可能就是一些堆溢出的问题了,不过呢,只要一步步的写好new 与 delete.绝对不会有类似的问题. 当时我们编译可以通过,但是运行会出错,因为对象s1与s2进行赋值时, ...

  2. 关于路由器漏洞利用,qemu环境搭建,网络配置的总结

    FAT 搭建的坑 1 先按照官方步骤进行,完成后进行如下步骤 2 修改 move /firmadyne into /firmware-analysis-toolkit navigate to the ...

  3. NOIP模拟 19

    最近试考的脑壳疼 晚上还有一场555 T1 count 研究性质题. 研究好了AC,研究不明白就没头绪 首先枚举n的因子d 其次发现因为是树,所以如果合法,贡献只能是1 然后发现如果合法,一定是一棵一 ...

  4. Asp.net Core 系列之--1.事件驱动初探:简单事件总线实现(SimpleEventBus)

    ChuanGoing 2019-08-06  前言 开篇之前,简单说明下随笔原因.在园子里游荡了好久,期间也起过要写一些关于.NET的随笔,因各种原因未能付诸实现. 前段时间拜读daxnet的系列文章 ...

  5. 远程传输命令scp

    Linux scp 命令用于 Linux 之间复制文件和目录. scp 是 secure copy 的缩写, scp 是 linux 系统下基于 ssh 登陆进行安全的远程文件拷贝命令. scp 是加 ...

  6. 羞,Java 字符串拼接竟然有这么多姿势

    二哥,我今年大二,看你分享的<阿里巴巴 Java 开发手册>上有一段内容说:"循环体内,拼接字符串最好使用 StringBuilder 的 append 方法,而不是 + 号操作 ...

  7. 支付宝小程序和微信小程序的区别(部分)

    支付宝小程序和微信小程序之间的互相转换 1.首先是文件名 微信小程序 wxss ------ 支付宝小程序 acss 微信小程序 wxml ------ 支付宝小程序 axml 2.调用方法前缀 微信 ...

  8. Python 基础 常用模块

    Python 为我们提供了很多功能强大的模块,今天就主要使用的到的模块进行整理,方便后面来翻阅学习. 一.时间模块 在时间模块中我们重点介绍几种自己常用的功能,主要方便我们按照自己想要的方式获取时间 ...

  9. sparkContext初始化机制

    sparkContext初始化机制 要点: 1.TaskSchedular如何注册,application.Excutor 如何反向注册 TaskScheduleImpl 即 TaskSchedula ...

  10. 第一篇: openJDK源码编译安装--mac版本

    1.为什么要编译JDK 想要一探JDK内部的实现机制,最便捷的路径之一就是自己编译一套JDK,通过阅读和跟踪调试JDK源码去了解Java技术体系的原理,虽然门槛高一点,但肯定比阅读各种书籍,文章,博客 ...