Keras保存模型并载入模型继续训练
我们以MNIST手写数字识别为例
- import numpy as np
- from keras.datasets import mnist
- from keras.utils import np_utils
- from keras.models import Sequential
- from keras.layers import Dense
- from keras.optimizers import SGD
- # 载入数据
- (x_train,y_train),(x_test,y_test) = mnist.load_data()
- # (60000,28,28)
- print('x_shape:',x_train.shape)
- # (60000)
- print('y_shape:',y_train.shape)
- # (60000,28,28)->(60000,784)
- x_train = x_train.reshape(x_train.shape[0],-1)/255.0
- x_test = x_test.reshape(x_test.shape[0],-1)/255.0
- # 换one hot格式
- y_train = np_utils.to_categorical(y_train,num_classes=10)
- y_test = np_utils.to_categorical(y_test,num_classes=10)
- # 创建模型,输入784个神经元,输出10个神经元
- model = Sequential([
- Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
- ])
- # 定义优化器
- sgd = SGD(lr=0.2)
- # 定义优化器,loss function,训练过程中计算准确率
- model.compile(
- optimizer = sgd,
- loss = 'mse',
- metrics=['accuracy'],
- )
- # 训练模型
- model.fit(x_train,y_train,batch_size=64,epochs=5)
- # 评估模型
- loss,accuracy = model.evaluate(x_test,y_test)
- print('\ntest loss',loss)
- print('accuracy',accuracy)
- # 保存模型
- model.save('model.h5') # HDF5文件,pip install h5py
载入初次训练的模型,再训练
- import numpy as np
- from keras.datasets import mnist
- from keras.utils import np_utils
- from keras.models import Sequential
- from keras.layers import Dense
- from keras.optimizers import SGD
- from keras.models import load_model
- # 载入数据
- (x_train,y_train),(x_test,y_test) = mnist.load_data()
- # (60000,28,28)
- print('x_shape:',x_train.shape)
- # (60000)
- print('y_shape:',y_train.shape)
- # (60000,28,28)->(60000,784)
- x_train = x_train.reshape(x_train.shape[0],-1)/255.0
- x_test = x_test.reshape(x_test.shape[0],-1)/255.0
- # 换one hot格式
- y_train = np_utils.to_categorical(y_train,num_classes=10)
- y_test = np_utils.to_categorical(y_test,num_classes=10)
- # 载入模型
- model = load_model('model.h5')
- # 评估模型
- loss,accuracy = model.evaluate(x_test,y_test)
- print('\ntest loss',loss)
- print('accuracy',accuracy)
- # 训练模型
- model.fit(x_train,y_train,batch_size=64,epochs=2)
- # 评估模型
- loss,accuracy = model.evaluate(x_test,y_test)
- print('\ntest loss',loss)
- print('accuracy',accuracy)
- # 保存参数,载入参数
- model.save_weights('my_model_weights.h5')
- model.load_weights('my_model_weights.h5')
- # 保存网络结构,载入网络结构
- from keras.models import model_from_json
- json_string = model.to_json()
- model = model_from_json(json_string)
- print(json_string)
关于compile和load_model()的使用顺序
这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?
compile做什么?
compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。
如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。
是否需要多次编译?
除非我们要更改其中之一:损失函数、优化器 / 学习率、度量
又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。
再次compile的后果?
如果再次编译模型,将会丢失优化器状态.
这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。
Keras保存模型并载入模型继续训练的更多相关文章
- keras 保存模型
转自:https://blog.csdn.net/u010159842/article/details/54407745,感谢分享! 我们不推荐使用pickle或cPickle来保存Keras模型 你 ...
- TensorFlow保存和载入模型
首先定义一个tf.train.Saver类: saver = tf.train.Saver(max_to_keep=1) 其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后 ...
- (原+译)pytorch中保存和载入模型
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/8108466.html 参考网址: http://pytorch.org/docs/master/not ...
- Keras框架下的保存模型和加载模型
在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...
- Keras入门(六)模型训练实时可视化
在北京做某个项目的时候,客户要求能够对数据进行训练.预测,同时能导出模型,还有在页面上显示训练的进度.前面的几个要求都不难实现,但在页面上显示训练进度当时笔者并没有实现. 本文将会分享如何在K ...
- TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化
线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...
- keras系列︱Sequential与Model模型、keras基本结构功能(一)
引自:http://blog.csdn.net/sinat_26917383/article/details/72857454 中文文档:http://keras-cn.readthedocs.io/ ...
- 三分钟快速上手TensorFlow 2.0 (下)——模型的部署 、大规模训练、加速
前文:三分钟快速上手TensorFlow 2.0 (中)——常用模块和模型的部署 TensorFlow 模型导出 使用 SavedModel 完整导出模型 不仅包含参数的权值,还包含计算的流程(即计算 ...
- 【Keras篇】---Keras初始,两种模型构造方法,利用keras实现手写数字体识别
一.前述 Keras 适合快速体验 ,keras的设计是把大量内部运算都隐藏了,用户始终可以用theano或tensorflow的语句来写扩展功能并和keras结合使用. 二.安装 Pip insta ...
随机推荐
- 分析Android APK-砸壳-Fdex2
砸壳的工具千千万,但是FDex2 是最有能耐的,我尝试过各种壳,都是秒砸的.特别说明一下,360的壳,oncreated 方法还是空的,但是其他大部分内容还是有的,反正是可以参考一下的. 安装环境: ...
- 搭建Vue开发环境
1.安装Node.js 安装包下载地址: https://nodejs.org/en/ 安装时可以选择是否自动安装必要的工具,如Chocolatey.Python2,这里我选择了自动安装 Node.j ...
- solo升级以及自动化更新的方法
使用solo过程总涉及到更新问题,所以就在这里把solo更新的方法总结一下.希望能给小伙伴们一些帮助.如何选择更新方法主要是跟你的部署方式有关,如果你是通过 docker方式进行部署,那么你可以还可以 ...
- 推荐一个好用的行内可编辑的table组件 vxe-table
项目中有一个需要用户点击table单元格可编辑的需求,由于博主用的是elementUI,element组件内实现可编辑,用过的小伙伴都知道,非常麻烦,后来博主在浏览组件的时候发现了 一款非常好用的ta ...
- Hive 时间函数总结【转】
1.日期函数UNIX时间戳转日期函数: from_unixtime语法:from_unixtime(bigint unixtime[, stringformat]) 返回值: string说明: 转化 ...
- openldap数据备份还原
数据备份[root@Server ~]# slapcat -n 2 -l /root/ldapbackup_ilanni.ldif脚本 ----- #!/bin/bash # 备份脚本 PATH=&q ...
- 【学习笔记】《Java编程思想》 第1~7章
第一章 对象导论 对整书的概要. 略读. 第二章 一切都是对象 创建一个引用,指向一个对象. 安全的做法:创建一个引用的同时便进行初始化. 对象存储的地方:1)寄存器:这是最快的存储区,因为它位于不同 ...
- 关于thymeleaf中th:if的使用
运用于判断表达式中时,关系判断使用 gt / ge / eq / lt / le / ne (即:使用缩写) gt: great than(大于)> ge: great equal(大于等于)& ...
- 51和32共用keil5方法
链接:https://blog.csdn.net/qq_41639829/article/details/81813992 看这位道友写的方法挺好的,可以实现共用,不过有点小问题是,安装 以后,用32 ...
- 第04组 Beta冲刺(3/4)
队名:斗地组 组长博客:地址 作业博客:Beta冲刺(3/4) 各组员情况 林涛(组长) 过去两天完成了哪些任务: 1.分配展示任务 2.收集各个组员的进度 3.写博客 展示GitHub当日代码/文档 ...