tfgan折腾笔记(三):核心函数详述——gan_loss族
gan_loss族的函数有:
1.gan_loss:
函数原型:
def gan_loss(
# GANModel.
model,
# Loss functions.
generator_loss_fn=tuple_losses.wasserstein_generator_loss,
discriminator_loss_fn=tuple_losses.wasserstein_discriminator_loss,
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
tensor_pool_fn=None,
# Options.
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=True)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数,可用函数见其他说明。
discriminator_loss_fn:判别器使用的损失函数,可用函数见其他说明。
gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。
gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。
gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。
gradient_penalty_one_sided:(暂不明白什么意思)。
mutual_information_penalty_weight:交叉信息惩罚权值。如果不是None,必须提供一个非负数或Tensor。
aux_cond_generator_weight:如果不是None,则添加生成器分类损失。
aux_cond_discriminator_weight:如果不是None,则添加判别器分类损失。
tensor_pool_fn:tensor pool函数。此函数传入tuple类型:(generated_data, generator_inputs),函数将它们放在内部pool中,并且返回上一个pool中的值。如,可以传入tfgan.features.tensor_pool。
reduction:传入tf.losses.Reduction类的函数。
add_summaries:是否添加总结到Tensorboard日志。
返回值:
返回“GANLoss 命名元组”。
函数内部实现:
# Create standard losses with optional kwargs, if the loss functions accept
# them.
def _optional_kwargs(fn, possible_kwargs):
"""Returns a kwargs dictionary of valid kwargs for a given function."""
if inspect.getargspec(fn).keywords is not None:
return possible_kwargs
actual_args = inspect.getargspec(fn).args
actual_kwargs = {}
for k, v in possible_kwargs.items():
if k in actual_args:
actual_kwargs[k] = v
return actual_kwargs
possible_kwargs = {'reduction': reduction, 'add_summaries': add_summaries}
gen_loss = generator_loss_fn(
model, **_optional_kwargs(generator_loss_fn, possible_kwargs))
dis_loss = discriminator_loss_fn(
pooled_model, **_optional_kwargs(discriminator_loss_fn, possible_kwargs))
其他说明:
- tfgan内置损失函数:
__all__ = [
'acgan_discriminator_loss',
'acgan_generator_loss',
'least_squares_discriminator_loss',
'least_squares_generator_loss',
'modified_discriminator_loss',
'modified_generator_loss',
'minimax_discriminator_loss',
'minimax_generator_loss',
'wasserstein_discriminator_loss',
'wasserstein_hinge_discriminator_loss',
'wasserstein_hinge_generator_loss',
'wasserstein_generator_loss',
'wasserstein_gradient_penalty',
'mutual_information_penalty',
'combine_adversarial_loss',
'cycle_consistency_loss',
'stargan_generator_loss_wrapper',
'stargan_discriminator_loss_wrapper',
'stargan_gradient_penalty_wrapper'
]
2.cyclegan_loss:
函数原型:
def cyclegan_loss(
model,
# Loss functions.
generator_loss_fn=tuple_losses.least_squares_generator_loss,
discriminator_loss_fn=tuple_losses.least_squares_discriminator_loss,
# Auxiliary losses.
cycle_consistency_loss_fn=tuple_losses.cycle_consistency_loss,
cycle_consistency_loss_weight=10.0,
# Options
**kwargs)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数。
discriminator_loss_fn:判别器使用的损失函数。
cycle_consistency_loss_fn:循环一致性损失函数。
cycle_consistency_loss_weight:循环一致性损失的权值。
**kwargs:这里的参数将直接传递给cyclegan_loss函数内部调用的gan_loss函数。
返回值:
返回“CycleGANLoss 命名元组”。
函数内部实现:
循环一致性损失函数与权值的定义:
# Defines cycle consistency loss.
cycle_consistency_loss = cycle_consistency_loss_fn(
model, add_summaries=kwargs.get('add_summaries', True))
cycle_consistency_loss_weight = _validate_aux_loss_weight(
cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss
**kwargs的实现:
# Defines losses for each partial model.
def _partial_loss(partial_model):
partial_loss = gan_loss(
partial_model,
generator_loss_fn=generator_loss_fn,
discriminator_loss_fn=discriminator_loss_fn,
**kwargs)
return partial_loss._replace(generator_loss=partial_loss.generator_loss +
aux_loss) with tf.compat.v1.name_scope('cyclegan_loss_x2y'):
loss_x2y = _partial_loss(model.model_x2y)
with tf.compat.v1.name_scope('cyclegan_loss_y2x'):
loss_y2x = _partial_loss(model.model_y2x)
其他说明:
- cycle-gan实际上是由两个普通gan组合而成的,其loss是普通gan的loss加上循环一致性损失。
- 循环一致性损失权值越大,则X->Y->X循环的相似性方面学习的越快。
3.stargan_loss:
函数原型:
def stargan_loss(
model,
generator_loss_fn=tuple_losses.stargan_generator_loss_wrapper(
losses_wargs.wasserstein_generator_loss),
discriminator_loss_fn=tuple_losses.stargan_discriminator_loss_wrapper(
losses_wargs.wasserstein_discriminator_loss),
gradient_penalty_weight=10.0,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
reconstruction_loss_fn=tf.compat.v1.losses.absolute_difference,
reconstruction_loss_weight=10.0,
classification_loss_fn=tf.compat.v1.losses.softmax_cross_entropy,
classification_loss_weight=1.0,
classification_one_hot=True,
add_summaries=True)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数。
discriminator_loss_fn:判别器使用的损失函数。
gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。
gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。
gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。
gradient_penalty_one_sided:(暂不明白什么意思)。
reconstruction_loss_fn:重建损失函数。
reconstruction_loss_weight:重建损失的权重。
classification_loss_fn:分类损失函数。
classification_loss_weight:分类损失的权重。
classification_one_hot:分类的one_hot_label。
add_summaries:是否向tensorboard添加总结。
返回值:
返回“StarGANLoss 命名元组”。
函数内部实现:
梯度惩罚函数与权值的定义:
# Gradient Penalty.
if _use_aux_loss(gradient_penalty_weight):
gradient_penalty_fn = tuple_losses.stargan_gradient_penalty_wrapper(
losses_wargs.wasserstein_gradient_penalty)
discriminator_loss += gradient_penalty_fn(
model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
one_sided=gradient_penalty_one_sided,
add_summaries=add_summaries) * gradient_penalty_weight
重建损失函数与权值的定义:
# Reconstruction Loss.
reconstruction_loss = reconstruction_loss_fn(model.input_data,
model.reconstructed_data)
generator_loss += reconstruction_loss * reconstruction_loss_weight
if add_summaries:
tf.compat.v1.summary.scalar('reconstruction_loss', reconstruction_loss)
分类损失函数与权值定义:
# Classification Loss.
generator_loss += _classification_loss_helper(
true_labels=model.generated_data_domain_target,
predict_logits=model.discriminator_generated_data_domain_predication,
scope_name='generator_classification_loss') * classification_loss_weight
discriminator_loss += _classification_loss_helper(
true_labels=model.input_data_domain_label,
predict_logits=model.discriminator_input_data_domain_predication,
scope_name='discriminator_classification_loss'
) * classification_loss_weight
其他说明:
无
tfgan折腾笔记(三):核心函数详述——gan_loss族的更多相关文章
- tfgan折腾笔记(二):核心函数详述——gan_model族
定义model的函数有: 1.gan_model 函数原型: def gan_model( # Lambdas defining models. generator_fn, discriminator ...
- Typescript 学习笔记三:函数
中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...
- ES6学习笔记<三> 生成器函数与yield
为什么要把这个内容拿出来单独做一篇学习笔记? 生成器函数比较重要,相对不是很容易理解,单独做一篇笔记详细聊一聊生成器函数. 标题为什么是生成器函数与yield? 生成器函数类似其他服务器端语音中的接口 ...
- tfgan折腾笔记(一):核心功能简要概述
tfgan是什么? tfgan是tensorflow团队开发出的一个专门用于训练各种GAN的轻量级库,它是基于tensorflow开发的,所以兼容于tensorflow.在tensorflow1.x版 ...
- python学习笔记三:函数及变量作用域
一.定义 def functionName([arg1,arg2,...]): code 二.示例 #!/usr/bin/python #coding:utf8 #coding=utf8 #encod ...
- python 学习笔记三 (函数)
1.把函数视为对象 def factorial(n): '''return n!''' return 1 if n < 2 else n*factorial(n-1) print(factori ...
- MySql学习笔记(三) —— 聚集函数的使用
1.AVG() 求平均数 select avg(prod_price) as avg_price from products; --返回商品价格的平均值 ; --返回生产商id为1003的商品价格平均 ...
- wr720n v4 折腾笔记(三):网络配置与扩充USB
0x01 前言 网络配置比较简单,但是USB拓展就麻烦许多了,这里由于overlay的内存分配问题导致软件安装失败,这里找到了一种方法就是直接从uboot刷入南浦月大神的wr720n的openwrt固 ...
- Python 学习笔记三
笔记三:函数 笔记二已取消置顶链接地址:http://www.cnblogs.com/dzzy/p/5289186.html 函数的作用: 給代码段命名,就像变量給数字命名一样 可以接收参数,像arg ...
随机推荐
- jeesite 去掉 /a
1.修改 jeesite.properties文件 adminPath=/a为 adminPath= 2.修改 web.xml文件找到如下设置 <filter-mapping> <f ...
- java连接access的用户名、密码异常Decoding not supported解决
Java通过ucanaccess对Access数据库.accdb文件连接: public static Connection getConn() { try { String dbURL = &quo ...
- BGP2
1) 按照拓扑搭建网络,在所有AS间使用直连接口建立EBGP邻居关系: 2) 在公司总部AS400中,R4与R5,R5与R7,R7与R6,R6与R4间使用环回接口建立IBGP邻居关系,IGP协议使用O ...
- C# 将多个DataTable添加到指定的DataSet中
DataSet ds = new DataSet();//创建数据集 DataTable dt1=new DataTable(); //表1 DataTable dt2 = new DataTable ...
- C++语言堆栈的详细讲解
本文主要向大家介绍了C++语言堆栈的详细讲解,通过具体的内容向大家展示,希望对大家学习C++语言有所帮助. 一.预备知识—程序的内存分配 一个由c/C++编译的程序占用的内存分为以下几个部分 1.栈区 ...
- Doc: NetBeans
NetBeans的最新版本已经更新为Apache NetBeans. 安装JDK 在Mac OS X下,有".dmg"的安装包,可以直接安装.只要JDK的版本大于1.8.0就可以安 ...
- markdown直接粘贴截图
通过代码方式 cmd markdown粘贴截图 https://www.jianshu.com/p/ae048b5090f8
- Flask从负到零的一周
新的一年,因为似乎要做很多的数据库,准备入坑Flask.开了一次讨论,我感觉自己燃起来了.于是,先买了一个号角状的水杯压压惊.目前通过一周的艰辛努力,终于做了一个小网站,支持数据库增删改查,算是从零到 ...
- 转:zabbix 更改maps图标
更改Zabbix map图标 Zabbix的maps用来图形化显示监控设备的拓扑图,并且以不同的标记显示故障事件,通过该图表很直观的显示设备的整体情况.系统默认的图标比较简陋,如图十一所示.通过更改系 ...
- Catch That Cow (BFS)
题目: Farmer John has been informed of the location of a fugitive cow and wants to catch her immediate ...