定义model的函数有:

1.gan_model

函数原型:

def gan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
# Options.
check_shapes=True)

参数:

generator_fn:预先定义好的生成器网络的函数名称。预先定义好的生成器函数的输入参数应该是接下来要说明的第四个参数generator_inputs,生成网络的返回值是网络的输出(因为是GAN,所以生成器的输出一般是一幅机器生成的图像)。

discriminator_fn:预先定义好的判别器网络的函数名称。预先定义好的判别器函数的输入参数有两个:第一个是“真实数据(图像)”/“机器生成的图像(generator_fn的返回值)”;第二个是生成器的输入,即此函数的第四个参数(在普通的gan当中,判别器只需要第一个参数。即使不需要第二个参数,也必须显式地定义出第二个参数,只不过定义了之后在判别器函数中可以不使用)。判别器的返回值必须在负无穷到正无穷之间([-inf, +inf])。

real_data:真实图像。一般传入真实图像batch化后的引用。

generator_inputs:生成器的输入。对于vallina gan,是tensor类型的噪声。除此之外,如果是c-gan,还可以传入一个list或tuple作为参数(在下方的“其他说明“里详细说明c-gan(conditional-gan)的情况)。

generator_scope:传入这个参数可以定义生成器内参数的变量命名空间(variable_scope)。默认为"Generator"。

discriminator_scope:传入这个参数可以定义判别器内参数的变量命名空间(variable_scope)。默认为"Discriminator"。

check_shapes:如果为真,将检查生成器生成的数据与真实数据是否有相同的shape。如果为假,则跳过检查。

返回值:

返回一个“GANModel 命名管道”。实际上就是一个由生成器函数、判别器函数、生成的数据、变量空间等东西组成的一个List。这个返回值不需要我们写程序的时候用,就不过多解释了(具体用法见本系列上一篇文档:传送门)。

函数内部实现:

generator_fn和discriminator_fn在gan_model函数里这样调用:

# 由机器生成数据
generated_data = generator_fn(generator_inputs) # 判别器判断机器生成图片的真实性
discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs) # 判别器判断真实图片的真实性
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)

其他说明:

  • gan_model支持conditional-gan。若需要训练c-gan,要通过generator_inputs额外传入标签信息。如:generator_inputs=(noise, one_hot_label)。同时,判别器网络与生成器网络应该按照c-gan论文中的模型重新定义。
  • real_data一般为一个next_batch。如:next_batch = tf.compat.v1.data.make_one_shot_iterator(image_ds).get_next()

2.infogan_model

函数原型:

def infogan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
unstructured_generator_inputs,
structured_generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator')

参数:

generator_fn:预先定义好的生成器网络的函数名称。预先定义好的生成器函数的输入参数应该是接下来要说明的unstructrued_generator_inputs与structured_generator_inputs共同组成的列表,列表中的每一项是一个Tensor,生成网络的返回值是生成器的输出。

discriminator_fn:预先定义好的判别器网络的函数名称。预先定义好的判别器函数的输入参数应该有两个:第一个是“真实数据(图像)”/“机器生成的图像(generator_fn的返回值)”;第二个是生成器的输入,即(unstructrued_generator_inputs与structured_generator_inputs共同组成的列表)。预先定义好的判别器函数的输出应是一个二维Tuple。Tuple的第一维是生成器网络输出层的logits,范围在[-inf, +inf]。Tuple的第二维是分布的列表:此分布的第i个列表元素代表的是第i个structure noise 的分布。

real_data:真实图像。一般传入真实图像batch化后的引用。

unstructured_generator_inputs:Tensor的列表。表示非结构化的noise或条件。

structured_generator_inputs:Tensor的列表。这些Tensor必须与识别器具有较高的相互信息。

generator_scope:传入这个参数可以定义生成器内参数的变量命名空间(variable_scope)。默认为"Generator"。

discriminator_scope:传入这个参数可以定义判别器内参数的变量命名空间(variable_scope)。默认为"Discriminator"。

返回值:

返回一个“InfoGANModel 命名管道”。同“GANModel 命名管道”一样,我们无需关心管道中的具体内容。

函数内部实现:

生成器的输入这样定义:

generator_inputs = (unstructured_generator_inputs + structured_generator_inputs)

生成器和判别器这样调用:

# 由机器生成数据
generated_data = generator_fn(generator_inputs) # 判别器判断机器生成图片的真实性
dis_gen_outputs, predicted_distributions = discriminator_fn(generated_data, generator_inputs) # 判别器判断真实图片的真实性
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)

其他说明:

  • 关于生成器和判别器网络模型的搭建,请参照Info-GAN的论文。
  • real_data一般为一个next_batch。如:next_batch = tf.compat.v1.data.make_one_shot_iterator(image_ds).get_next()

3.acgan_model:

函数原型:

def acgan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
generator_inputs,
one_hot_labels,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
# Options.
check_shapes=True)

参数:

与gan_model中的参数基本一致,除了:

