目录

设置

基于checkpoints的模型保存

通过ModelCheckpoint模块来自动保存数据

手动保存权重

整个模型保存

总体代码


模型可以在训练中或者训练完成后保存。具体文档参考:https://tensorflow.google.cn/tutorials/keras/save_and_restore_models

设置

依赖项设置:

  1. !pip install -q h5py pyyaml

模型建立:

  1. from __future__ import absolute_import, division, print_function
  2. import os
  3. import tensorflow as tf
  4. from tensorflow import keras
  5. tf.__version__
  6. (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
  7. train_labels = train_labels[:1000]
  8. test_labels = test_labels[:1000]
  9. train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
  10. test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
  11. # 模型创建模型
  12. def create_model():
  13. model = tf.keras.models.Sequential([
  14. keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
  15. keras.layers.Dropout(0.2),
  16. keras.layers.Dense(10, activation=tf.nn.softmax)
  17. ])
  18. model.compile(optimizer=tf.keras.optimizers.Adam(),
  19. loss=tf.keras.losses.sparse_categorical_crossentropy,
  20. metrics=['accuracy'])
  21. return model
  22. #创建模型
  23. model = create_model()
  24. model.summary()

基于checkpoints的模型保存

通过ModelCheckpoint模块来自动保存数据

  1. #创建回调函数
  2. cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
  3. save_weights_only=True, #只保存权重
  4. verbose=1)
  5. model = create_model()
  6. model.fit(train_images, train_labels, epochs = 10,
  7. validation_data = (test_images,test_labels),
  8. callbacks = [cp_callback]) #保存模型

通过load_weight读取权重

  1. #对全新没有训练的模型进行预测
  2. model = create_model()
  3. loss, acc = model.evaluate(test_images, test_labels)
  4. print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4%
  5. #载入权重参数后的模型
  6. model.load_weights(checkpoint_path)
  7. loss,acc = model.evaluate(test_images, test_labels)
  8. print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2

手动保存权重

  1. # 保存权重
  2. model.save_weights('./checkpoints/my_checkpoint')
  3. #恢复模型
  4. model = create_model()
  5. model.load_weights('./checkpoints/my_checkpoint')
  6. loss,acc = model.evaluate(test_images, test_labels)
  7. print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00%

整个模型保存

基于keras的HD5文件保存整个模型所有参数,优化器参数等。

  1. #将整个模型保存为HDF5文件
  2. model = create_model()
  3. model.fit(train_images, train_labels, epochs=5)
  4. model.save('my_model.h5')
  5. #载入一个相同的模型
  6. new_model = keras.models.load_model('my_model.h5')
  7. new_model.summary()
  8. loss, acc = new_model.evaluate(test_images, test_labels)
  9. print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%

