保存和恢复模型(Save and restore models)

官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_models

在训练期间保存检查点

在训练期间或训练结束时自动保存检查点。
权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式)。
可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断

  • 检查点回调用法:创建检查点回调,训练模型并将ModelCheckpoint回调传递给该模型,得到检查点文件集合,用于分享权重
  • 检查点回调选项:该回调提供了多个选项,用于为生成的检查点提供独一无二的名称,以及调整检查点创建频率。

手动保存权重

使用 Model.save_weights 方法即可手动保存权重

保存整个模型

整个模型可以保存到一个文件中,其中包含权重值、模型配置(架构)、优化器配置。
可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
Keras通过检查架构来保存模型,使用HDF5标准提供基本的保存格式。
特别注意:

  • 目前无法保存TensorFlow优化器(来自tf.train)。
  • 使用此类优化器时,需要在加载模型后对其进行重新编译,使优化器的状态变松散。

MNIST数据集

MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集

示例

脚本内容

GitHub:https://github.com/anliven/Hello-AI/blob/master/Google-Learn-and-use-ML/5_save_and_restore_models.py

  1. # coding=utf-8
  2. import tensorflow as tf
  3. from tensorflow import keras
  4. import numpy as np
  5. import pathlib
  6. import os
  7.  
  8. os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
  9. print("# TensorFlow version: {} - tf.keras version: {}".format(tf.VERSION, tf.keras.__version__)) # 查看版本
  10.  
  11. # ### 获取示例数据集
  12.  
  13. ds_path = str(pathlib.Path.cwd()) + "\\datasets\\mnist\\" # 数据集路径
  14. np_data = np.load(ds_path + "mnist.npz") # 加载numpy格式数据
  15. print("# np_data keys: ", list(np_data.keys())) # 查看所有的键
  16.  
  17. # 加载mnist数据集
  18. (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data(path=ds_path + "mnist.npz")
  19. train_labels = train_labels[:1000]
  20. test_labels = test_labels[:1000]
  21. train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
  22. test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
  23.  
  24. # ### 定义模型
  25. def create_model():
  26. model = tf.keras.models.Sequential([
  27. keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
  28. keras.layers.Dropout(0.2),
  29. keras.layers.Dense(10, activation=tf.nn.softmax)
  30. ]) # 构建一个简单的模型
  31. model.compile(optimizer=tf.keras.optimizers.Adam(),
  32. loss=tf.keras.losses.sparse_categorical_crossentropy,
  33. metrics=['accuracy'])
  34. return model
  35.  
  36. mod = create_model()
  37. mod.summary()
  38.  
  39. # ### 在训练期间保存检查点
  40.  
  41. # 检查点回调用法
  42. checkpoint_path = "training_1/cp.ckpt"
  43. checkpoint_dir = os.path.dirname(checkpoint_path) # 检查点存放目录
  44. cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
  45. save_weights_only=True,
  46. verbose=2) # 创建检查点回调
  47. model1 = create_model()
  48. model1.fit(train_images, train_labels,
  49. epochs=10,
  50. validation_data=(test_images, test_labels),
  51. verbose=0,
  52. callbacks=[cp_callback] # 将ModelCheckpoint回调传递给该模型
  53. ) # 训练模型,将创建一个TensorFlow检查点文件集合,这些文件在每个周期结束时更新
  54.  
  55. model2 = create_model() # 创建一个未经训练的全新模型(与原始模型架构相同,才能分享权重)
  56. loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集进行评估
  57. print("# Untrained model2, accuracy: {:5.2f}%".format(100 * acc)) # 未训练模型的表现(准确率约为10%)
  58.  
  59. model2.load_weights(checkpoint_path) # 从检查点加载权重
  60. loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集,重新进行评估
  61. print("# Restored model2, accuracy: {:5.2f}%".format(100 * acc)) # 模型表现得到大幅提升
  62.  
  63. # 检查点回调选项
  64. checkpoint_path2 = "training_2/cp-{epoch:04d}.ckpt" # 使用“str.format”方式为每个检查点设置唯一名称
  65. checkpoint_dir2 = os.path.dirname(checkpoint_path)
  66. cp_callback2 = tf.keras.callbacks.ModelCheckpoint(checkpoint_path2,
  67. verbose=1,
  68. save_weights_only=True,
  69. period=5 # 每隔5个周期保存一次检查点
  70. ) # 创建检查点回调
  71. model3 = create_model()
  72. model3.fit(train_images, train_labels,
  73. epochs=50,
  74. callbacks=[cp_callback2], # 将ModelCheckpoint回调传递给该模型
  75. validation_data=(test_images, test_labels),
  76. verbose=0) # 训练一个新模型,每隔5个周期保存一次检查点并设置唯一名称
  77. latest = tf.train.latest_checkpoint(checkpoint_dir2)
  78. print("# latest checkpoint: {}".format(latest)) # 查看最新的检查点
  79.  
  80. model4 = create_model() # 重新创建一个全新的模型
  81. loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集进行评估
  82. print("# Untrained model4, accuracy: {:5.2f}%".format(100 * acc)) # 未训练模型的表现(准确率约为10%)
  83.  
  84. model4.load_weights(latest) # 加载最新的检查点
  85. loss, acc = model4.evaluate(test_images, test_labels) #
  86. print("# Restored model4, accuracy: {:5.2f}%".format(100 * acc)) # 模型表现得到大幅提升
  87.  
  88. # ### 手动保存权重
  89. model5 = create_model()
  90. model5.fit(train_images, train_labels,
  91. epochs=10,
  92. validation_data=(test_images, test_labels),
  93. verbose=0) # 训练模型
  94. model5.save_weights('./training_3/my_checkpoint') # 手动保存权重
  95.  
  96. model6 = create_model()
  97. loss, acc = model6.evaluate(test_images, test_labels)
  98. print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc))
  99. model6.load_weights('./training_3/my_checkpoint')
  100. loss, acc = model6.evaluate(test_images, test_labels)
  101. print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc))
  102.  
  103. # ### 保存整个模型
  104. model7 = create_model()
  105. model7.fit(train_images, train_labels, epochs=5)
  106. model7.save('my_model.h5') # 保存整个模型到HDF5文件
  107.  
  108. model8 = keras.models.load_model('my_model.h5') # 重建完全一样的模型,包括权重和优化器
  109. model8.summary()
  110. loss, acc = model8.evaluate(test_images, test_labels)
  111. print("Restored model8, accuracy: {:5.2f}%".format(100 * acc))

