GAN 生成mnist数据
参考资料
GAN原理学习笔记
生成式对抗网络GAN汇总
GAN的理解与TensorFlow的实现
TensorFlow小试牛刀(2):GAN生成手写数字
参考代码之一
#coding=utf-8
#http://blog.csdn.net/u012223913/article/details/75051516?locationNum=1&fps=1
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tensorflow.examples.tutorials.mnist import input_data sess = tf.InteractiveSession() mb_size = 128
Z_dim = 100 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) def weight_var(shape, name):
return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer()) def bias_var(shape, name):
return tf.get_variable(name=name, shape=shape, initializer=tf.constant_initializer(0)) # discriminater net X = tf.placeholder(tf.float32, shape=[None, 784], name='X') # X [128 784] z1 = W * x + b
# 矩阵乘法 128 * 784 * 784 * 128 = [128 128]
D_W1 = weight_var([784, 128], 'D_W1')
D_b1 = bias_var([128], 'D_b1') #z2 = W * z1 + b
# 矩阵乘法 128 * 128 * 128 * 1 = [128 1] 输出判决结果,二分类
D_W2 = weight_var([128, 1], 'D_W2')
D_b2 = bias_var([1], 'D_b2')
theta_D = [D_W1, D_W2, D_b1, D_b2] # generator net Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z') # z [128 784] z1 = W * x + b
# 矩阵乘法 128 * 100 * 100 * 128 = [128 128]
G_W1 = weight_var([100, 128], 'G_W1')
G_b1 = bias_var([128], 'G_B1') #z2 = W * z1 + b
# 矩阵乘法 128 * 128 * 128 * 784 = [128 784] 输出28*28的图像
G_W2 = weight_var([128, 784], 'G_W2')
G_b2 = bias_var([784], 'G_B2') theta_G = [G_W1, G_W2, G_b1, G_b2] def generator(z):
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob) return G_prob def discriminator(x):
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample) #discriminator输出为1表示ground truth
#discriminator输出为0表示非ground truth
#对于生成网络希望两点:
#(2)希望D_real尽可能大,这样保证正确识别真正的样本
#(1)希望D_fake尽可能小,这样可以剔除假的生成样本
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) #对于判别网络, 希望D_fake尽可能大,这样可以迷惑生成网络,
G_loss = -tf.reduce_mean(tf.log(D_fake)) D_optimizer = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_optimizer = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) init = tf.initialize_all_variables()
saver = tf.train.Saver()
# 启动默认图
sess = tf.Session()
# 初始化
sess.run(init) def sample_Z(m, n):
'''Uniform prior for G(Z)'''
return np.random.uniform(-1., 1., size=[m, n]) def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05) for i, sample in enumerate(samples): # [i,samples[i]] imax=16
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r') return fig if not os.path.exists('out/'):
os.makedirs('out/') i = 0 for it in range(1000000):
if it % 1000 == 0:
samples = sess.run(G_sample, feed_dict={
Z: sample_Z(16, Z_dim)}) # 16*784
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
i += 1
plt.close(fig) X_mb, _ = mnist.train.next_batch(mb_size)#ground truth _, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={
X: X_mb, Z: sample_Z(mb_size, Z_dim)})
_, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={
Z: sample_Z(mb_size, Z_dim)}) if it % 1000 == 0:
print('Iter: {}'.format(it))
print('D loss: {:.4}'.format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
print()
参考代码之二
#http://blog.csdn.net/sparkexpert/article/details/70147409 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#from skimage.io import imsave
import scipy
import os
import shutil img_height = 28
img_width = 28
img_size = img_height * img_width to_train = True
to_restore = False
output_path = "output" # 总迭代次数500
max_epoch = 500 h1_size = 150
h2_size = 300
z_size = 100
batch_size = 256 # generate (model 1)
def build_generator(z_prior):
w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)
b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)
h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)
w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3
x_generate = tf.nn.tanh(h3)
g_params = [w1, b1, w2, b2, w3, b3]
return x_generate, g_params # discriminator (model 2)
def build_discriminator(x_data, x_generated, keep_prob):
# tf.concat
x_in = tf.concat([x_data, x_generated], 0)
w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)
b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)
w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)
w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3
y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))
y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))
d_params = [w1, b1, w2, b2, w3, b3]
return y_data, y_generated, d_params #
def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):
batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5
img_h, img_w = batch_res.shape[1], batch_res.shape[2]
grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
for i, res in enumerate(batch_res):
if i >= grid_size[0] * grid_size[1]:
break
img = (res) * 255
img = img.astype(np.uint8)
row = (i // grid_size[0]) * (img_h + grid_pad)
col = (i % grid_size[1]) * (img_w + grid_pad)
img_grid[row:row + img_h, col:col + img_w] = img
#imsave(fname, img_grid)
#img.save('output/num.jpg')
scipy.misc.imsave(fname, img_grid) def train():
# load data(mnist手写数据集)
mnist = input_data.read_data_sets('mnist_data', one_hot=True) x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data")
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
global_step = tf.Variable(0, name="global_step", trainable=False) # 创建生成模型
x_generated, g_params = build_generator(z_prior)
# 创建判别模型
y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob) # 损失函数的设置
d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))
g_loss = - tf.log(y_generated) optimizer = tf.train.AdamOptimizer(0.0001) # 两个模型的优化函数
d_trainer = optimizer.minimize(d_loss, var_list=d_params)
g_trainer = optimizer.minimize(g_loss, var_list=g_params) init = tf.initialize_all_variables() saver = tf.train.Saver()
# 启动默认图
sess = tf.Session()
# 初始化
sess.run(init) if to_restore:
chkpt_fname = tf.train.latest_checkpoint(output_path)
saver.restore(sess, chkpt_fname)
else:
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.mkdir(output_path) z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) steps = 60000 / batch_size
for i in range(sess.run(global_step), max_epoch):
for j in np.arange(steps):
# for j in range(steps):
print("epoch:%s, iter:%s" % (i, j))
# 每一步迭代,我们都会加载256个训练样本,然后执行一次train_step
x_value, _ = mnist.train.next_batch(batch_size)
x_value = 2 * x_value.astype(np.float32) - 1
z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# 执行生成
sess.run(d_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
# 执行判别
if j % 1 == 0:
sess.run(g_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
show_result(x_gen_val, "output/sample{0}.jpg".format(i))
z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
show_result(x_gen_val, "output/random_sample{0}.jpg".format(i))
sess.run(tf.assign(global_step, i + 1))
saver.save(sess, os.path.join(output_path, "model"), global_step=global_step) def test():
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
x_generated, _ = build_generator(z_prior)
chkpt_fname = tf.train.latest_checkpoint(output_path) init = tf.initialize_all_variables()
sess = tf.Session()
saver = tf.train.Saver()
sess.run(init)
saver.restore(sess, chkpt_fname)
z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
show_result(x_gen_val, "output/test_result.jpg") if __name__ == '__main__':
if to_train:
train()
else:
test()
GAN 生成mnist数据的更多相关文章
- GAN生成式对抗网络(三)——mnist数据生成
通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...
- 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)
变分自编码器(VAE,variatinal autoencoder) VS 生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...
- 深度学习之 GAN 进行 mnist 图片的生成
深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...
- 对抗生成网络-图像卷积-mnist数据生成(代码) 1.tf.layers.conv2d(卷积操作) 2.tf.layers.conv2d_transpose(反卷积操作) 3.tf.layers.batch_normalize(归一化操作) 4.tf.maximum(用于lrelu) 5.tf.train_variable(训练中所有参数) 6.np.random.uniform(生成正态数据
1. tf.layers.conv2d(input, filter, kernel_size, stride, padding) # 进行卷积操作 参数说明:input输入数据, filter特征图的 ...
- GAN︱生成模型学习笔记(运行机制、NLP结合难点、应用案例、相关Paper)
我对GAN"生成对抗网络"(Generative Adversarial Networks)的看法: 前几天在公开课听了新加坡国立大学[机器学习与视觉实验室]负责人冯佳时博士在[硬 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- GAN生成的评价指标 Evaluation of GAN
传统方法中,如何衡量一个generator ?-- 用 generator 产生数据的 likelihood,越大越好. 但是 GAN 中的 generator 是隐式建模,所以只能从 P_G 中采样 ...
- Enterprise Solution 生成实体数据访问接口与实现类型 Code Smith 6.5 模板文件下载
数据库表定义为SalesOrder,用LLBL Gen Pro生成的实体定义是SalesOrderEntity,再用Code Smith生成的数据读写接口是ISalesOrderManager,最后是 ...
随机推荐
- 算法笔记_142:无向图的欧拉回路求解(Java)
目录 1 问题描述 2 解决方案 1 问题描述 John's trip Time Limit: 1000MS Memory Limit: 65536K Total Submissions: 8 ...
- jquery获取json对象中的key小技巧,遍历json串所有key,value
比如有一个json var json = {"name" : "Tom", "age" : 18}; 想分别获取它的key 和 value ...
- javascript 新知识
document.compatMode 属性 BackCompat: Standards-compliant mode is not switched on. (Quirks Mode) 标准模式 ...
- ubuntu PATH 出错修复
我的 ubuntu10.10设置交叉编译环境时,PATH 设置错误了,导致无法正常启动,错误情况如下: { PATH:找不到命令ubuntu2010@ubuntu:~$ ls命令 'ls' 可在 '/ ...
- C#调取java接口
1. public class APIRequest { //public static string commonUrl = @"http://192.168.2.186 ...
- EAST 自然场景文本检测
自然场景文本检测是图像处理的核心模块,也是一直想要接触的一个方面. 刚好看到国内的旷视今年在CVPR2017的一篇文章:EAST: An Efficient and Accurate S ...
- MySQL主从不一致的几种故障总结分析、解决和预防
(1).主从不一致故障,从库宕机,从库启动后重复写入数据报错解决与预防:relay_log_info_repository=TABLE(InnoDB)参数解释说明:若relay_log_info_re ...
- 关于php使用基于socket Web消息推送(未完)
转:http://blog.csdn.net/young_phper/article/details/52441143 http://www.workerman.net/ http://blog.cs ...
- unity5,UI Button too small on device than in Game View解决办法
假设测试设备为iphone5(横屏).下面说明如何使真机上ui显示效果与Game View中一致. 1,首先Game View左上角屏幕规格选 iPhone 5 Wide (16:9),如图: 2,在 ...
- Oracle宣布很多其它的Java 9 新特性
随着Oracle确认了其余的4个Java 9特性,下一代Java的计划開始变得更清晰了,Oracle已经发布了第二套Java 9特性.自从Oracle在今年早些时候宣布了3个新的API和模块化源代码后 ...