initializer总结:

#f.constant_initializer(value)

变量初始化为给定的常量,初始化一切所提供的值。


#tf.random_normal_initializer(mean,stddev)

功能是将变量初始化为满足正态分布的随机值,主要参数(正太分布的均值和标准差),用所给的均值和标准差初始化均匀分布


#tf.truncated_normal_initializer(mean,stddev,seed,dtype)

mean:用于指定均值; stddev用于指定标准差;

seed:用于指定随机数种子; dtype:用于指定随机数的数据类型。

功能:将变量初始化为满足正态分布的随机值,但如果随机出来的值偏离平均值超过2个标准差,那么这个数将会被重新随机,通常只需要设定一个标准差stddev这一个参数就可以。 #tf.random_uniform_initializer(a,b,seed,dtype) #从a到b均匀初始化,将变量初始化为满足平均分布的随机值,主要参数(最大值,最小值)


优化器构造:

1).compute_gradients(loss,var_list=None,gate_gradients=GATE_OP,aggregation_method=None,colocate_gradients_with_ops=False,grad_loss=None)

作用:对于在变量列表(var_list)中的变量计算对于损失函数的梯度,这个函数返回一个(梯度,变量)对的列表,其中梯度就是相对应变量的梯度了。这是minimize()函数的第一个部分,

参数:

loss: 待减小的值

var_list: 默认是在GraphKey.TRAINABLE_VARIABLES.

2).apply_gradients(grads_and_vars,global_step=None,name=None)

作用:把梯度“应用”(Apply)到变量上面去。其实就是按照梯度下降的方式加到上面去。这是minimize()函数的第二个步骤。 返回一个应用的操作。

参数:

grads_and_vars: compute_gradients()函数返回的(gradient, variable)对的列表

global_step: Optional Variable to increment by one after the variables have been updated.

3).minimize(loss,global_step=None,var_list=None,gate_gradients=GATE_OP,aggregation_method=None,colocate_gradients_with_ops=False,name=None,grad_loss=None)

变量初始化:

sess.run(tf.global_variables_initializer())

函数中调用了 variable_initializer() 和 global_variables()

global_variables() 返回一个 Variable list ,里面保存的是 gloabal variables。variable_initializer() 将 Variable list 中的所有 Variable 取出来,将其 variable.initializer 属性做成一个 op group。然后看 Variable 类的源码可以发现, variable.initializer 就是一个 assign op。

所以: sess.run(tf.global_variables_initializer()) 就是 run了所有global Variable 的 assign op,这就是初始化参数的本来面目。