运行结果

  1. C:\Users\anliven\AppData\Local\conda\conda\envs\mlcc\python.exe D:/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML/5_save_and_restore_models.py
  2. # TensorFlow version: 1.12.0 - tf.keras version: 2.1.6-tf
  3. # np_data keys: ['x_test', 'x_train', 'y_train', 'y_test']
  4. _________________________________________________________________
  5. Layer (type) Output Shape Param #
  6. =================================================================
  7. dense (Dense) (None, 512) 401920
  8. _________________________________________________________________
  9. dropout (Dropout) (None, 512) 0
  10. _________________________________________________________________
  11. dense_1 (Dense) (None, 10) 5130
  12. =================================================================
  13. Total params: 407,050
  14. Trainable params: 407,050
  15. Non-trainable params: 0
  16. _________________________________________________________________
  17.  
  18. Epoch 00001: saving model to training_1/cp.ckpt
  19. Epoch 00002: saving model to training_1/cp.ckpt
  20. Epoch 00003: saving model to training_1/cp.ckpt
  21. Epoch 00004: saving model to training_1/cp.ckpt
  22. Epoch 00005: saving model to training_1/cp.ckpt
  23. Epoch 00006: saving model to training_1/cp.ckpt
  24. Epoch 00007: saving model to training_1/cp.ckpt
  25. Epoch 00008: saving model to training_1/cp.ckpt
  26. Epoch 00009: saving model to training_1/cp.ckpt
  27. Epoch 00010: saving model to training_1/cp.ckpt
  28.  
  29. 32/1000 [..............................] - ETA: 3s
  30. 1000/1000 [==============================] - 0s 140us/step
  31. # Untrained model2, accuracy: 8.20%
  32.  
  33. 32/1000 [..............................] - ETA: 0s
  34. 1000/1000 [==============================] - 0s 40us/step
  35. # Restored model2, accuracy: 86.40%
  36.  
  37. Epoch 00005: saving model to training_2/cp-0005.ckpt
  38. Epoch 00010: saving model to training_2/cp-0010.ckpt
  39. Epoch 00015: saving model to training_2/cp-0015.ckpt
  40. Epoch 00020: saving model to training_2/cp-0020.ckpt
  41. Epoch 00025: saving model to training_2/cp-0025.ckpt
  42. Epoch 00030: saving model to training_2/cp-0030.ckpt
  43. Epoch 00035: saving model to training_2/cp-0035.ckpt
  44. Epoch 00040: saving model to training_2/cp-0040.ckpt
  45. Epoch 00045: saving model to training_2/cp-0045.ckpt
  46. Epoch 00050: saving model to training_2/cp-0050.ckpt
  47.  
  48. # latest checkpoint: training_1\cp.ckpt
  49.  
  50. 32/1000 [..............................] - ETA: 3s
  51. 1000/1000 [==============================] - 0s 140us/step
  52. # Untrained model4, accuracy: 86.40%
  53.  
  54. 32/1000 [..............................] - ETA: 2s
  55. 1000/1000 [==============================] - 0s 110us/step
  56. # Restored model4, accuracy: 86.40%
  57.  
  58. 32/1000 [..............................] - ETA: 5s
  59. 1000/1000 [==============================] - 0s 220us/step
  60. # Restored model6, accuracy: 18.20%
  61.  
  62. 32/1000 [..............................] - ETA: 0s
  63. 1000/1000 [==============================] - 0s 40us/step
  64. # Restored model6, accuracy: 87.40%
  65. Epoch 1/5
  66.  
  67. 32/1000 [..............................] - ETA: 9s - loss: 2.4141 - acc: 0.0625
  68. 320/1000 [========>.....................] - ETA: 0s - loss: 1.8229 - acc: 0.4469
  69. 576/1000 [================>.............] - ETA: 0s - loss: 1.4932 - acc: 0.5694
  70. 864/1000 [========================>.....] - ETA: 0s - loss: 1.2624 - acc: 0.6481
  71. 1000/1000 [==============================] - 1s 530us/step - loss: 1.1978 - acc: 0.6620
  72. Epoch 2/5
  73.  
  74. 32/1000 [..............................] - ETA: 0s - loss: 0.5490 - acc: 0.8750
  75. 320/1000 [========>.....................] - ETA: 0s - loss: 0.4832 - acc: 0.8594
  76. 576/1000 [================>.............] - ETA: 0s - loss: 0.4630 - acc: 0.8715
  77. 864/1000 [========================>.....] - ETA: 0s - loss: 0.4356 - acc: 0.8808
  78. 1000/1000 [==============================] - 0s 200us/step - loss: 0.4298 - acc: 0.8790
  79. Epoch 3/5
  80.  
  81. 32/1000 [..............................] - ETA: 0s - loss: 0.1681 - acc: 0.9688
  82. 320/1000 [========>.....................] - ETA: 0s - loss: 0.2826 - acc: 0.9437
  83. 576/1000 [================>.............] - ETA: 0s - loss: 0.2774 - acc: 0.9340
  84. 832/1000 [=======================>......] - ETA: 0s - loss: 0.2740 - acc: 0.9327
  85. 1000/1000 [==============================] - 0s 200us/step - loss: 0.2781 - acc: 0.9280
  86. Epoch 4/5
  87.  
  88. 32/1000 [..............................] - ETA: 0s - loss: 0.1589 - acc: 0.9688
  89. 288/1000 [=======>......................] - ETA: 0s - loss: 0.2169 - acc: 0.9410
  90. 608/1000 [=================>............] - ETA: 0s - loss: 0.2186 - acc: 0.9457
  91. 864/1000 [========================>.....] - ETA: 0s - loss: 0.2231 - acc: 0.9479
  92. 1000/1000 [==============================] - 0s 200us/step - loss: 0.2164 - acc: 0.9480
  93. Epoch 5/5
  94.  
  95. 32/1000 [..............................] - ETA: 0s - loss: 0.1095 - acc: 1.0000
  96. 352/1000 [=========>....................] - ETA: 0s - loss: 0.1631 - acc: 0.9744
  97. 608/1000 [=================>............] - ETA: 0s - loss: 0.1671 - acc: 0.9638
  98. 864/1000 [========================>.....] - ETA: 0s - loss: 0.1545 - acc: 0.9688
  99. 1000/1000 [==============================] - 0s 210us/step - loss: 0.1538 - acc: 0.9670
  100. _________________________________________________________________
  101. Layer (type) Output Shape Param #
  102. =================================================================
  103. dense_14 (Dense) (None, 512) 401920
  104. _________________________________________________________________
  105. dropout_7 (Dropout) (None, 512) 0
  106. _________________________________________________________________
  107. dense_15 (Dense) (None, 10) 5130
  108. =================================================================
  109. Total params: 407,050
  110. Trainable params: 407,050
  111. Non-trainable params: 0
  112. _________________________________________________________________
  113.  
  114. 32/1000 [..............................] - ETA: 3s
  115. 1000/1000 [==============================] - 0s 150us/step
  116. Restored model8, accuracy: 86.10%
  117.  
  118. Process finished with exit code 0