discriminator_fn:预定义的判别器函数应当返回一个二维Tuple。第一维是网络输出层的real或者fake的logits;第二维是分类器的logits。他们两个的范围都应该是[-inf, +inf]。

one_hot_labels:对应于一个batch图像的one_hot_label。

返回值:

返回“AcGANModel 命名管道”。同“GANModel 命名管道”一样,我们无需关心管道中的具体内容。

函数内部实现:

生成器和判别器这样调用:

# 由机器生成数据
generated_data = generator_fn(generator_inputs) # 判别器判断机器生成图片的真实性
(discriminator_gen_outputs, discriminator_gen_classification_logits) = _validate_acgan_discriminator_outputs(discriminator_fn(generated_data, generator_inputs)) # 判别器判断真实图片的真实性
(discriminator_real_outputs, discriminator_real_classification_logits) = _validate_acgan_discriminator_outputs(discriminator_fn(real_data, generator_inputs))

其他说明:

  • one_hot_labels在此函数内部没有被使用,而是直接通过命名管道(返回值)传递给gan_loss函数(下一篇详细说明)。
  • one_hot_labels与real_data均为batch。

4.cyclegan_model:

函数原型:

def cyclegan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# data X and Y.
data_x,
data_y,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
model_x2y_scope='ModelX2Y',
model_y2x_scope='ModelY2X',
# Options.
check_shapes=True)

参数:

generator_fn:预先定义好的生成器函数。此生成器的输入有一个参数,与gan_model的generator_fn一样。返回值为生成器网络的输出。

discriminator_fn:预先定义好的判别器函数。与gan_model的discriminator_fn定义一样。

data_x:x域的真实数据。

data_y:y域的真实数据。

generator_scope:与gan_model的generator_scope意义一样。

discriminator_scope:与gan_model的discriminator_scope意义一样。

model_x2y_scope:x->y转换过程的variable_scope。

model_y2x_scope:y->x转换过程的variable_scope。

check_shapes:如果为真,将检查生成器生成的数据与真实数据是否有相同的shape。如果为假,则跳过检查。

返回值:

返回“CycleGANModel 命名空间”。

函数内部实现:

此函数实际上调用了gan_model函数,如下所示:

# Create models.
def _define_partial_model(input_data, output_data): # 内部函数定义
return gan_model(
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
real_data=output_data,
generator_inputs=input_data,
generator_scope=generator_scope,
discriminator_scope=discriminator_scope,
check_shapes=check_shapes) with tf.compat.v1.variable_scope(model_x2y_scope):
model_x2y = _define_partial_model(data_x, data_y)
with tf.compat.v1.variable_scope(model_y2x_scope):
model_y2x = _define_partial_model(data_y, data_x) with tf.compat.v1.variable_scope(model_y2x.generator_scope, reuse=True):
reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
with tf.compat.v1.variable_scope(model_x2y.generator_scope, reuse=True):
reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data) return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
reconstructed_y)

其他说明:

5.stargan_model

函数原型:

def stargan_model(generator_fn,
discriminator_fn,
input_data,
input_data_domain_label,
generator_scope='Generator',
discriminator_scope='Discriminator')

参数:

generator_fn:预先定义好的函数的函数名称。函数的输入有两个,应分别为:input、target,返回值是根据inputs和targets由机器生成的图像。inputs的形状应该是(batch, height, width, channel),targets的形状是(batch, num_domain)。返回值有和inputs相同的形状。

discriminator_fn:预先定义好的函数的函数名称。此函数的输入有两个,分别为input和num_domain。返回值是一个Tuple:(`source_prediction`, `domain_prediction`)。`source_prediction`表示预测的图像(真实或生成的)真实度,“ domain_prediction”代表判别器对域分类的预测(真实度)。 `source_prediction`的形状是(batch), `domain_prediction`具有形状(batch,num_domains)。

input_data:Tensor或Tensor组成的列表。代表真实输入的图片。形状是(batch, height, width, channel)。

input_data_domain_label:Tensor或Tensor组成的列表。形状为(batch, num_domains)。表示真实数据相对应的代表域的标签。

generator_scope:与gan_model的此参数意义相同。

discriminator_scope:与gan_model的此参数意义相同。

返回值:

返回“StarGANModel 命名空间”。

函数内部实现:

函数内部重要代码如下:

  # Transform input_data to random target domains.
with tf.compat.v1.variable_scope(generator_scope) as generator_scope:
generated_data_domain_target = generate_stargan_random_domain_target(
batch_size, num_domains)
generated_data = generator_fn(input_data, generated_data_domain_target) # Transform generated_data back to the original input_data domain.
with tf.compat.v1.variable_scope(generator_scope, reuse=True):
reconstructed_data = generator_fn(generated_data, input_data_domain_label) # Predict source and domain for the generated_data using the discriminator.
with tf.compat.v1.variable_scope(discriminator_scope) as discriminator_scope:
disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
generated_data, num_domains) # Predict source and domain for the input_data using the discriminator.
with tf.compat.v1.variable_scope(discriminator_scope, reuse=True):
disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
input_data, num_domains)

