【tf.keras】实现 F1 score、precision、recall 等 metric
tf.keras.metric 里面竟然没有实现 F1 score、recall、precision 等指标,一开始觉得真不可思议。但这是有原因的,这些指标在 batch-wise 上计算都没有意义,需要在整个验证集上计算,而 tf.keras 在训练过程(包括验证集)中计算 acc、loss 都是一个 batch 计算一次的,最后再平均起来。Keras 2.0 版本将 precision, recall, fbeta_score, fmeasure 等 metrics 移除了。
虽然 tf.keras.metric 中没有实现 f1 socre、precision、recall,但我们可以通过 tf.keras.callbacks.Callback 实现。即在每个 epoch 末尾,在整个 val 上计算 f1、precision、recall。
一些博客实现了二分类下的 f1 socre、precision、recall,如下所示:
- How to compute f1 score for each epoch in Keras -- Thong Nguyen
- keras如何求分类问题中的准确率和召回率? - 鱼塘邓少的回答 - 知乎
以下代码实现了多分类下对验证集 F1 值、precision、recall 的计算,并且保存 val_f1 值最好的模型:
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, recall_score, precision_score
import numpy as np
import os
class Metrics(tf.keras.callbacks.Callback):
def __init__(self, valid_data):
super(Metrics, self).__init__()
self.validation_data = valid_data
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
val_predict = np.argmax(self.model.predict(self.validation_data[0]), -1)
val_targ = self.validation_data[1]
if len(val_targ.shape) == 2 and val_targ.shape[1] != 1:
val_targ = np.argmax(val_targ, -1)
_val_f1 = f1_score(val_targ, val_predict, average='macro')
_val_recall = recall_score(val_targ, val_predict, average='macro')
_val_precision = precision_score(val_targ, val_predict, average='macro')
logs['val_f1'] = _val_f1
logs['val_recall'] = _val_recall
logs['val_precision'] = _val_precision
print(" — val_f1: %f — val_precision: %f — val_recall: %f" % (_val_f1, _val_precision, _val_recall))
return
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=10000, random_state=32)
# LeNet-5
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(32, 32, 3)),
tf.keras.layers.Conv2D(6, 5, activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Conv2D(16, 5, activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(120, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(84, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
if not os.path.exists('./checkpoints'):
os.makedirs('./checkpoints')
# 按照 val_f1 保存模型
ck_callback = tf.keras.callbacks.ModelCheckpoint('./checkpoints/weights.{epoch:02d}-{val_f1:.4f}.hdf5',
monitor='val_f1',
mode='max', verbose=2,
save_best_only=True,
save_weights_only=True)
tb_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=0)
model.fit(x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
callbacks=[Metrics(valid_data=(x_val, y_val)),
ck_callback,
tb_callback])
注意 Metrics()
和 ck_callback
两个 callback 的顺序,互换之后将报错。
References
How to calculate F1 Macro in Keras? -- StackOverflow
How to compute f1 score for each epoch in Keras -- Thong Nguyen
keras如何求分类问题中的准确率和召回率? - 鱼塘邓少的回答 - 知乎
Keras 2.0 release notes -- keras-team/keras
【tf.keras】实现 F1 score、precision、recall 等 metric的更多相关文章
- How to compute f1 score for each epoch in Keras
https://medium.com/@thongonary/how-to-compute-f1-score-for-each-epoch-in-keras-a1acd17715a2 https:// ...
- Precision,Recall,F1的计算
Precision又叫查准率,Recall又叫查全率.这两个指标共同衡量才能评价模型输出结果. TP: 预测为1(Positive),实际也为1(Truth-预测对了) TN: 预测为0(Negati ...
- 机器学习:评价分类结果(F1 Score)
一.基础 疑问1:具体使用算法时,怎么通过精准率和召回率判断算法优劣? 根据具体使用场景而定: 例1:股票预测,未来该股票是升还是降?业务要求更精准的找到能够上升的股票:此情况下,模型精准率越高越优. ...
- 【笔记】F1 score
F1 score 关于精准率和召回率 精准率和召回率可以很好的评价对于数据极度偏斜的二分类问题的算法,有个问题,毕竟是两个指标,有的时候这两个指标也会产生差异,对于不同的算法,精准率可能高一些,召回率 ...
- 机器学习中的 precision、recall、accuracy、F1 Score
1. 四个概念定义:TP.FP.TN.FN 先看四个概念定义: - TP,True Positive - FP,False Positive - TN,True Negative - FN,False ...
- 机器学习--如何理解Accuracy, Precision, Recall, F1 score
当我们在谈论一个模型好坏的时候,我们常常会听到准确率(Accuracy)这个词,我们也会听到"如何才能使模型的Accurcy更高".那么是不是准确率最高的模型就一定是最好的模型? 这篇博文会向大家解释 ...
- 【tf.keras】使用手册
目录 0. 简介 1. 安装 1.1 安装 CUDA 和 cuDNN 2. 数据集 2.1 使用 tensorflow_datasets 导入公共数据集 2.2 数据集过大导致内存溢出 2.3 加载 ...
- 查准与召回(Precision & Recall)
Precision & Recall 先看下面这张图来理解了,后面再具体分析.下面用P代表Precision,R代表Recall 通俗的讲,Precision 就是检索出来的条目中(比如网页) ...
- hihocoder 1522 : F1 Score
题目链接 时间限制:10000ms 单点时限:1000ms 内存限制:256MB 描述 小Hi和他的小伙伴们一起写了很多代码.时间一久有些代码究竟是不是自己写的,小Hi也分辨不出来了. 于是他实现 ...
随机推荐
- python基础-集合set及内置方法
数据类型之集合-set 用途:多用于去重,关系运算 定义方式:通过大括号存储,集合中的每个元素通过逗号分隔.集合内存储的元素必须是不可变的,因此,列表-List 和字典dict 不能存储在集合中 注意 ...
- MySQL查询-分组取组中某字段最大(小)值所有记录
最近做东西的时候,用到一个数据库的查询.将记录按某个字段分组,取每个分组中某个字段的最大值的所有记录.举栗子来说. 已知分数表“score”,包含字段“id", "name&quo ...
- [考试反思]1105csp-s模拟测试101: 临别
先不改题,这次主要不在T3上. 这次有必要粘文件得分了. 临考前总解锁新锅我也不知道这是什么个事啊... T1宏定义写挂.因为原来在OJ上没事所以一直没注意.在Lemon评测下直接全部RE. GG在主 ...
- [考试反思]1031csp-s模拟测试95:优势
假的三首杀.因为交文件所以他们都不往OJ上交. 假装是三首杀吧.嗯. 然而昨天还是没有AK.好像说是按照64位评测机的评测结果为准. 但是联赛省选的机子好像都是32位的?也就是和我们正在用的一致. 所 ...
- [UWP]使用SpringAnimation创建有趣的动画
1. 什么是自然动画 最近用弹簧动画(SpringAnimation)做了两个番茄钟,关于弹簧动画官方文档已经介绍得够详细了,这篇文章就摘录一些官方文档核心内容. 在传统UI中,关键帧动画(KeyFr ...
- 「Luogu 1821」[USACO07FEB]银牛派对Silver Cow Party
更好的阅读体验 Portal Portal1: Luogu Portal2: POJ Description One cow from each of N farms \((1 \le N \le 1 ...
- 腾讯开源进入爆发期,Plato助推十亿级节点图计算进入分钟级时代
腾讯开源再次迎来重磅项目,14日,腾讯正式宣布开源高性能图计算框架Plato,这是在短短一周之内,开源的第五个重大项目. 相对于目前全球范围内其它的图计算框架,Plato可满足十亿级节点的超大规模图计 ...
- 009-2010网络最热的 嵌入式学习|ARM|Linux|wince|ucos|经典资料与实例分析
前段时间做了一个关于ARM9 2440资料的汇总帖,很高兴看到21ic和CSDN等论坛朋友们的支持和鼓励.当年学单片机的时候datasheet和学习资料基本都是在论坛上找到的,也遇到很多好心的高手朋友 ...
- centos6的redis安装
1.到redis的官网下载redis压缩包 https://redis.io/ 2.利用命令 mkdir /usr/local/redis 新建redis文件夹 并将redis压缩包移动到新建的文件夹 ...
- go中的数据结构接口-interface
1. 接口的基本使用 golang中的interface本身也是一种类型,它代表的是一个方法的集合.任何类型只要实现了接口中声明的所有方法,那么该类就实现了该接口.与其他语言不同,golang并不需要 ...