生成的文件

  1. anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
  2. $ ll training_1
  3. total 1601
  4. -rw-r--r-- 1 anliven 197121 71 5 5 23:36 checkpoint
  5. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:36 cp.ckpt.data-00000-of-00001
  6. -rw-r--r-- 1 anliven 197121 647 5 5 23:36 cp.ckpt.index
  7.  
  8. anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
  9. $
  10.  
  11. anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
  12. $
  13.  
  14. anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
  15. $
  16.  
  17. anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
  18. $ ls -l training_1
  19. total 1601
  20. -rw-r--r-- 1 anliven 197121 71 5 5 23:36 checkpoint
  21. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:36 cp.ckpt.data-00000-of-00001
  22. -rw-r--r-- 1 anliven 197121 647 5 5 23:36 cp.ckpt.index
  23.  
  24. anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
  25. $ ls -l training_2
  26. total 16001
  27. -rw-r--r-- 1 anliven 197121 81 5 5 23:37 checkpoint
  28. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:36 cp-0005.ckpt.data-00000-of-00001
  29. -rw-r--r-- 1 anliven 197121 647 5 5 23:36 cp-0005.ckpt.index
  30. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:36 cp-0010.ckpt.data-00000-of-00001
  31. -rw-r--r-- 1 anliven 197121 647 5 5 23:36 cp-0010.ckpt.index
  32. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:36 cp-0015.ckpt.data-00000-of-00001
  33. -rw-r--r-- 1 anliven 197121 647 5 5 23:36 cp-0015.ckpt.index
  34. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:36 cp-0020.ckpt.data-00000-of-00001
  35. -rw-r--r-- 1 anliven 197121 647 5 5 23:36 cp-0020.ckpt.index
  36. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:36 cp-0025.ckpt.data-00000-of-00001
  37. -rw-r--r-- 1 anliven 197121 647 5 5 23:36 cp-0025.ckpt.index
  38. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:37 cp-0030.ckpt.data-00000-of-00001
  39. -rw-r--r-- 1 anliven 197121 647 5 5 23:37 cp-0030.ckpt.index
  40. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:37 cp-0035.ckpt.data-00000-of-00001
  41. -rw-r--r-- 1 anliven 197121 647 5 5 23:37 cp-0035.ckpt.index
  42. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:37 cp-0040.ckpt.data-00000-of-00001
  43. -rw-r--r-- 1 anliven 197121 647 5 5 23:37 cp-0040.ckpt.index
  44. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:37 cp-0045.ckpt.data-00000-of-00001
  45. -rw-r--r-- 1 anliven 197121 647 5 5 23:37 cp-0045.ckpt.index
  46. -rw-r--r-- 1 anliven 197121 1631508 5 5 23:37 cp-0050.ckpt.data-00000-of-00001
  47. -rw-r--r-- 1 anliven 197121 647 5 5 23:37 cp-0050.ckpt.index
  48.  
  49. anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
  50. $ ls -l training_3
  51. total 1601
  52. -rw-r--r-- 1 anliven 197121 83 5 5 23:37 checkpoint
  53. -rw-r--r-- 1 anliven 197121 1631517 5 5 23:37 my_checkpoint.data-00000-of-00001
  54. -rw-r--r-- 1 anliven 197121 647 5 5 23:37 my_checkpoint.index
  55.  
  56. anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
  57. $ ls -l my_model.h5
  58. -rw-r--r-- 1 anliven 197121 4909112 5 5 23:37 my_model.h5