其他说明:

tfgan折腾笔记(二):核心函数详述——gan_model族的更多相关文章

  1. tfgan折腾笔记(三):核心函数详述——gan_loss族

    gan_loss族的函数有: 1.gan_loss: 函数原型: def gan_loss( # GANModel. model, # Loss functions. generator_loss_f ...

  2. tfgan折腾笔记(一):核心功能简要概述

    tfgan是什么? tfgan是tensorflow团队开发出的一个专门用于训练各种GAN的轻量级库,它是基于tensorflow开发的,所以兼容于tensorflow.在tensorflow1.x版 ...

  3. linux shell学习笔记二---自定义函数(定义、返回值、变量作用域)介绍

    linux shell 可以用户定义函数,然后在shell脚本中可以随便调用.下面说说它的定义方法,以及调用需要注意那些事项. 一.定义shell函数(define function) 语法: [ f ...

  4. ASP.NET Core 折腾笔记二:自己写个完整的Cache缓存类来支持.NET Core

    背景: 1:.NET Core 已经没System.Web,也木有了HttpRuntime.Cache,因此,该空间下Cache也木有了. 2:.NET Core 有新的Memory Cache提供, ...

  5. guxh的python笔记二:函数基础

    1,函数的参数 1.1,查看函数的参数类型 def run(a, *args, b, **kwargs): return a + b 可以通过如下方式查看参数类型: import inspect k ...

  6. OpenCv 2.4.9 (二) 核心函数

    前言 经过前面一节的怎样读取图片,我们可以做一些有趣的图像变换,下面我们首先介绍使用遍历的方法实现,然后我们使用内置的函数实现. 矩阵掩码实现 矩阵掩码,和卷积神经网络中的卷积类似.一个例子如下: 现 ...

  7. [C语言学习笔记二] extern 函数的用法

    extern 用来定义一个或多个变量.其后跟数据类型名和初始值.例如: extern int a =10 它与 int,long long int,double,char的本质区别,在于 extern ...

  8. 【C++初学者自学笔记二】函数重载(模块一)

    1.概念:同意作用域的一组参数列表不同,函数名相同的函数,这组函数叫函数重载(C语言是不能定义相同名称的函数,但是C++可以允许定义). 2作用:重载函数通常来命名一组功能相似的函数,这样做减少了函数 ...

  9. wr720n v4 折腾笔记(二):刷入不死Uboot

    0x01 前言 接着上节刷入Openwrt开始说起,此次开始刷入不死Uboot,刷入之后就可以在Uboot里面随便刷机,再也不怕成砖了. 固件附件地址: 下载地址1(还是之前一的包) flash文件地 ...

随机推荐

  1. 892B. Wrath#愤怒的连环杀人事件(cin/cout的加速)

    题目出处:http://codeforces.com/problemset/problem/892/B 题目大意:一队人同时举刀捅死前面一些人后还活着几个 #include<iostream&g ...

  2. 109)PHP与oracle网址

    https://pecl.php.net/package/oci8/2.1.8/windowshttps://www.toadworld.com/platforms/oracle/w/wiki/116 ...

  3. rest framework-restful介绍-长期维护

    ###############   django框架-rest framework    ############### # django rest framework 框架 # 为什么学习这个res ...

  4. jeesite 去掉 /a

    1.修改 jeesite.properties文件 adminPath=/a为 adminPath= 2.修改 web.xml文件找到如下设置 <filter-mapping> <f ...

  5. 结构体struct,类class

    1.struct,值类型,结构体会自动生成初始化方法,class是引用类型 struct Person { var name : String var age : Int func simpleDes ...

  6. php函数 之 iconv 不是php的默认函数,也是默认安装的模块。需要安装才能用的。

    windows下最近在做一个小偷程序,需要用到iconv函数把抓取来过的utf-8编码的页面转成gb2312, 发现只有用iconv函数把抓取过来的数据一转码数据就会无缘无故的少一些.  让我郁闷了好 ...

  7. Yii框架的学习指南(策码秀才篇)1-3 我是这么学习的yii framework (不间断更新中)

    Ⅰ.基本概念一.入口文件入口文件内容:一般格式如下:<?php $yii=dirname(__FILE__).'/../../framework/yii.php';//Yii框架位置$confi ...

  8. [LC] 557. Reverse Words in a String III

    Given a string, you need to reverse the order of characters in each word within a sentence while sti ...

  9. python使用geopandas和shapely处理shp文件

    一.环境搭建 所需库:geopandas (以及前置库)  doc:http://geopandas.org/ shapely(以及前置库)  doc: 二.数据预处理 1.将shp文件进行切片 2. ...

  10. mysql配置白名单

    1. 测试是否允许远程连接 $ telnet 192.168.1.8 3306 host 192.168.1.4 is not allowed to connect to this mysql ser ...