用keras搭好模型架构之后的下一步,就是执行编译操作。在编译时,经常需要指定三个参数

  • loss
  • optimizer
  • metrics

这三个参数有两类选择:

  • 使用字符串
  • 使用标识符,如keras.losses,keras.optimizers,metrics包下面的函数

例如:

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
optimizer=sgd,
metrics=['accuracy'])

因为有时可以使用字符串,有时可以使用标识符,令人很想知道背后是如何操作的。下面分别针对optimizer,loss,metrics三种对象的获取进行研究。

optimizer

一个模型只能有一个optimizer,在执行编译的时候只能指定一个optimizer。

在keras.optimizers.py中,有一个get函数,用于根据用户传进来的optimizer参数获取优化器的实例:

def get(identifier):
# 如果后端是tensorflow并且使用的是tensorflow自带的优化器实例,可以直接使用tensorflow原生的优化器
if K.backend() == 'tensorflow':
# Wrap TF optimizer instances
if isinstance(identifier, tf.train.Optimizer):
return TFOptimizer(identifier)
# 如果以json串的形式定义optimizer并进行参数配置
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):
# 如果以字符串形式指定optimizer,那么使用优化器的默认配置参数
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
if isinstance(identifier, Optimizer):
# 如果使用keras封装的Optimizer的实例
return identifier
else:
raise ValueError('Could not interpret optimizer identifier: ' +
str(identifier))

其中,deserilize(config)函数的作用就是把optimizer反序列化制造一个实例。

loss

keras.losses函数也有一个get(identifier)方法。其中需要注意以下一点:

如果identifier是可调用的一个函数名,也就是一个自定义的损失函数,这个损失函数返回值是一个张量。这样就轻而易举的实现了自定义损失函数。除了使用str和dict类型的identifier,我们也可以直接使用keras.losses包下面的损失函数。

def get(identifier):
if identifier is None:
return None
if isinstance(identifier, six.string_types):
identifier = str(identifier)
return deserialize(identifier)
if isinstance(identifier, dict):
return deserialize(identifier)
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'loss function identifier:', identifier)

metrics

在model.compile()函数中,optimizer和loss都是单数形式,只有metrics是复数形式。因为一个模型只能指明一个optimizer和loss,却可以指明多个metrics。metrics也是三者中处理逻辑最为复杂的一个。

在keras最核心的地方keras.engine.train.py中有如下处理metrics的函数。这个函数其实就做了两件事:

  • 根据输入的metric找到具体的metric对应的函数
  • 计算metric张量

在寻找metric对应函数时,有两种步骤:

  • 使用字符串形式指明准确率和交叉熵
  • 使用keras.metrics.py中的函数
def handle_metrics(metrics, weights=None):
metric_name_prefix = 'weighted_' if weights is not None else '' for metric in metrics:
# 如果metrics是最常见的那种:accuracy,交叉熵
if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
# custom handling of accuracy/crossentropy
# (because of class mode duality)
output_shape = K.int_shape(self.outputs[i])
# 如果输出维度是1或者损失函数是二分类损失函数,那么说明是个二分类问题,应该使用二分类的accuracy和二分类的的交叉熵
if (output_shape[-1] == 1 or
self.loss_functions[i] == losses.binary_crossentropy):
# case: binary accuracy/crossentropy
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.binary_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.binary_crossentropy
# 如果损失函数是sparse_categorical_crossentropy,那么目标y_input就不是one-hot的,所以就需要使用sparse的多类准去率和sparse的多类交叉熵
elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
# case: categorical accuracy/crossentropy
# with sparse targets
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.sparse_categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.sparse_categorical_crossentropy
else:
# case: categorical accuracy/crossentropy
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.categorical_crossentropy
if metric in ('accuracy', 'acc'):
suffix = 'acc'
elif metric in ('crossentropy', 'ce'):
suffix = 'ce'
weighted_metric_fn = weighted_masked_objective(metric_fn)
metric_name = metric_name_prefix + suffix
else:
# 如果输入的metric不是字符串,那么就调用metrics模块获取
metric_fn = metrics_module.get(metric)
weighted_metric_fn = weighted_masked_objective(metric_fn)
# Get metric name as string
if hasattr(metric_fn, 'name'):
metric_name = metric_fn.name
else:
metric_name = metric_fn.__name__
metric_name = metric_name_prefix + metric_name with K.name_scope(metric_name):
metric_result = weighted_metric_fn(y_true, y_pred,
weights=weights,
mask=masks[i]) # Append to self.metrics_names, self.metric_tensors,
# self.stateful_metric_names
if len(self.output_names) > 1:
metric_name = self.output_names[i] + '_' + metric_name
# Dedupe name
j = 1
base_metric_name = metric_name
while metric_name in self.metrics_names:
metric_name = base_metric_name + '_' + str(j)
j += 1
self.metrics_names.append(metric_name)
self.metrics_tensors.append(metric_result) # Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
if isinstance(metric_fn, Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
self.stateful_metric_functions.append(metric_fn)
self.metrics_updates += metric_fn.updates

无论怎么使用metric,最终都会变成metrics包下面的函数。当使用字符串形式指明accuracy和crossentropy时,keras会非常智能地确定应该使用metrics包下面的哪个函数。因为metrics包下的那些metric函数有不同的使用场景,例如:

  • 有的处理的是one-hot形式的y_input(数据的类别),有的处理的是非one-hot形式的y_input
  • 有的处理的是二分类问题的metric,有的处理的是多分类问题的metric

当使用字符串“accuracy”和“crossentropy”指明metric时,keras会根据损失函数、输出层的shape来确定具体应该使用哪个metric函数。在任何情况下,直接使用metrics下面的函数名是总不会出错的。

keras.metrics.py文件中也有一个get(identifier)函数用于获取metric函数。

def get(identifier):
if isinstance(identifier, dict):
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
elif isinstance(identifier, six.string_types):
return deserialize(str(identifier))
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'metric function identifier:', identifier)

