训练模型时,很多事情一开始都无法预测。比如之前我们为了找出迭代多少轮才能得到最佳验证损失,可能会先迭代100次,迭代完成后画出运行结果,发现在中间就开始过拟合了,于是又重新开始训练。

类似的情况很多,于是我们想要实时监测训练动态,并能根据训练情况及时对模型采取一定的措施。Keras中的回调函数和tf的TensorBoard就是为此而生。

Keras回调函数

回调函数(callbacks)是在调用fit时传入模型的一个对象,它在训练过程中的不同时间点都会被模型调用。它可以访问关于模型状态和性能的所有可用数据,还可以采取行动:中断训练、保存模型、加载一组不同的权重或者改变模型的状态。也就是说,之前在训练模型的过程中,我们不知道模型的实时状态,因此为了更好的监测和控制模型的训练过程,我们派出了一个特派员——回调函数,它可以根据情况记录、反馈或者采取措施。我们熟悉的训练进度条和fit返回的history都是回调函数,只不过它俩因为太常用,所以被单独拎出来。

fit和fit_generator函数都提供了callbacks接口。常用的回调函数有:

  • ModelCheckpoint(在每轮过后保存当前模型);
  • EarlyStopping(如果监控参数得不到改善就中断训练);
  • LearningRateScheduler(在训练过程中动态调整学习率);
  • ReduceLROnPlateau(如果验证表现得不到改善,可以用它降低学习率,跳出局部最小值);
  • CSVLogger(将每个epoch的结果写入CSV文件)。
  • 其他回调函数,也可以根据需要自行编写。

应用示例:

from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

#fit提供callbacks接口,接收一个回调函数列表,可将任意个回调函数传入模型中
callback_lists = [] callback_lists.append(EarlyStopping(monitor = 'acc', #监控模型的验证精度
patience = 1)) #如果精度在多于一轮的时间(即两轮)内不再改善,就中断训练 callback_lists.append(ModelCheckpoint(filepath = 'my_model.h5', #目标文件的保存路径
monitor = 'val_loss', #监控验证损失
save_best_only = True)) #只保存最佳模型 callback_lists.append(ReduceLROnPlateau(monitor = 'val_loss', #监控模型的验证损失
factor = 0.1, #触发时将学习率乘以系数0.1
patience = 10) #若验证损失在10轮内都没有改善,则触发该回调函数 #由于回调函数要监控验证损失和验证精度,所以在调用fit时需要传入validation_data
model.fit(x, y, epochs = 10, batch_size = 32,
callbacks = callbacks_list,
validation_data = (x_val, y_val))

TensorBoard:实时可视化工具

TensorBoard是内置于TensorFlow中基于浏览器的可视化工具,安装TensorFlow时会自动安装这个工具。简单来说,它就是把训练过程数据写入文件,然后用浏览器查看的工具。在Keras中,它也被包装成一个回调函数。

示例如下:

#引入Tensorboard
from keras.callbacks import TensorBoard #定义回调函数列表,现在只放一个简单的TensorBoard
log_path = './logs' #指定TensorBoard读取的文件路径,可以新建一个
callback_lists = [TensorBoard(log_dir=log_path, histogram_freq=1)] #模型调用fit时,通过回调函数接口传入
model.fit(...inputs and parameters..., callbacks=callback_lists)

为了在训练的过程中可视化各项指标,需要自己在终端启动TensorBoard。

打开终端的方式有两种:一种是系统自带的终端cmd;另一种是在Anaconda Prompt终端。选择用哪种终端打开,根据当时安装tensorflow时用的终端方式。我试了下cmd,总是出错,但在Anaconda Prompt终端就能正常启动。

启动方式:在终端输入 tensorboard --logdir=C:\Users...\logs (自己文件的路径),就会返回一行信息,包含了一个http网址。这个地址一般是不会改变的,在浏览器中输入提示的http地址,即可查看模型的训练过程和相关状态,如下图所示。



Reference

书籍:Python深度学习

