用keras训练模型并实时显示loss/acc曲线,(重要的事情说三遍:实时!实时!实时!)实时导出loss/acc数值(导出的方法就是实时把loss/acc等写到一个文本文件中,其他模块如前端调用时可直接读取文本文件),同时也涉及了plt画图方法

ps:以下代码基于网上的一段程序修改完成,如有侵权,请联系我哈!

上代码:

from keras import Sequential, initializers, optimizers
from keras.layers import Activation, Dense
import numpy as np
import pylab as pl
from IPython import display
from keras.callbacks import Callback
from keras.datasets import mnist
import keras
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Dense, Dropout, Flatten #定义回调函数的类,用于实时显示loss/acc曲线和导出loss/acc数值
class DrawCallback(Callback):
def __init__(self, runtime_plot=True): # 初始化 self.init_loss = None
self.init_val_loss = None
self.init_acc = None
self.init_val_acc = None
self.runtime_plot = runtime_plot self.xdata = []
self.ydata = []
self.ydata2 = []
self.ydata3 = []
self.ydata4 = []
def _plot(self, epoch=None):
epochs = self.params.get("epochs")
pl.subplot(121) #画第一个图,121表示纵向1个图,横向2个图,当前第1个图
pl.ylim(0, int(self.init_loss*1.2)) #限制坐标轴范围
pl.xlim(0, epochs)
pl.plot(self.xdata, self.ydata,'r', label='loss') #xdata/ydata均为不断增长的一维数组,同时定义了线段颜色/类型/图例
pl.plot(self.xdata, self.ydata2, 'b--', label='val_loss')
pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs)) #坐标轴显示变化的标签
pl.ylabel('Loss {:.4f}'.format(self.ydata[-1]))
pl.legend() #显示图例,不加这个即便是定义图例了也没用
pl.title('loss') #显示标题 pl.subplot(122)
pl.ylim(0, 1.2)
pl.xlim(0, epochs)
pl.plot(self.xdata, self.ydata3,'r', label='acc')
pl.plot(self.xdata, self.ydata4, 'b--', label='val_acc')
pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs))
pl.ylabel('Loss {:.4f}'.format(self.ydata[-1]))
pl.legend()
pl.title('acc') def _runtime_plot(self, epoch):
self._plot(epoch)
#不断的清图
display.clear_output(wait=True)
display.display(pl.gcf())
pl.gcf().clear() def plot(self):
self._plot()
pl.show() #显示窗口 def on_epoch_end(self, epoch, logs = None): #更新xdata/ydata
logs = logs or {}
# batch_size = self.params.get("batch_size")
epochs = self.params.get("epochs") #获取训练相关数据
loss = logs.get("loss")
val_loss = logs.get("val_loss")
acc = logs.get("acc")
val_acc = logs.get("val_acc") epochs_str = str(epochs)[0:6] #为了写入txt,必须转为字符型,为了美观只保留小数点后4位
loss_str = str(loss)[0:6]
val_loss_str = str(val_loss)[0:6]
acc_str = str(acc)[0:6]
val_acc_str = str(val_acc)[0:6] f = open('logs_r/record.txt','a') #要用追加方式‘a’写入txt,所在行数就是当前迭代次数
f.write('epochs:{}_loss:{}_val_loss:{}_acc:{}_val_acc{}'.format(epochs_str,loss_str,val_loss_str,acc_str,val_acc_str))
f.write('\n')
f.close() if self.init_loss is None: #增加xdata/ydata内容
self.init_loss = loss
self.init_val_loss = val_loss
self.xdata.append(epoch)
self.ydata.append(loss)
self.ydata2.append(val_loss)
self.ydata3.append(acc)
self.ydata4.append(val_acc)
if self.runtime_plot:
self._runtime_plot(epoch) # 下面开始构建keras需要的东西
def viz_keras_fit(runtime_plot=False):
d = DrawCallback(runtime_plot = runtime_plot) #实例化回调函数
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1,28,28,1)
x_test = x_test.reshape(-1,28,28,1)
input_shape = (28,28,1)
x_train = x_train/255
x_test = x_test/255
y_train = keras.utils.to_categorical(y_train,10)
y_test = keras.utils.to_categorical(y_test,10)
#为了减小计算量,减少了训练/测试数据
x_train = x_train[0:600,:,:,:]
x_test = x_test[0:100,:,:,:]
y_train = y_train[0:600,:]
y_test = y_test[0:100,:] model = Sequential() #实例化一个模型
#接下来一顿操作,就是搭建网络
model.add(Conv2D(filters=32, kernel_size=(3,3),
activation='relu', input_shape=input_shape,
name='conv1'))
model.add(Conv2D(64,(3,3),activation='relu',name='conv2'))
model.add(MaxPooling2D(pool_size=(2,2),name='pool2'))
model.add(Dropout(0.25,name='dropout1'))
model.add(Flatten(name='flat1'))
model.add(Dense(128,activation='relu'))
model.add(Dropout(0.5,name='dropout2'))
model.add(Dense(10,activation='softmax',name='output'))
#编译网络,同时定义了loss方法/优化方法/监测内容
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
#开始训练
model.fit(x = x_train,
y = y_train,
epochs=30,
verbose=0, #当值为1时,会打印训练过程
validation_data=(x_test, y_test), #加入测试数据,不然有些数据时看不到的
callbacks=[d]) #指定回调函数
return d

  