总体代码

  1. from __future__ import absolute_import, division, print_function
  2. import os
  3. import tensorflow as tf
  4. from tensorflow import keras
  5. tf.__version__
  6. (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
  7. train_labels = train_labels[:1000]
  8. test_labels = test_labels[:1000]
  9. train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
  10. test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
  11. # 模型创建模型
  12. def create_model():
  13. model = tf.keras.models.Sequential([
  14. keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
  15. keras.layers.Dropout(0.2),
  16. keras.layers.Dense(10, activation=tf.nn.softmax)
  17. ])
  18. model.compile(optimizer=tf.keras.optimizers.Adam(),
  19. loss=tf.keras.losses.sparse_categorical_crossentropy,
  20. metrics=['accuracy'])
  21. return model
  22. #创建模型
  23. model = create_model()
  24. model.summary()
  25. checkpoint_path = "training_1/cp.ckpt"
  26. checkpoint_dir = os.path.dirname(checkpoint_path)
  27. '''
  28. #创建回调函数
  29. cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
  30. save_weights_only=True, #只保存权重
  31. verbose=1)
  32. model = create_model()
  33. model.fit(train_images, train_labels, epochs = 10,
  34. validation_data = (test_images,test_labels),
  35. callbacks = [cp_callback]) #保存模型
  36. #对全新没有训练的模型进行预测
  37. model = create_model()
  38. loss, acc = model.evaluate(test_images, test_labels)
  39. print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4%
  40. #载入权重参数后的模型
  41. model.load_weights(checkpoint_path)
  42. loss,acc = model.evaluate(test_images, test_labels)
  43. print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2
  44. # 保存权重
  45. model.save_weights('./checkpoints/my_checkpoint')
  46. #恢复模型
  47. model = create_model()
  48. model.load_weights('./checkpoints/my_checkpoint')
  49. loss,acc = model.evaluate(test_images, test_labels)
  50. print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00%
  51. '''
  52. #将整个模型保存为HDF5文件
  53. model = create_model()
  54. model.fit(train_images, train_labels, epochs=5)
  55. model.save('my_model.h5')
  56. #载入一个相同的模型
  57. new_model = keras.models.load_model('my_model.h5')
  58. new_model.summary()
  59. loss, acc = new_model.evaluate(test_images, test_labels)
  60. print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%

[深度学习] tf.keras入门5-模型保存和载入的更多相关文章

  1. [深度学习] tf.keras入门1-基本函数介绍

    目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...

  2. [深度学习] tf.keras入门4-过拟合和欠拟合

    过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...

  3. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  4. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. SpringBoot (四) - 整合Mybatis,逆向工程,JPA

    1.SpringBoot整合MyBatis 1.1 application.yml # 数据源配置 spring: datasource: driver-class-name: com.mysql.c ...

  2. 关于TP5模板输出时间戳问题--A non well formed numeric value encountered

    某日.因为一个项目.控制器我是这么写的 1 /** 2 * get admin/Picture/index 3 * 显示所有图册信息 4 * @return view 5 */ 6 public fu ...

  3. 23.mixin类源码解析

    mixin类用于提供视图的基本操作行为,注意mixin类提供动作方法,而不是直接定义处理程序方法 例如.get() .post(),这允许更灵活的定义,mixin从rest_framework.mix ...

  4. static 关键字分析

    在java中static 关键字用途很广,可以修饰成员变量 方法 甚至类(静态内部类),这里不分析static 修饰类 static修饰的内容的运行顺序 java的程序执行之前有一个类的加载的过程,在 ...

  5. JS逆向实战3——AESCBC 模式解密

    爬取某省公共资源交易中心 通过抓包数据可知 这个data是我们所需要的数据,但是已经通过加密隐藏起来了 分析 首先这是个json文件,我们可以用请求参数一个一个搜 但是由于我们已经知道了这是个json ...

  6. JS 学习笔记 (六) 函数式编程

    1.函数闭包 1.1 概述 JavaScript采用词法作用域,函数的执行依赖于变量作用域,这个作用域是在函数定义时决定的,而不是函数调用时决定的. 为了实现这种词法作用域,JavaScript函数对 ...

  7. Burpsuite(科学版)安装教程

    前言 BurpSuite是一款用于攻击web 应用程序的集成平台,在安全圈被称作"抓包神器".本文主要讲解 BurpSuite破解版的安装教程. 配置环境变量 BurpSuite是 ...

  8. java反序列化漏洞cc_link_one

    CC-LINK-one 前言 这里也正式进入的java的反序列化漏洞了,简单介绍一下CC是什么借用一些官方的解释:Apache Commons是Apache软件基金会的项目,曾经隶属于Jakarta项 ...

  9. Complementary XOR

    题目链接 题目大意: 给你两个字符串只有01组成,你可以选取区间[l, r],对字符串a在区间里面进行异或操作,对字符串b非区间值进行异或操作,问能否将两个字符串变为全0串.如果可以输出YES, 操作 ...

  10. 【iOS逆向】某车之家sign签名分析

    阅读此文档的过程中遇到任何问题,请关注公众号[移动端Android和iOS开发技术分享]或加QQ群[309580013] 1.目标 分析某车之家sign签名算法的实现 2.操作环境 frida mac ...