CNN基础四:监测并控制训练过程的法宝——Keras回调函数和TensorBoard的更多相关文章

  1. iNeuOS工业互联平台,PLC监测与控制应用过程案例。新闻:.NET 6 RC1 正式发布

    目       录 1.      概述... 1 2.      平台演示... 2 3.      应用过程... 2 1.   概述 iNeuOS工业互联网操作系统主要使用.netcore 3. ...

  2. 51定时器控制4各led,使用回调函数机制

    程序转载自51hei,经过自己的实际验证,多了一种编程的思路技能,回调函数的基本思想也是基于事件机制的,哪个事件来了, 就执行哪个事件. 程序中,最多四个子定时器,说明51的处理速度是不够的,在中断中 ...

  3. CNN基础一:从头开始训练CNN进行图像分类(猫狗大战为例)

    本文旨在总结一次从头开始训练CNN进行图像分类的完整过程(猫狗大战为例,使用Keras框架),免得经常遗忘.流程包括: 从Kaggle下载猫狗数据集: 利用python的os.shutil库,制作训练 ...

  4. CNN基础二:使用预训练网络提取图像特征

    上一节中,我们采用了一个自定义的网络结构,从头开始训练猫狗大战分类器,最终在使用图像增强的方式下得到了82%的验证准确率.但是,想要将深度学习应用于小型图像数据集,通常不会贸然采用复杂网络并且从头开始 ...

  5. 深度学习基础(CNN详解以及训练过程1)

    深度学习是一个框架,包含多个重要算法: Convolutional Neural Networks(CNN)卷积神经网络 AutoEncoder自动编码器 Sparse Coding稀疏编码 Rest ...

  6. 卷积神经网络(CNN)的训练过程

    卷积神经网络的训练过程 卷积神经网络的训练过程分为两个阶段.第一个阶段是数据由低层次向高层次传播的阶段,即前向传播阶段.另外一个阶段是,当前向传播得出的结果与预期不相符时,将误差从高层次向底层次进行传 ...

  7. 从零搭建Pytorch模型教程(四)编写训练过程--参数解析

    ​  前言 训练过程主要是指编写train.py文件,其中包括参数的解析.训练日志的配置.设置随机数种子.classdataset的初始化.网络的初始化.学习率的设置.损失函数的设置.优化方式的设置. ...

  8. 卷积神经网络(CNN)基础介绍

    本文是对卷积神经网络的基础进行介绍,主要内容包含卷积神经网络概念.卷积神经网络结构.卷积神经网络求解.卷积神经网络LeNet-5结构分析.卷积神经网络注意事项. 一.卷积神经网络概念 上世纪60年代. ...

  9. CNN基础框架简介

    卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...

随机推荐

  1. [BZOJ2341][Shoi2011]双倍回文 manacher+std::set

    题目链接 发现双倍回文串一定是中心是#的回文串. 所以考虑枚举#点.发现以\(i\)为中心的双倍回文的左半部分是个回文串,其中心一定位于\(i-\frac{pal[i]-1}2\)到\(i-1\)之间 ...

  2. python3输出中文报错的原因,及解决办法(基于pycharm)

    通常python3里面如果有中文,在不连接其他设备和程序的情况下,报错信息大致如下: SyntaxError: Non-UTF-8 code starting with '\xd6' in file ...

  3. hadoop通过java输出HAFS上的文件内容

    package org.apache.hadoop.book;import java.io.InputStream;import java.net.URL;import org.apache.hado ...

  4. SpringBoot---Favicon配置

    1.概述 1.1.SpringBoot提供了一个默认的Favicon,每次访问都能看到: 2.关闭Favicon 在application.yml中设置关闭Favicon,默认开启: spring.m ...

  5. java中switch的用法以及判断的类型有哪些(String\byte\short\int\char\枚举类型)

    switch关键字对于多数java学习者来说并不陌生,由于笔试和面试经常会问到它的用法,这里做了一个简单的总结: 能用于switch判断的类型有:byte.short.int.char(JDK1.6) ...

  6. JavaScript之ECMAScript

    JavaScript脚本语言, 运行在浏览器上,无需编译, 轻量级的语言. 功能:让页面有执行逻辑的功能, 可以产生一些动态的效果 JavaScript = ECMAScript + BOM + DO ...

  7. Dataphin帮助企业构建数据中台系列之--萃取数据中心

    Dataphin作为阿里巴巴数据中台OneData (OneModel.OneID.OneService)方法论的产品载体,帮助企业构建三大数据中心:基于数据集成形成的垂直数据中心.基于数据开发沉淀的 ...

  8. spring-cloud:eureka server单机、双机、集群示例

    1.运行环境 开发工具:intellij idea JDK版本:1.8 项目管理工具:Maven 4.0.0 2.GITHUB地址 https://github.com/nbfujx/springCl ...

  9. Android中实现Activity的启动拦截之----实现360卫士的安装应用界面

    第一.摘要 今天不是周末,但是我已经放假了,所以就开始我们的技术探索之旅,今天我们来讲一下Android中最期待的技术,就是拦截Activity的启动,其实我在去年的时候,就像实现这个技术了,但是因为 ...

  10. kNN(从文本文件中解析数据)

    # 准备数据:从文本文件中解析数据# 在kNN.py中创建名为file2matrix的函数,处理输入格式问题# 该函数的输入为文件名字符串,输出为训练样本矩阵和类标签向量# 将文本记录到转换Numpy ...