参考资料

GAN原理学习笔记

生成式对抗网络GAN汇总

GAN的理解与TensorFlow的实现

TensorFlow小试牛刀(2):GAN生成手写数字

参考代码之一

#coding=utf-8
#http://blog.csdn.net/u012223913/article/details/75051516?locationNum=1&fps=1
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tensorflow.examples.tutorials.mnist import input_data sess = tf.InteractiveSession() mb_size = 128
Z_dim = 100 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) def weight_var(shape, name):
return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer()) def bias_var(shape, name):
return tf.get_variable(name=name, shape=shape, initializer=tf.constant_initializer(0)) # discriminater net X = tf.placeholder(tf.float32, shape=[None, 784], name='X') # X [128 784] z1 = W * x + b
# 矩阵乘法 128 * 784 * 784 * 128 = [128 128]
D_W1 = weight_var([784, 128], 'D_W1')
D_b1 = bias_var([128], 'D_b1') #z2 = W * z1 + b
# 矩阵乘法 128 * 128 * 128 * 1 = [128 1] 输出判决结果,二分类
D_W2 = weight_var([128, 1], 'D_W2')
D_b2 = bias_var([1], 'D_b2'
)
theta_D = [D_W1, D_W2, D_b1, D_b2] # generator net Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z') # z [128 784] z1 = W * x + b
# 矩阵乘法 128 * 100 * 100 * 128 = [128 128]
G_W1 = weight_var([100, 128], 'G_W1')
G_b1 = bias_var([128], 'G_B1') #z2 = W * z1 + b
# 矩阵乘法 128 * 128 * 128 * 784 = [128 784] 输出28*28的图像
G_W2 = weight_var([128, 784], 'G_W2')
G_b2 = bias_var([784], 'G_B2') theta_G = [G_W1, G_W2, G_b1, G_b2] def generator(z):
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob) return G_prob def discriminator(x):
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample) #discriminator输出为1表示ground truth
#discriminator输出为0表示非ground truth
#对于生成网络希望两点:
#(2)希望D_real尽可能大,这样保证正确识别真正的样本
#(1)希望D_fake尽可能小,这样可以剔除假的生成样本
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) #对于判别网络, 希望D_fake尽可能大,这样可以迷惑生成网络,
G_loss = -tf.reduce_mean(tf.log(D_fake)) D_optimizer = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_optimizer = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) init = tf.initialize_all_variables()
saver = tf.train.Saver()
# 启动默认图
sess = tf.Session()
# 初始化
sess.run(init) def sample_Z(m, n):
'''Uniform prior for G(Z)'''
return np.random.uniform(-1., 1., size=[m, n]) def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05) for i, sample in enumerate(samples): # [i,samples[i]] imax=16
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r') return fig if not os.path.exists('out/'):
os.makedirs('out/') i = 0 for it in range(1000000):
if it % 1000 == 0:
samples = sess.run(G_sample, feed_dict={
Z: sample_Z(16, Z_dim)}) # 16*784
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
i += 1
plt.close(fig) X_mb, _ = mnist.train.next_batch(mb_size)#ground truth _, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={
X: X_mb, Z: sample_Z(mb_size, Z_dim)})
_, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={
Z: sample_Z(mb_size, Z_dim)}) if it % 1000 == 0:
print('Iter: {}'.format(it))
print('D loss: {:.4}'.format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
print()

参考代码之二

#http://blog.csdn.net/sparkexpert/article/details/70147409

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#from skimage.io import imsave
import scipy
import os
import shutil img_height = 28
img_width = 28
img_size = img_height * img_width to_train = True
to_restore = False
output_path = "output" # 总迭代次数500
max_epoch = 500 h1_size = 150
h2_size = 300
z_size = 100
batch_size = 256 # generate (model 1)
def build_generator(z_prior):
w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)
b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)
h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)
w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3
x_generate = tf.nn.tanh(h3)
g_params = [w1, b1, w2, b2, w3, b3]
return x_generate, g_params # discriminator (model 2)
def build_discriminator(x_data, x_generated, keep_prob):
# tf.concat
x_in = tf.concat([x_data, x_generated], 0)
w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)
b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)
w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)
w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3
y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))
y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))
d_params = [w1, b1, w2, b2, w3, b3]
return y_data, y_generated, d_params #
def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):
batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5
img_h, img_w = batch_res.shape[1], batch_res.shape[2]
grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
for i, res in enumerate(batch_res):
if i >= grid_size[0] * grid_size[1]:
break
img = (res) * 255
img = img.astype(np.uint8)
row = (i // grid_size[0]) * (img_h + grid_pad)
col = (i % grid_size[1]) * (img_w + grid_pad)
img_grid[row:row + img_h, col:col + img_w] = img
#imsave(fname, img_grid)
#img.save('output/num.jpg')
scipy.misc.imsave(fname, img_grid) def train():
# load data(mnist手写数据集)
mnist = input_data.read_data_sets('mnist_data', one_hot=True) x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data")
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
global_step = tf.Variable(0, name="global_step", trainable=False) # 创建生成模型
x_generated, g_params = build_generator(z_prior)
# 创建判别模型
y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob) # 损失函数的设置
d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))
g_loss = - tf.log(y_generated) optimizer = tf.train.AdamOptimizer(0.0001) # 两个模型的优化函数
d_trainer = optimizer.minimize(d_loss, var_list=d_params)
g_trainer = optimizer.minimize(g_loss, var_list=g_params) init = tf.initialize_all_variables() saver = tf.train.Saver()
# 启动默认图
sess = tf.Session()
# 初始化
sess.run(init) if to_restore:
chkpt_fname = tf.train.latest_checkpoint(output_path)
saver.restore(sess, chkpt_fname)
else:
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.mkdir(output_path) z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) steps = 60000 / batch_size
for i in range(sess.run(global_step), max_epoch):
for j in np.arange(steps):
# for j in range(steps):
print("epoch:%s, iter:%s" % (i, j))
# 每一步迭代,我们都会加载256个训练样本,然后执行一次train_step
x_value, _ = mnist.train.next_batch(batch_size)
x_value = 2 * x_value.astype(np.float32) - 1
z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# 执行生成
sess.run(d_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
# 执行判别
if j % 1 == 0:
sess.run(g_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
show_result(x_gen_val, "output/sample{0}.jpg".format(i))
z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
show_result(x_gen_val, "output/random_sample{0}.jpg".format(i))
sess.run(tf.assign(global_step, i + 1))
saver.save(sess, os.path.join(output_path, "model"), global_step=global_step) def test():
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
x_generated, _ = build_generator(z_prior)
chkpt_fname = tf.train.latest_checkpoint(output_path) init = tf.initialize_all_variables()
sess = tf.Session()
saver = tf.train.Saver()
sess.run(init)
saver.restore(sess, chkpt_fname)
z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
show_result(x_gen_val, "output/test_result.jpg") if __name__ == '__main__':
if to_train:
train()
else:
test()

GAN 生成mnist数据的更多相关文章

  1. GAN生成式对抗网络(三)——mnist数据生成

    通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...

  2. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  3. 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...

  4. 对抗生成网络-图像卷积-mnist数据生成(代码) 1.tf.layers.conv2d(卷积操作) 2.tf.layers.conv2d_transpose(反卷积操作) 3.tf.layers.batch_normalize(归一化操作) 4.tf.maximum(用于lrelu) 5.tf.train_variable(训练中所有参数) 6.np.random.uniform(生成正态数据

    1. tf.layers.conv2d(input, filter, kernel_size, stride, padding) # 进行卷积操作 参数说明:input输入数据, filter特征图的 ...

  5. GAN︱生成模型学习笔记(运行机制、NLP结合难点、应用案例、相关Paper)

    我对GAN"生成对抗网络"(Generative Adversarial Networks)的看法: 前几天在公开课听了新加坡国立大学[机器学习与视觉实验室]负责人冯佳时博士在[硬 ...

  6. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  7. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  8. GAN生成的评价指标 Evaluation of GAN

    传统方法中,如何衡量一个generator ?-- 用 generator 产生数据的 likelihood,越大越好. 但是 GAN 中的 generator 是隐式建模,所以只能从 P_G 中采样 ...

  9. Enterprise Solution 生成实体数据访问接口与实现类型 Code Smith 6.5 模板文件下载

    数据库表定义为SalesOrder,用LLBL Gen Pro生成的实体定义是SalesOrderEntity,再用Code Smith生成的数据读写接口是ISalesOrderManager,最后是 ...

随机推荐

  1. 算法笔记_142:无向图的欧拉回路求解(Java)

    目录 1 问题描述 2 解决方案   1 问题描述 John's trip Time Limit: 1000MS   Memory Limit: 65536K Total Submissions: 8 ...

  2. jquery获取json对象中的key小技巧,遍历json串所有key,value

    比如有一个json var json = {"name" : "Tom", "age" : 18}; 想分别获取它的key 和 value ...

  3. javascript 新知识

     document.compatMode 属性 BackCompat: Standards-compliant mode is not switched on. (Quirks Mode)  标准模式 ...

  4. ubuntu PATH 出错修复

    我的 ubuntu10.10设置交叉编译环境时,PATH 设置错误了,导致无法正常启动,错误情况如下: { PATH:找不到命令ubuntu2010@ubuntu:~$ ls命令 'ls' 可在 '/ ...

  5. C#调取java接口

    1. public class APIRequest    {       //public static string commonUrl = @"http://192.168.2.186 ...

  6. EAST 自然场景文本检测

           自然场景文本检测是图像处理的核心模块,也是一直想要接触的一个方面. 刚好看到国内的旷视今年在CVPR2017的一篇文章:EAST: An Efficient and Accurate S ...

  7. MySQL主从不一致的几种故障总结分析、解决和预防

    (1).主从不一致故障,从库宕机,从库启动后重复写入数据报错解决与预防:relay_log_info_repository=TABLE(InnoDB)参数解释说明:若relay_log_info_re ...

  8. 关于php使用基于socket Web消息推送(未完)

    转:http://blog.csdn.net/young_phper/article/details/52441143 http://www.workerman.net/ http://blog.cs ...

  9. unity5,UI Button too small on device than in Game View解决办法

    假设测试设备为iphone5(横屏).下面说明如何使真机上ui显示效果与Game View中一致. 1,首先Game View左上角屏幕规格选 iPhone 5 Wide (16:9),如图: 2,在 ...

  10. Oracle宣布很多其它的Java 9 新特性

    随着Oracle确认了其余的4个Java 9特性,下一代Java的计划開始变得更清晰了,Oracle已经发布了第二套Java 9特性.自从Oracle在今年早些时候宣布了3个新的API和模块化源代码后 ...