废话不多说,上代码。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tensorflow.examples.tutorials.mnist import input_data os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' sess = tf.InteractiveSession() mb_size = 128
Z_dim = 100
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)#mnist数据集 one_hot是为了让标签二元,即只有0和1. def weight_var(shape, name): #定义权重,传入权重shape和name
return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer())
def bias_var(shape, name):#定义偏置,传入偏置shape和name
return tf.get_variable(name=name, shape=shape, initializer=tf.constant_initializer(0))
# discriminater net
X = tf.placeholder(tf.float32, shape=[None, 784], name='X') #样本x的shape是【batchsize】【784】
D_W1 = weight_var([784, 128], 'D_W1') #D的中间层的w1
D_b1 = bias_var([128], 'D_b1')
D_W2 = weight_var([128, 1], 'D_W2')#D的输出层的w2
D_b2 = bias_var([1], 'D_b2')
theta_D = [D_W1, D_W2, D_b1, D_b2]#D的参数列表 # generator net
Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')#随机噪声向量z的shape是【batchsize】【100】
G_W1 = weight_var([100, 128], 'G_W1')#D的中间层的w1
G_b1 = bias_var([128], 'G_B1')
G_W2 = weight_var([128, 784], 'G_W2')#D的输出层的w2
G_b2 = bias_var([784], 'G_B2')
theta_G = [G_W1, G_W2, G_b1, G_b2]#G的参数列表 def generator(z): #定义G,传入随机噪声z,返回G的输出。
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) #G_h1中间层经过激活函数后的输出。
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 #G输出层没有经过激活函数的输出。
G_prob = tf.nn.sigmoid(G_log_prob)#G输出层经过激活函数后的输出。
return G_prob def discriminator(x):#定义D,传入样本x,返回D的输出和没有经过激活函数的输出。
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)#D_h1中间层经过激活函数后的输出
D_logit = tf.matmul(D_h1, D_W2) + D_b2#D输出层没有经过激活函数的输出
D_prob = tf.nn.sigmoid(D_logit)#D输出层经过激活函数后的输出
return D_prob, D_logit G_sample = generator(Z) #调用generator(z)生成G样本
D_real, D_logit_real = discriminator(X)#discriminator(x)辨别样本
D_fake, D_logit_fake = discriminator(G_sample) # D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake)) D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_logit_real, labels=tf.ones_like(D_logit_real))) #使用交叉熵代价函数,D的目标:对于真实样本,target=1
tf.summary.scalar("D_loss_real", D_loss_real) D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))#使用交叉熵代价函数,D的目标:对于生成器生成的样本,target=0
tf.summary.scalar("D_loss_fake", D_loss_fake)
D_loss = D_loss_real + D_loss_fake #D最后的损失函数为D(真)+D(假) G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))#使用交叉熵代价函数,G的目标:对于生成器生成的样本,target=1
tf.summary.scalar("G_loss", G_loss) D_optimizer = tf.train.GradientDescentOptimizer(0.002).minimize(D_loss, var_list=theta_D)
G_optimizer = tf.train.GradientDescentOptimizer(0.002).minimize(G_loss, var_list=theta_G)
# init variables
sess.run(tf.global_variables_initializer())
def sample_Z(m, n):#随机噪声向量的生成,维度为m*n
return np.random.uniform(-1., 1., size=[m, n]) def plot(samples):#画图
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples): # [i,samples[i]] imax=16
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig if not os.path.exists('out/'):
os.makedirs('out/')
i = 0 summary_op = tf.summary.merge_all()
writer = tf.summary.FileWriter("E:\Program Files\PycharmProjects\mnist_gan",sess.graph) for it in range(1000000):
if it % 1000 == 0:
samples = sess.run(G_sample, feed_dict={
Z: sample_Z(16, Z_dim)})#生成一个维度为16*100的向量,其值是-1.——1.的随机值。
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
i += 1
plt.close(fig) X_mb, _ = mnist.train.next_batch(mb_size)#调用了mnist里的方法,返回x和label _, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={#run(D_optimizer),开始进行梯度下降。#run(D_loss),获得d_loss值
X: X_mb, Z: sample_Z(mb_size, Z_dim)}) #D喂入x样本和Z样本 _, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={
Z: sample_Z(mb_size, Z_dim)})#G喂入x样本和Z样本
result = sess.run(summary_op, feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
writer.add_summary(result, i)
if it % 1000 == 0:
print('Iter: {}'.format(it)) #用format()里的数字来替换“{}”
print('D_loss: {:.4}'.format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
print()

基础Gan代码解析的更多相关文章

  1. [nRF51822] 12、基础实验代码解析大全 · 实验19 - PWM

    一.PWM概述: PWM(Pulse Width Modulation):脉冲宽度调制技术,通过对一系列脉冲的宽度进行调制,来等效地获得所需要波形. PWM 的几个基本概念: 1) 占空比:占空比是指 ...

  2. [nRF51822] 11、基础实验代码解析大全 · 实验16 - 内部FLASH读写

     一.实验内容: 通过串口发送单个字符到NRF51822,NRF51822 接收到字符后将其写入到FLASH 的最后一页,之后将其读出并通过串口打印出数据. 二.nRF51822芯片内部flash知识 ...

  3. [nRF51822] 10、基础实验代码解析大全 · 实验15 - RTC

    一.实验内容: 配置NRF51822 的RTC0 的TICK 频率为8Hz,COMPARE0 匹配事件触发周期为3 秒,并使能了TICK 和COMPARE0 中断. TICK 中断中驱动指示灯D1 翻 ...

  4. [nRF51822] 9、基础实验代码解析大全 · 实验12 - ADC

    一.本实验ADC 配置 分辨率:10 位. 输入通道:5,即使用输入通道AIN5 检测电位器的电压. ADC 基准电压:1.2V. 二.NRF51822 ADC 管脚分布 NRF51822 的ADC ...

  5. [nRF51822] 8、基础实验代码解析大全 · 实验11 - PPI

    前一篇分析了前十个基础实验的代码,从这里开始分析后十个~ 一.PPI原理: PPI(Programmable Peripheral Interconnect),中文翻译为可编程外设互连. 在nRF51 ...

  6. [nRF51822] 7、基础实验代码解析大全(前十)

    实验01 - GPIO输出控制LED 引脚输出配置:nrf_gpio_cfg_output(LED_1); 引脚输出置高:nrf_gpio_pin_set(LED_1); 引脚电平转换:nrf_gpi ...

  7. 【原创】大数据基础之Spark(5)Shuffle实现原理及代码解析

    一 简介 Shuffle,简而言之,就是对数据进行重新分区,其中会涉及大量的网络io和磁盘io,为什么需要shuffle,以词频统计reduceByKey过程为例, serverA:partition ...

  8. 【原创】大数据基础之Spark(4)RDD原理及代码解析

    一 简介 spark核心是RDD,官方文档地址:https://spark.apache.org/docs/latest/rdd-programming-guide.html#resilient-di ...

  9. hadoop概述测试题和基础模版代码

    hadoop概述测试题和基础模版代码 1.Hadoop的创始人是DougCutting?() A.正确 B.错误答对了!正确答案:A解析:参考课程里的文档,这个就不解释了2.下列有关Hadoop的说法 ...

随机推荐

  1. ABP安装,应用及二次开发视频

    CSDN课程:http://edu.csdn.net/lecturer/944

  2. idea上手

    IntelliJ Idea 常用快捷键列表 最常用: Ctrl+P,可以显示参数信息 Alt+Insert,可以生成构造器/Getter/Setter等 Ctrl+Enter,导入包,自动修正 Ctr ...

  3. tensorboard No graph definition files were found No scalar data was found 解决方法

    logdir后面的路径不要加引号!!!! tensorboard --logdir='D:\WorkSpace\eclipse\tf_tr\train\my_graph\' 删除引号,改为: tens ...

  4. Unity 3D类结构简介

    趁着周末,再来一发.对于Unity3D,我也是刚开始学习,希望能够与大家多多交流.好了,废话不多说,下面继续. 本篇文章使用C#进行举例和说明.关于Unity 3D编辑器中的各种窗口,网上有很多资料了 ...

  5. JDK安装教程

    打开我的电脑,在D盘中新建一个文件夹,名字为develop 进入develop,创建一个新文件夹,名字叫做jdk 双击JDK的安装包, .4.出如图所示的框,选择下一步 .5.更改安装路径,选择更改 ...

  6. 转 一个oracle11g 使用exp导出空表丢失的问题分析及解决办法

    用exp无法导出空表解决方法 最早的一次使用oracle 11g导出数据发现有的表丢失了,感觉莫名其妙的,后来终于找到原因了. 找到问题以后,再看看解决方案.11GR2中有个新特性,当表无数据时,不分 ...

  7. [蓝桥杯]PREV-25.历届试题_城市建设

    问题描述 栋栋居住在一个繁华的C市中,然而,这个城市的道路大都年久失修.市长准备重新修一些路以方便市民,于是找到了栋栋,希望栋栋能帮助他. C市中有n个比较重要的地点,市长希望这些地点重点被考虑.现在 ...

  8. bzoj5104: Fib数列

    Description Fib数列为1,1,2,3,5,8... 求在Mod10^9+9的意义下,数字N在Fib数列中出现在哪个位置 无解输出-1 Input 一行,一个数字N,N < = 10 ...

  9. SpringCloud注解和配置以及pom依赖说明

    在本文中说明了pom依赖可以支持什么功能,以及支持什么注解,引入该依赖可以在application.properties中添加什么配置. 1.SpringCloud 的pom依赖 序号 pom依赖 说 ...

  10. thinkphp url build 生成localhost.localhost的解决方案

    找到框架核心Url.php的下面一段代码 // 原代码// URL组装$url = $domain . rtrim($this->root ?: $this->app['request'] ...