问题处理

问题描述:出现如下告警信息。

  1. WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x00000280FD318780>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.
  2.  
  3. Consider using a TensorFlow optimizer from `tf.train`.

问题处理:

正常告警,对脚本运行和结果无影响,暂不关注。

AI - TensorFlow - 示例05:保存和恢复模型的更多相关文章

  1. 第六节,TensorFlow编程基础案例-保存和恢复模型(中)

    在我们使用TensorFlow的时候,有时候需要训练一个比较复杂的网络,比如后面的AlexNet,ResNet,GoogleNet等等,由于训练这些网络花费的时间比较长,因此我们需要保存模型的参数. ...

  2. 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)

    学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...

  3. AI - TensorFlow - 示例01:基本分类

    基本分类 基本分类(Basic classification):https://www.tensorflow.org/tutorials/keras/basic_classification Fash ...

  4. AI - TensorFlow - 示例03:基本回归

    基本回归 回归(Regression):https://www.tensorflow.org/tutorials/keras/basic_regression 主要步骤:数据部分 获取数据(Get t ...

  5. AI - TensorFlow - 示例02:影评文本分类

    影评文本分类 文本分类(Text classification):https://www.tensorflow.org/tutorials/keras/basic_text_classificatio ...

  6. AI - TensorFlow - 示例04:过拟合与欠拟合

    过拟合与欠拟合(Overfitting and underfitting) 官网示例:https://www.tensorflow.org/tutorials/keras/overfit_and_un ...

  7. TensorFlow学习笔记:保存和读取模型

    TensorFlow 更新频率实在太快,从 1.0 版本正式发布后,很多 API 接口就发生了改变.今天用 TF 训练了一个 CNN 模型,结果在保存模型的时候居然遇到各种问题.Google 搜出来的 ...

  8. 保存与恢复变量和模型,tensorflow官方文档阅读笔记

    官方中文文档的网址先贴出来:https://tensorflow.google.cn/programmers_guide/saved_model tf.train.Saver 类别提供了保存和恢复模型 ...

  9. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