最后运行:

viz_keras_fit(runtime_plot=True) #调用函数

显示结果:

keras训练实例-python实现的更多相关文章

  1. keras训练cnn模型时loss为nan

    keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...

  2. Keras 训练 inceptionV3 并移植到OpenCV4.0 in C++

    1. 训练 # --coding:utf--- import os import sys import glob import argparse import matplotlib.pyplot as ...

  3. 使用Keras训练大规模数据集

    官方提供的.flow_from_directory(directory)函数可以读取并训练大规模训练数据,基本可以满足大部分需求.但是在有些场合下,需要自己读取大规模数据以及对应标签,下面提供一种方法 ...

  4. keras训练和保存

    https://cloud.tencent.com/developer/article/1010815 8.更科学地模型训练与模型保存 filepath = 'model-ep{epoch:03d}- ...

  5. Keras 训练一个单层全连接网络的线性回归模型

    1.准备环境,探索数据 import numpy as np from keras.models import Sequential from keras.layers import Dense im ...

  6. Keras 入门实例

    使用Keras构建神经网络的基本工作流程主要可以分为 4个部分.(而这个用法和思路,很像是在使用Scikit-learn中的机器学习方法) Model definition → Model compi ...

  7. 使用Keras训练神经网络备忘录

    小书匠深度学习 文章太长,放个目录: 1.优化函数的选择 2.损失函数的选择 2.2常用的损失函数 2.2自定义函数 2.1实践 2.2将损失函数自定义为网络层 3.模型的保存 3.1同时保持结构和权 ...

  8. keras训练大量数据的办法

    最近在做一个鉴黄的项目,数据量比较大,有几百个G,一次性加入内存再去训练模青型是不现实的. 查阅资料发现keras中可以用两种方法解决,一是将数据转为tfrecord,但转换后数据大小会方法不好:另外 ...

  9. 【机器学习实战学习笔记(1-2)】k-近邻算法应用实例python代码

    文章目录 1.改进约会网站匹配效果 1.1 准备数据:从文本文件中解析数据 1.2 分析数据:使用Matplotlib创建散点图 1.3 准备数据:归一化特征 1.4 测试算法:作为完整程序验证分类器 ...

随机推荐

  1. Java实现 LeetCode 面试题13. 机器人的运动范围(DFS)

    面试题13. 机器人的运动范围 地上有一个m行n列的方格,从坐标 [0,0] 到坐标 [m-1,n-1] .一个机器人从坐标 [0, 0] 的格子开始移动,它每次可以向左.右.上.下移动一格(不能移动 ...

  2. Java实现 蓝桥杯 算法提高 上帝造题五分钟

    算法提高 上帝造题五分钟 时间限制:1.0s 内存限制:256.0MB 问题描述 第一分钟,上帝说:要有题.于是就有了L,Y,M,C 第二分钟,LYC说:要有向量.于是就有了长度为n写满随机整数的向量 ...

  3. Java实现 蓝桥杯 算法提高 7-1用宏求球的体积

    算法提高 7-1用宏求球的体积 时间限制:1.0s 内存限制:256.0MB 问题描述 使用宏实现计算球体体积的功能.用户输入半径,系统输出体积.不能使用函数,pi=3.1415926,结果精确到小数 ...

  4. Java实现 LeetCode 470 用 Rand7() 实现 Rand10()

    470. 用 Rand7() 实现 Rand10() 已有方法 rand7 可生成 1 到 7 范围内的均匀随机整数,试写一个方法 rand10 生成 1 到 10 范围内的均匀随机整数. 不要使用系 ...

  5. java实现识别复制串

    ** 识别复制串** 代码的目标:判断一个串是否为某个基本串的简单复制构成的. 例如: abcabcabc,它由"abc"复制3次构成,则程序输出:abc aa 由"a& ...

  6. vim编辑器添加插件NERDTree

    0x01 首先在 http://www.vim.org/scripts/script.php?script_id=1658 下载插件 (可能要爬梯,也可以在https://github.com/scr ...

  7. 利用 Powershell 编写简单的浏览器脚本

    生活中有很多事情是低效益,重复性.比如每天上某些网站,先登录再签到打卡,比如每隔一段时间清理回收站的文件等等.一个成熟的软件工程师应该想到用软件解决他. 对于这些简单的小任务,一般用脚本实现.比如Py ...

  8. <VCC笔记> 溢出与unchecked

    在程序运算或者数据转换的时候,由于各种数据类型有各自的范围,运算的时候,其结果如果超出这个范围,就被称之为溢出.熟悉C#的同志们应该了解用来解决溢出(Overflow)问题的checked,unche ...

  9. git新手入门问题总结

    git新手入门问题总结 前言 本人为2019年6月份刚刚毕业,大三暑假中旬来到上海,实习时间大致为十个月,在这十个月里面学到了许多关于git使用方面的知识 经常会逛开源中国水水动态,看看技术帖子学习知 ...

  10. [转] 图解单片机下载程序电路原理之USB转串口线、CH340、PL2303、MAX232芯片的使用

    点击阅读原文 目前为止,我接触单片机已有不少时日,从选择元器件.原理图.PCB.电路硬件调试.软件开发也算小有心得 .单片机软件开发里面第一步当属下载程序了,如果这一步都有问题,那么后面的一切便无从谈 ...