不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 utils.py,输入如下代码:
import scipy.misc
import numpy as np # 保存图片函数
def save_images(images, size, path): """
Save the samples images
The best size number is
int(max(sqrt(image.shape[0]),sqrt(image.shape[1]))) + 1
example:
The batch_size is 64, then the size is recommended [8, 8]
The batch_size is 32, then the size is recommended [6, 6]
""" # 图片归一化,主要用于生成器输出是 tanh 形式的归一化
img = (images + 1.0) / 2.0
h, w = img.shape[1], img.shape[2] # 产生一个大画布,用来保存生成的 batch_size 个图像
merge_img = np.zeros((h * size[0], w * size[1], 3)) # 循环使得画布特定地方值为某一幅图像的值
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
merge_img[j*h:j*h+h, i*w:i*w+w, :] = image # 保存画布
return scipy.misc.imsave(path, merge_img)
这个函数的作用是在训练的过程中保存采样生成的图片。
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 model.py,定义生成器,判别器和训练过程中的采样网络,在 model.py 输入如下代码:
import tensorflow as tf
from ops import * BATCH_SIZE = 64 # 定义生成器
def generator(z, y, train = True):
# y 是一个 [BATCH_SIZE, 10] 维的向量,把 y 转成四维张量
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
# 把 y 作为约束条件和 z 拼接起来
z = tf.concat(1, [z, y], name = 'z_concat_y')
# 经过一个全连接,BN 和激活层 ReLu
h1 = tf.nn.relu(batch_norm_layer(fully_connected(z, 1024, 'g_fully_connected1'),
is_train = train, name = 'g_bn1'))
# 把约束条件和上一层拼接起来
h1 = tf.concat(1, [h1, y], name = 'active1_concat_y') h2 = tf.nn.relu(batch_norm_layer(fully_connected(h1, 128 * 49, 'g_fully_connected2'),
is_train = train, name = 'g_bn2'))
h2 = tf.reshape(h2, [64, 7, 7, 128], name = 'h2_reshape')
# 把约束条件和上一层拼接起来
h2 = conv_cond_concat(h2, yb, name = 'active2_concat_y') h3 = tf.nn.relu(batch_norm_layer(deconv2d(h2, [64,14,14,128],
name = 'g_deconv2d3'),
is_train = train, name = 'g_bn3'))
h3 = conv_cond_concat(h3, yb, name = 'active3_concat_y') # 经过一个 sigmoid 函数把值归一化为 0~1 之间,
h4 = tf.nn.sigmoid(deconv2d(h3, [64, 28, 28, 1],
name = 'g_deconv2d4'), name = 'generate_image') return h4 # 定义判别器
def discriminator(image, y, reuse = False): # 因为真实数据和生成数据都要经过判别器,所以需要指定 reuse 是否可用
if reuse:
tf.get_variable_scope().reuse_variables() # 同生成器一样,判别器也需要把约束条件串联进来
yb = tf.reshape(y, [BATCH_SIZE, 1, 1, 10], name = 'yb')
x = conv_cond_concat(image, yb, name = 'image_concat_y') # 卷积,激活,串联条件。
h1 = lrelu(conv2d(x, 11, name = 'd_conv2d1'), name = 'lrelu1')
h1 = conv_cond_concat(h1, yb, name = 'h1_concat_yb') h2 = lrelu(batch_norm_layer(conv2d(h1, 74, name = 'd_conv2d2'),
name = 'd_bn2'), name = 'lrelu2')
h2 = tf.reshape(h2, [BATCH_SIZE, -1], name = 'reshape_lrelu2_to_2d')
h2 = tf.concat(1, [h2, y], name = 'lrelu2_concat_y') h3 = lrelu(batch_norm_layer(fully_connected(h2, 1024, name = 'd_fully_connected3'),
name = 'd_bn3'), name = 'lrelu3')
h3 = tf.concat(1,[h3, y], name = 'lrelu3_concat_y') # 全连接层,输出以为 loss 值
h4 = fully_connected(h3, 1, name = 'd_result_withouts_sigmoid') return tf.nn.sigmoid(h4, name = 'discriminator_result_with_sigmoid'), h4 # 定义训练过程中的采样函数
def sampler(z, y, train = True):
tf.get_variable_scope().reuse_variables()
return generator(z, y, train = train)
可以看到,生成器由 7 × 7 变为 14 × 14 再变为 28 × 28大小,每一层都加入了约束条件 y,完美的诠释了论文所给出的网络,之所以要加入 is_train 参数,是由于 Batch_norm 层中训练和测试的时候的过程是不同的,用这个参数区分训练和测试,生成器的最后一层,用了一个 sigmoid 函数把值归一化到 0~1 之间,如果是不加约束的网络,则用 tanh 函数,所以在 save_images 函数中要用到语句:img = (images + 1.0) / 2.0。
sampler 函数的作用是在训练过程中对生成器生成的图片进行采样,所以这个函数必须指定 reuse 可用,关于 reuse 说明,请看:http://www.cnblogs.com/Charles-Wan/p/6200446.html。
参考资料:
1. https://github.com/carpedm20/DCGAN-tensorflow
不要怂,就是GAN (生成式对抗网络) (三):判别器和生成器 TensorFlow Model的更多相关文章
- GAN生成式对抗网络(三)——mnist数据生成
通过GAN生成式对抗网络,产生mnist数据 引入包,数据约定等 import numpy as np import matplotlib.pyplot as plt import input_dat ...
- GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构
论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...
- GAN生成式对抗网络(一)——原理
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型 GAN包括两个核心模块. 1.生成器模块 --generator 2.判别器模块--de ...
- 不要怂,就是GAN (生成式对抗网络) (一)
前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...
- 不要怂,就是GAN (生成式对抗网络) (一): GAN 简介
前面我们用 TensorFlow 写了简单的 cifar10 分类的代码,得到还不错的结果,下面我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现. 自从 Ian Goodf ...
- 不要怂,就是GAN (生成式对抗网络) (二)
前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...
- 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN
在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保 ...
- 不要怂,就是GAN (生成式对抗网络) (二):数据读取和操作
前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条 ...
- 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码
先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...
随机推荐
- streamsets record header 属性
record 的header 属性可以在pipeline 逻辑中使用. 有写stages 会为了特殊目录创建reord header 属性,比如(cdc)需要进行crud 操作类型的区分 你可以使用一 ...
- Jenkins在windows环境下安装无法安装插件
在windos平台下安装jenkins要是无法安装插件,tomcat控制台报以下错误: 解决方法: 进入到jenkins里头,Jenkins -- 管理插件 -- 高级 -- 升级站点,如图所示: 将 ...
- tasks
Edit: F:\wamp\www\tasks Task ID Name Links? Date commit Date Done 9 Read openCV documents F:\wamp\ww ...
- laravel 数据库操作
1 配置信息 1.1配置目录: config/database.php 1.2配置多个数据库 //默认的数据库 'mysql' => [ 'driver' => 'mysql', 'hos ...
- NGINX 添加MP4、FLV视频支持模块
由于公司网站需要放置视频,但是默认的服务器环境是没有编译这个支持的模块,视频文件只能缓冲完了在播放,非常麻烦. 之前呢也安装了一个nginx_mod_h264_streaming来支持,效果很不错 ...
- MySQL性能调优 – 你必须了解的15个重要变量
1.DEFAULT_STORAGE_ENGINE 如果你已经在用MySQL 5.6或者5.7,并且你的数据表都是InnoDB,那么表示你已经设置好了.如果没有,确保把你的表转换为InnoDB并且设置d ...
- java get post 请求
package wzh.Http; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStr ...
- 除去a标签target="_blank"的方法
用Jquery:$(function(){$("div>a").attr("target","_blank");});先查找页面上的d ...
- 图搜索——使用DFS和BFS耗时比较
图测试数据生成代码: #include<bits/stdc++.h> using namespace std; int random(int mod) { return rand() % ...
- 超简单的制作win7 U盘启动
我感觉真的太简单,操作so简单 第一个下载这个工具,这是微软官方提供的,用这个工具可以把win7的iso文件刻录到u盘中,u盘就可以作为系统启动盘来使用了 Windows 7 USB DVD Down ...