如果identifier是字符串或者字典,那么会根据identifier反序列化出一个metric函数。

如果identifier本身就是一个函数名,那么就直接返回这个函数名。这种方式就为自定义metric提供了巨大便利。

keras中的设计哲学堪称完美。

keras中的loss、optimizer、metrics的更多相关文章

  1. 【tf.keras】tf.keras使用tensorflow中定义的optimizer

    Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...

  2. keras中保存自定义层和loss

    在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...

  3. Python机器学习笔记:深入理解Keras中序贯模型和函数模型

     先从sklearn说起吧,如果学习了sklearn的话,那么学习Keras相对来说比较容易.为什么这样说呢? 我们首先比较一下sklearn的机器学习大致使用流程和Keras的大致使用流程: skl ...

  4. Python机器学习笔记:深入学习Keras中Sequential模型及方法

    Sequential 序贯模型 序贯模型是函数式模型的简略版,为最简单的线性.从头到尾的结构顺序,不分叉,是多个网络层的线性堆叠. Keras实现了很多层,包括core核心层,Convolution卷 ...

  5. (数据科学学习手札44)在Keras中训练多层感知机

    一.简介 Keras是有着自主的一套前端控制语法,后端基于tensorflow和theano的深度学习框架,因为其搭建神经网络简单快捷明了的语法风格,可以帮助使用者更快捷的搭建自己的神经网络,堪称深度 ...

  6. Deep Learning 32: 自己写的keras的一个callbacks函数,解决keras中不能在每个epoch实时显示学习速率learning rate的问题

    一.问题: keras中不能在每个epoch实时显示学习速率learning rate,从而方便调试,实际上也是为了调试解决这个问题:Deep Learning 31: 不同版本的keras,对同样的 ...

  7. 在Keras中可视化LSTM

    作者|Praneet Bomma 编译|VK 来源|https://towardsdatascience.com/visualising-lstm-activations-in-keras-b5020 ...

  8. keras中seq2seq实现

    这里只是简单的一个例子 输入序列 目标序列 [13, 28, 18, 7, 9, 5] [18, 28, 13] [29, 44, 38, 15, 26, 22] [38, 44, 29] [27, ...

  9. keras中调用tensorboard:from keras.callbacks import TensorBoard

    from keras.models import Sequential from keras.layers import Dense from keras.wrappers.scikit_learn ...

随机推荐

  1. win7系统不能用telnet命令的两种解决方法

    电脑专业人员对telnet命令都不陌生了,Telnet当成一种通信协议,在日常工作中,经常面对网络问题的人都会用到telnet命令,因为简单有效,可以帮助更快的找出问题.要是在使用过程中碰到win7纯 ...

  2. 我对android 软件栈了解

    android 软件栈如图所示: Android平台的核心是Linux内核,它负责设备驱动程序.资源访问.电源管理和完成其他操作系统的职责.提供的设备驱动程序包括显示器.照相机,键盘.WiFi.闪存. ...

  3. Kafka:ZK+Kafka+Spark Streaming集群环境搭建(二)安装hadoop2.9.0

    如何搭建配置centos虚拟机请参考<Kafka:ZK+Kafka+Spark Streaming集群环境搭建(一)VMW安装四台CentOS,并实现本机与它们能交互,虚拟机内部实现可以上网.& ...

  4. Spring(十三):使用工厂方法来配置Bean的两种方式(静态工厂方法&实例工厂方法)

    通过调用静态工厂方法创建Bean 1)调用静态工厂方法创建Bean是将对象创建的过程封装到静态方法中.当客户端需要对象时,只需要简单地调用静态方法,而不需要关心创建对象的具体细节. 2)要声明通过静态 ...

  5. 【html5】HTML5中canvas怎样画虚线

    在canvas API中,我们发现仅仅提供了画实线的方法实现,并没有虚线的相关方法,那么怎样实现画虚线呢? 现实中,虚线是由一小段小段的实线线段组成,那么仅仅要我们通过画出等长度的线段就能够组成我们想 ...

  6. IntelliJ - idea15.0.2 破解方法

    由于idea 15版本更换了注册方式,只能通过联网激活,所以现在不能通过简单的通用注册码进行离线注册了, 虽然可以继续用14版本,但是有新版本却无法尝试让强迫症也是异常抓狂. 通过度娘我找到了一个破解 ...

  7. Android Device Monitor工具的使用

    在最新的Android Studio3.x版本中,已经去掉了Android Device Monitor工具,但是不代表Android Device Monitor工具就不能用了,找到sdk的目录: ...

  8. jQuery的md5加密插件及其它js md5加密代码

    /** * jQuery MD5 hash algorithm function * * <code> * Calculate the md5 hash of a String * Str ...

  9. 微信小程序 - cb回调(typeof cb == "function" && cb(obj);)

    typeof cb == "function" && cb(obj) 但凡用了Promise,这种方式就可以抛弃了. Page({ data: {}, onLoad ...

  10. 第六周 Word目录和索引

    第六周 Word目录和索引 教学时间 2013-4-2 教学课时 2 教案序号 5 教学目标 能正确使用索引.目录等 教学过程: 复习提问 1.脚注和尾注的区别是什么?2.如何插入脚注和尾注?3.如何 ...