随机推荐

  1. vs 在高分屏下开发 winform 配置

    一.窗体控件大小 第一种方法:使用网格避免整除误差 在选项中将Windows窗体设计器的LayoutMode(布局模式)改成SnapToGrid(对齐到网格),并将Default Grid Cell ...

  2. 【转发】c#做端口转发程序支持正向连接和反向链接

    可以通过中转server来连接sql server,连接的时候用ip,port,不是冒号,是逗号 但试过local port 21想连接AS400的FTP却不成功...为咩涅... https://w ...

  3. C# Ninject使用

    Ninject是一个IOC容器,用来解决程序中组件的耦合问题,它的目的在于做到最少配置.简单来讲就是 为我们选择一个想要的类来处理事务. 百度百科的解释:一个快如闪电.超轻量级的基于.Net平台的依赖 ...

  4. C++中unique函数的用法总结

    个人感觉,unique是STL中很实用的函数之一,需要#include,下面来简单介绍一下它的作用. unique的作用是"去掉"容器中相邻元素的重复元素,这里去掉要加一个引号,为 ...

  5. python 之 列表常用 操作

  6. luogu_P3313 [SDOI2014]旅行

    传送门 Solution 第二次学习可持久化线段树 打了一道裸题来练习一下-- 对于每个宗教都可以开一个主席树 基础操作 树剖lca Code  #include<bits/stdc++.h&g ...

  7. Leetcode84. 柱状图中最大的矩形(单调栈)

    84. 柱状图中最大的矩形 前置 单调栈 做法 连续区间组成的矩形,是看最短的那一块,求出每一块左边第一个小于其高度的位置,右边也同理,此块作为最短限制.需要两次单调栈 单调栈维护递增区间,每次不满足 ...

  8. SQL题(子文章)(持续更新)

    -----> 总文章 入口 文章目录 [-----> 总文章 入口](https://blog.csdn.net/qq_37214567/article/details/90174445) ...

  9. elasticsearch routing

    当索引一个文档的时候,文档会被存储到一个主分片中. Elasticsearch 如何知道一个文档应该存放到哪个分片中呢?当我们创建文档时,它如何决定这个文档应当被存储在分片 1 还是分片 2 中呢?首 ...

  10. 小福bbs-冲刺日志(第三天)

    [小福bbs-冲刺日志(第三天)] 这个作业属于哪个课程 班级链接 这个作业要求在哪里 作业要求的链接 团队名称 小福bbs 这个作业的目标 前端交付部分页面给后端 ,后端开始完成部分功能 作业的正文 ...