AI - TensorFlow - 示例05:保存和恢复模型
保存和恢复模型(Save and restore models)
官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_models
在训练期间保存检查点
在训练期间或训练结束时自动保存检查点。
权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式)。
可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断
- 检查点回调用法:创建检查点回调,训练模型并将ModelCheckpoint回调传递给该模型,得到检查点文件集合,用于分享权重
- 检查点回调选项:该回调提供了多个选项,用于为生成的检查点提供独一无二的名称,以及调整检查点创建频率。
手动保存权重
使用 Model.save_weights 方法即可手动保存权重
保存整个模型
整个模型可以保存到一个文件中,其中包含权重值、模型配置(架构)、优化器配置。
可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
Keras通过检查架构来保存模型,使用HDF5标准提供基本的保存格式。
特别注意:
- 目前无法保存TensorFlow优化器(来自tf.train)。
- 使用此类优化器时,需要在加载模型后对其进行重新编译,使优化器的状态变松散。
MNIST数据集
MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集
- 官方下载地址:http://yann.lecun.com/exdb/mnist/
- 包含70000张手写数字的灰度图片,其中60000张为训练图像和10000张为测试图像
- 每一张图片都是28*28个像素点大小的灰度图像
- https://keras.io/datasets/#mnist-database-of-handwritten-digits
- TensorFlow:https://www.tensorflow.org/api_docs/python/tf/keras/datasets/mnist
示例
脚本内容
GitHub:https://github.com/anliven/Hello-AI/blob/master/Google-Learn-and-use-ML/5_save_and_restore_models.py
# coding=utf-8
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pathlib
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
print("# TensorFlow version: {} - tf.keras version: {}".format(tf.VERSION, tf.keras.__version__)) # 查看版本 # ### 获取示例数据集 ds_path = str(pathlib.Path.cwd()) + "\\datasets\\mnist\\" # 数据集路径
np_data = np.load(ds_path + "mnist.npz") # 加载numpy格式数据
print("# np_data keys: ", list(np_data.keys())) # 查看所有的键 # 加载mnist数据集
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data(path=ds_path + "mnist.npz")
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 # ### 定义模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
]) # 构建一个简单的模型
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return model mod = create_model()
mod.summary() # ### 在训练期间保存检查点 # 检查点回调用法
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path) # 检查点存放目录
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=2) # 创建检查点回调
model1 = create_model()
model1.fit(train_images, train_labels,
epochs=10,
validation_data=(test_images, test_labels),
verbose=0,
callbacks=[cp_callback] # 将ModelCheckpoint回调传递给该模型
) # 训练模型,将创建一个TensorFlow检查点文件集合,这些文件在每个周期结束时更新 model2 = create_model() # 创建一个未经训练的全新模型(与原始模型架构相同,才能分享权重)
loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集进行评估
print("# Untrained model2, accuracy: {:5.2f}%".format(100 * acc)) # 未训练模型的表现(准确率约为10%) model2.load_weights(checkpoint_path) # 从检查点加载权重
loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集,重新进行评估
print("# Restored model2, accuracy: {:5.2f}%".format(100 * acc)) # 模型表现得到大幅提升 # 检查点回调选项
checkpoint_path2 = "training_2/cp-{epoch:04d}.ckpt" # 使用“str.format”方式为每个检查点设置唯一名称
checkpoint_dir2 = os.path.dirname(checkpoint_path)
cp_callback2 = tf.keras.callbacks.ModelCheckpoint(checkpoint_path2,
verbose=1,
save_weights_only=True,
period=5 # 每隔5个周期保存一次检查点
) # 创建检查点回调
model3 = create_model()
model3.fit(train_images, train_labels,
epochs=50,
callbacks=[cp_callback2], # 将ModelCheckpoint回调传递给该模型
validation_data=(test_images, test_labels),
verbose=0) # 训练一个新模型,每隔5个周期保存一次检查点并设置唯一名称
latest = tf.train.latest_checkpoint(checkpoint_dir2)
print("# latest checkpoint: {}".format(latest)) # 查看最新的检查点 model4 = create_model() # 重新创建一个全新的模型
loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集进行评估
print("# Untrained model4, accuracy: {:5.2f}%".format(100 * acc)) # 未训练模型的表现(准确率约为10%) model4.load_weights(latest) # 加载最新的检查点
loss, acc = model4.evaluate(test_images, test_labels) #
print("# Restored model4, accuracy: {:5.2f}%".format(100 * acc)) # 模型表现得到大幅提升 # ### 手动保存权重
model5 = create_model()
model5.fit(train_images, train_labels,
epochs=10,
validation_data=(test_images, test_labels),
verbose=0) # 训练模型
model5.save_weights('./training_3/my_checkpoint') # 手动保存权重 model6 = create_model()
loss, acc = model6.evaluate(test_images, test_labels)
print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc))
model6.load_weights('./training_3/my_checkpoint')
loss, acc = model6.evaluate(test_images, test_labels)
print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc)) # ### 保存整个模型
model7 = create_model()
model7.fit(train_images, train_labels, epochs=5)
model7.save('my_model.h5') # 保存整个模型到HDF5文件 model8 = keras.models.load_model('my_model.h5') # 重建完全一样的模型,包括权重和优化器
model8.summary()
loss, acc = model8.evaluate(test_images, test_labels)
print("Restored model8, accuracy: {:5.2f}%".format(100 * acc))
运行结果
C:\Users\anliven\AppData\Local\conda\conda\envs\mlcc\python.exe D:/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML/5_save_and_restore_models.py
# TensorFlow version: 1.12.0 - tf.keras version: 2.1.6-tf
# np_data keys: ['x_test', 'x_train', 'y_train', 'y_test']
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 512) 401920
_________________________________________________________________
dropout (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________ Epoch 00001: saving model to training_1/cp.ckpt
Epoch 00002: saving model to training_1/cp.ckpt
Epoch 00003: saving model to training_1/cp.ckpt
Epoch 00004: saving model to training_1/cp.ckpt
Epoch 00005: saving model to training_1/cp.ckpt
Epoch 00006: saving model to training_1/cp.ckpt
Epoch 00007: saving model to training_1/cp.ckpt
Epoch 00008: saving model to training_1/cp.ckpt
Epoch 00009: saving model to training_1/cp.ckpt
Epoch 00010: saving model to training_1/cp.ckpt 32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 140us/step
# Untrained model2, accuracy: 8.20% 32/1000 [..............................] - ETA: 0s
1000/1000 [==============================] - 0s 40us/step
# Restored model2, accuracy: 86.40% Epoch 00005: saving model to training_2/cp-0005.ckpt
Epoch 00010: saving model to training_2/cp-0010.ckpt
Epoch 00015: saving model to training_2/cp-0015.ckpt
Epoch 00020: saving model to training_2/cp-0020.ckpt
Epoch 00025: saving model to training_2/cp-0025.ckpt
Epoch 00030: saving model to training_2/cp-0030.ckpt
Epoch 00035: saving model to training_2/cp-0035.ckpt
Epoch 00040: saving model to training_2/cp-0040.ckpt
Epoch 00045: saving model to training_2/cp-0045.ckpt
Epoch 00050: saving model to training_2/cp-0050.ckpt # latest checkpoint: training_1\cp.ckpt 32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 140us/step
# Untrained model4, accuracy: 86.40% 32/1000 [..............................] - ETA: 2s
1000/1000 [==============================] - 0s 110us/step
# Restored model4, accuracy: 86.40% 32/1000 [..............................] - ETA: 5s
1000/1000 [==============================] - 0s 220us/step
# Restored model6, accuracy: 18.20% 32/1000 [..............................] - ETA: 0s
1000/1000 [==============================] - 0s 40us/step
# Restored model6, accuracy: 87.40%
Epoch 1/5 32/1000 [..............................] - ETA: 9s - loss: 2.4141 - acc: 0.0625
320/1000 [========>.....................] - ETA: 0s - loss: 1.8229 - acc: 0.4469
576/1000 [================>.............] - ETA: 0s - loss: 1.4932 - acc: 0.5694
864/1000 [========================>.....] - ETA: 0s - loss: 1.2624 - acc: 0.6481
1000/1000 [==============================] - 1s 530us/step - loss: 1.1978 - acc: 0.6620
Epoch 2/5 32/1000 [..............................] - ETA: 0s - loss: 0.5490 - acc: 0.8750
320/1000 [========>.....................] - ETA: 0s - loss: 0.4832 - acc: 0.8594
576/1000 [================>.............] - ETA: 0s - loss: 0.4630 - acc: 0.8715
864/1000 [========================>.....] - ETA: 0s - loss: 0.4356 - acc: 0.8808
1000/1000 [==============================] - 0s 200us/step - loss: 0.4298 - acc: 0.8790
Epoch 3/5 32/1000 [..............................] - ETA: 0s - loss: 0.1681 - acc: 0.9688
320/1000 [========>.....................] - ETA: 0s - loss: 0.2826 - acc: 0.9437
576/1000 [================>.............] - ETA: 0s - loss: 0.2774 - acc: 0.9340
832/1000 [=======================>......] - ETA: 0s - loss: 0.2740 - acc: 0.9327
1000/1000 [==============================] - 0s 200us/step - loss: 0.2781 - acc: 0.9280
Epoch 4/5 32/1000 [..............................] - ETA: 0s - loss: 0.1589 - acc: 0.9688
288/1000 [=======>......................] - ETA: 0s - loss: 0.2169 - acc: 0.9410
608/1000 [=================>............] - ETA: 0s - loss: 0.2186 - acc: 0.9457
864/1000 [========================>.....] - ETA: 0s - loss: 0.2231 - acc: 0.9479
1000/1000 [==============================] - 0s 200us/step - loss: 0.2164 - acc: 0.9480
Epoch 5/5 32/1000 [..............................] - ETA: 0s - loss: 0.1095 - acc: 1.0000
352/1000 [=========>....................] - ETA: 0s - loss: 0.1631 - acc: 0.9744
608/1000 [=================>............] - ETA: 0s - loss: 0.1671 - acc: 0.9638
864/1000 [========================>.....] - ETA: 0s - loss: 0.1545 - acc: 0.9688
1000/1000 [==============================] - 0s 210us/step - loss: 0.1538 - acc: 0.9670
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_14 (Dense) (None, 512) 401920
_________________________________________________________________
dropout_7 (Dropout) (None, 512) 0
_________________________________________________________________
dense_15 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________ 32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 150us/step
Restored model8, accuracy: 86.10% Process finished with exit code 0
生成的文件
anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ll training_1
total 1601
-rw-r--r-- 1 anliven 197121 71 5月 5 23:36 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp.ckpt.index anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_1
total 1601
-rw-r--r-- 1 anliven 197121 71 5月 5 23:36 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp.ckpt.index anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_2
total 16001
-rw-r--r-- 1 anliven 197121 81 5月 5 23:37 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0005.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0005.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0010.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0010.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0015.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0015.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0020.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0020.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0025.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0025.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0030.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0030.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0035.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0035.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0040.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0040.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0045.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0045.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0050.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0050.ckpt.index anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_3
total 1601
-rw-r--r-- 1 anliven 197121 83 5月 5 23:37 checkpoint
-rw-r--r-- 1 anliven 197121 1631517 5月 5 23:37 my_checkpoint.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 my_checkpoint.index anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l my_model.h5
-rw-r--r-- 1 anliven 197121 4909112 5月 5 23:37 my_model.h5
问题处理
问题描述:出现如下告警信息。
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x00000280FD318780>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`.
问题处理:
正常告警,对脚本运行和结果无影响,暂不关注。
AI - TensorFlow - 示例05:保存和恢复模型的更多相关文章
- 第六节,TensorFlow编程基础案例-保存和恢复模型(中)
在我们使用TensorFlow的时候,有时候需要训练一个比较复杂的网络,比如后面的AlexNet,ResNet,GoogleNet等等,由于训练这些网络花费的时间比较长,因此我们需要保存模型的参数. ...
- 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)
学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...
- AI - TensorFlow - 示例01:基本分类
基本分类 基本分类(Basic classification):https://www.tensorflow.org/tutorials/keras/basic_classification Fash ...
- AI - TensorFlow - 示例03:基本回归
基本回归 回归(Regression):https://www.tensorflow.org/tutorials/keras/basic_regression 主要步骤:数据部分 获取数据(Get t ...
- AI - TensorFlow - 示例02:影评文本分类
影评文本分类 文本分类(Text classification):https://www.tensorflow.org/tutorials/keras/basic_text_classificatio ...
- AI - TensorFlow - 示例04:过拟合与欠拟合
过拟合与欠拟合(Overfitting and underfitting) 官网示例:https://www.tensorflow.org/tutorials/keras/overfit_and_un ...
- TensorFlow学习笔记:保存和读取模型
TensorFlow 更新频率实在太快,从 1.0 版本正式发布后,很多 API 接口就发生了改变.今天用 TF 训练了一个 CNN 模型,结果在保存模型的时候居然遇到各种问题.Google 搜出来的 ...
- 保存与恢复变量和模型,tensorflow官方文档阅读笔记
官方中文文档的网址先贴出来:https://tensorflow.google.cn/programmers_guide/saved_model tf.train.Saver 类别提供了保存和恢复模型 ...
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...
随机推荐
- Linux Shell 如何获取参数
$# 是传给脚本的参数个数 $0 是脚本本身的名字 $1 是传递给该shell脚本的第一个参数 $2 是传递给该shell脚本的第二个参数 $@ 是传给脚本的所有参数的列表 $* 是以一个单字符串显示 ...
- PostgreSQL的 pg_hba.conf 配置参数详解
pg_hba.conf 配置详解 该文件位于初始化安装的数据库目录下 编辑 pg_hba.conf 配置文件 postgres@clw-db1:/pgdata/9.6/poc/data> v ...
- C语言for 循环 9*9 实现九九乘法表
#include <stdio.h> int main(void) { //for循环实现9*9乘法表 /* 1*1=1 1*2=2 2*2=4 1*3=3 2*3=6 3*3=9 */ ...
- loj #10131
抽离题意 求删除一条树边和一条非树边后将图分成不连通的两部分的方案数 对于一棵树,再加入一条边就会产生环.若只有一个环,说明只加入了一条非树边 (x, y),记 lca 为 l, 那么 对于任意一条 ...
- codevs 1683 车厢重组
1683 车厢重组 时间限制: 1 s 空间限制: 1000 KB 题目等级 : 白银 Silver 题目描述 Description 在一个旧式的火车站旁边有一座桥,其桥面可以绕河中心的桥 ...
- C++2.0新特性(一)——<特性认知、__cplusplus宏开启、Variadic Templates 、左右值区分>
一.新特性介绍 2.0新特性包含了C++11和C++14的部分 1.2 启用测试c++11功能 C++ 标准特定版本的支持,/Zc:__cplusplus 编译器选项启用 __cplusplus 预处 ...
- Hadoop hadoop balancer配置
hadoop版本:2.9.2 1.带宽的设置参数: dfs.datanode.balance.bandwidthPerSec 默认值 10m 2.datanode之间数据块的传输线程大小:dfs. ...
- [树链剖分]BZOJ3589动态树
题目描述 别忘了这是一棵动态树, 每时每刻都是动态的. 小明要求你在这棵树上维护两种事件 事件0: 这棵树长出了一些果子, 即某个子树中的每个节点都会长出K个果子. 事件1: 小明希望你求出几条树枝上 ...
- Diffie-Hellman算法简介
一.DH算法是一种密钥交换协议,它可以让双方在不泄漏密钥的情况下协商出一个密钥来. DH算法基于数学原理,比如小明和小红想要协商一个密钥,可以这么做: . 小明先选一个素数和一个底数,例如,素数p=, ...
- Python3基础 global 在函数内部对全局变量进行修改
Python : 3.7.3 OS : Ubuntu 18.04.2 LTS IDE : pycharm-community-2019.1.3 ...