Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制。当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型。

这里以tensorflow2官网中的例子来说明:

import numpy as np
import tensorflow as tf
from tensorflow import keras
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
class CustomModel(keras.Model):
tf.random.set_seed(100)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics} # Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss=tf.losses.MSE, metrics=["mae"]) # Just use `fit` as usual model.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 1ms/step - loss: 0.2783 - mae: 0.4257

<tensorflow.python.keras.callbacks.History at 0x7ff7edf6dfd0>

这里的loss是tensorflow库中实现了的损失函数,如果想自定义损失函数,然后将损失函数传入model.compile中,能正常按我们预想的work吗?

答案竟然是否定的,而且没有错误提示,只是loss计算不会符合我们的预期。

def custom_mse(y_true, y_pred):
return tf.reduce_mean((y_true - y_pred)**2, axis=-1)
a_true = tf.constant([1., 1.5, 1.2])
a_pred = tf.constant([1., 2, 1.5])
custom_mse(a_true, a_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.11333332>
tf.losses.MSE(a_true, a_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.11333332>

以上结果证实了我们自定义loss的正确性,下面我们直接将自定义的loss置入compile中的loss参数中,看看会发生什么。

my_model = CustomModel(inputs, outputs)
my_model.compile(optimizer="adam", loss=custom_mse, metrics=["mae"])
my_model.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 820us/step - loss: 0.1628 - mae: 0.3257

<tensorflow.python.keras.callbacks.History at 0x7ff7edeb7810>

我们看到,这里的loss与我们与标准的tf.losses.MSE明显不同。这说明我们自定义的loss以这种方式直接传递进model.compile中,是完全错误的操作。

正确运用自定义loss的姿势是什么呢?下面揭晓。

loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae") class MyCustomModel(keras.Model):
tf.random.set_seed(100)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = custom_mse(y, y_pred)
# loss += self.losses # Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Compute our own metrics
loss_tracker.update_state(loss)
mae_metric.update_state(y, y_pred)
return {"loss": loss_tracker.result(), "mae": mae_metric.result()} @property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [loss_tracker, mae_metric] # Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
my_model_beta = MyCustomModel(inputs, outputs)
my_model_beta.compile(optimizer="adam") # Just use `fit` as usual my_model_beta.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 960us/step - loss: 0.2783 - mae: 0.4257

<tensorflow.python.keras.callbacks.History at 0x7ff7eda3d810>

终于,通过跳过在 compile() 中传递损失函数,而在 train_step 中手动完成所有计算内容,我们获得了与之前默认tf.losses.MSE完全一致的输出,这才是我们想要的结果。

总结一下,当我们在模型中想用自定义的损失函数,不能直接传入fit函数,而是需要在train_step中手动传入,完成计算过程。

tensorflow2 自定义损失函数使用的隐藏坑的更多相关文章

  1. TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵

    TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵 神经元模型:用数学公式比表示为:f(Σi xi*wi + b), f为激活函数 神经网络 是以神经元为基本单位构成的 激 ...

  2. tensorflow 自定义损失函数示例

    这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变) 我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元. 如果我们 ...

  3. tensflow自定义损失函数

    tensflow 不仅支持经典的损失函数,还可以优化任意的自定义损失函数. 预测商品销量时,如果预测值比真实销量大,商家损失的是生产商品的成本:如果预测值比真实值小,损失的则是商品的利润. 比如如果一 ...

  4. 机器学习之路: tensorflow 自定义 损失函数

    git: https://github.com/linyi0604/MachineLearning/tree/master/07_tensorflow/ import tensorflow as tf ...

  5. Tensorflow 损失函数(loss function)及自定义损失函数(三)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/limiyudianzi/article ...

  6. SpringMVC自定义配置消息转换器踩坑总结

    问题描述 最近在开发时候碰到一个问题,springmvc页面向后台传数据的时候,通常我是这样处理的,在前台把数据打成一个json,在后台接口中使用@requestbody定义一个对象来接收,但是这次数 ...

  7. Fidder详解-工具简介(保存会话、decode解码、Repaly、自定义会话框、隐藏会话、会话排序)

    前言 本文会对Fidder这款工具的一些重要功能,进行详细讲解,带大家进入Fidder的世界,本文会让你明白,Fidder不仅是一个抓包分析工具,也是一个请求发送工具,更加可以当作为Mock Serv ...

  8. 隐藏软键盘(解决自定义Dialog中无法隐藏的问题)

    /** * Dialog中隐藏软键盘不管用 * @param activity */ public static void HideSoftKeyBoard(Activity activity){ t ...

  9. IOS 极光推送自定义通知遇到的一些坑

    主要方法: //自定义推送 - (void)networkDidReceiveMessage:(NSNotification *)notification { NSDictionary * userI ...

随机推荐

  1. 使用JS获取两个时间差(JS写一个倒计时功能)

    <body onload="myFunction()"> <p id="demo"></p> <script> ...

  2. 编译原理-一种词法分析器LEX原理

    1.将所有单词的正规集用正规式描述 2.用正规式到NFA的转换算 得到识别所有单词用NFA 3.用NFA到DFA的转换算法 得到识别所有单词用DFA 4.将DFA的状态转换函数表示成二维数组 并与DF ...

  3. Linkerd 2.10(Step by Step)—1. 将您的服务添加到 Linkerd

    为了让您的服务利用 Linkerd,它们还需要通过将 Linkerd 的数据平面代理(data plane proxy)注入到它们服务的 pod 中,从而进行网格化. Linkerd 2.10 中文手 ...

  4. Qt信号槽机制理解

    1. 信号和槽概述 > 信号槽是 Qt 框架引以为豪的机制之一.所谓信号槽,实际就是观察者模式(发布-订阅模式).当某个`事件`发生之后,比如,按钮检测到自己被点击了一下,它就会发出一个信号(s ...

  5. 简单测试 APISIX2.6 网关

    Apache APISIX是一个动态的.实时的.高性能的 API 网关.它提供丰富的流量管理功能,例如负载均衡.动态上游服务.金丝雀发布.断路.身份验证.可观察性等.您可以使用 Apache APIS ...

  6. P1831 杠杆数(数位Dp)

    题目描述 如果把一个数的某一位当成支点,且左边的数字到这个点的力矩和等于右边的数字到这个点的力矩和,那么这个数就可以被叫成杠杆数. 比如$4139$就是杠杆数,把3当成支点,我们有这样的等式:$4 \ ...

  7. 八、配置Tomcat日志

    [root@svr5 ~]# vim /usr/local/tomcat/conf/server.xml .. .. <Host name="www.a.com" appBa ...

  8. 【dp】状压dp

    二进制的力量 状态压缩DP 愤怒的小鸟 第一次接触状态压缩DP是在NOIP2016的愤怒的小鸟,当时菜得连题目都没看懂,不过现在回过头来看还是挺简单的,那么我们再来看看这道题吧. 题意&数据范 ...

  9. Java字符串比较(3种方法)以及对比 C++ 时的注意项

    字符串比较是常见的操作,包括比较相等.比较大小.比较前缀和后缀串等.在 Java 中,比较字符串的常用方法有 3 个:equals() 方法.equalsIgnoreCase() 方法. compar ...

  10. Kubernetes的认证机制

    1.了解认证机制 API服务器可以配置一到多个认证的插件(授权插件同样也可以).API服务器接收到的请求会经过一个认证插件的列表,列表中的每个插件都可以检查这个请求和尝试确定谁在发送这个请求.列表中的 ...