模型的保存与加载一般有三种模式:save/load weights(最干净、最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有的状态都保存起来),saved_model(更通用的方式,以固定模型格式保存,该格式是各种语言通用的)

具体使用方法如下:

        # 保存模型
model.save_weights('./checkpoints/my_checkpoint')
# 加载模型
model = keras.create_model()
model.load_weights('./checkpoints/my_checkpoint')

示例:

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28 * 28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max()) db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz) sample = next(iter(db))
print(sample[0].shape, sample[1].shape) network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary() network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
) network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2) network.evaluate(ds_val) network.save_weights('weights.ckpt')
print('saved weights.')
del network network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.load_weights('weights.ckpt')
print('loaded weights!')
network.evaluate(ds_val)

运行效果如下:

可以看到保存前后的精度和损失差距不大,这是由于神经网络的运算过程中会有很多不确定因子,这些不确定因子不会通过save_weights方法保存,要想保存前后运行结果一致,就需要完整的保存网络模型。即model.save方法

使用方法如下:

# 模型保存
network.save('model.h5')
print('saved total model.')
# 模型加载
print('load model from file')
network = tf.keras.models.load_model('model.h5')
# 评估
network.evaluate(x_val,y_val)

除了这种方法之外,tensorflow还支持保存为标准的可以给其他语言使用的模型,使用saved_model即可

使用方法如下:

tf.saved_model.save(m,'/tmp/saved_model/')
imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x=tf.ones([1,28,28,3])))

tensorflow模型的保存与加载的更多相关文章

  1. tensorflow 之模型的保存与加载(二)

    上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...

  2. tensorflow 之模型的保存与加载(三)

    前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...

  3. tensorflow 之模型的保存与加载(一)

    怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...

  4. Python之TensorFlow的模型训练保存与加载-3

    一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...

  5. Tensorflow 模型持久化saver及加载图结构

    主要内容: 1. 直接保存,加载模型; (可以指定加载,保存的var_list) 2. 加载,保存指定变量的模型 3. slim加载模型使用 4. 加载模型图结构和参数等 tensorflow 恢复部 ...

  6. (sklearn)机器学习模型的保存与加载

    需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...

  7. pytorch_模型参数-保存,加载,打印

    1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...

  8. pytorch 中模型的保存与加载,增量训练

     让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...

  9. fashion_mnist多分类训练,两种模型的保存与加载

    from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python ...

随机推荐

  1. codeforces 1020 C Elections(枚举+贪心)

    题意: 有 n个人,m个党派,第i个人开始想把票投给党派pi,而如果想让他改变他的想法需要花费ci元.你现在是党派1,问你最少花多少钱使得你的党派得票数大于其它任意党派. n,m<3000 思路 ...

  2. 基于python2+selenium3+pytest4的UI自动化框架

    环境:Python2.7.10, selenium3.141.0, pytest4.6.6, pytest-html1.22.0, Windows-7-6.1.7601-SP1 特点:- 二次封装了s ...

  3. 在centos6.3下安装php的Xdebug

    首先下载一个xdebug http://www.xdebug.org/docs/ 官网上有windos版本和linux源码版本的,我们下载一个源码包xdebug-2.2.5.tgz 然后进入安装流程 ...

  4. Spring学习笔记:自动创建Proxy

    为什么需要自动创建Proxy 手动为所有需要代理的类用ProxyFactoryBean创建代理Proxy需要大量的配置. 这样如果需要代理的类很多,配置就很繁琐,而且也不便于xml配置的维护. 因此S ...

  5. 都闪开,不用任何游戏引擎,html也能开发格斗游戏

    html格斗游戏,对打游戏 不用引擎,不用画布canvas,不用任何库(包括jquery), 原生div+img组件,开发格斗游戏游戏教程视频已经上传 b站:https://www.bilibili. ...

  6. 记一次IE浏览器做图片预览的坑

    随便写写吧,被坑死了 IE 10 及 IE10以上,可以使用FileReader的方式,来做图片预览,加载本地图片显示 IE 9 8 7 没有FileReader这个对象,所以只能使用微软自己的东西来 ...

  7. 开源堡垒机jumpserver的安装

    开源跳板机jumpserver安装 简介 Jumpserver 是全球首款完全开源的堡垒机, 使用GNU GPL v2.0 开源协议, 是符合4A 的专业运维审计系统 Jumpserver 使用Pyt ...

  8. 非对称加密 秘钥登录 https

    非对称加密简介: 对称加密算法在加密和解密时使用的是同一个秘钥:而非对称加密算法需要两个密钥来进行加密和解密,这两个秘钥是公开密钥(public key,简称公钥)私有密钥(private key,简 ...

  9. [Redis-CentOS7]Redis打开远程连接(十) Could not connect to Redis at 127.0.0.1:6379: Connection refused

    通过网络无法访问Redis redis-cli 172.16.1.111 Could not connect to Redis at 127.0.0.1:6379: Connection refuse ...

  10. TCP协议可靠性是如何保证之滑动窗口,超时重发,序列号确认应答信号

    原创文章首发于公众号:「码农富哥」,欢迎收藏和关注,如转载请注明出处! TCP 是一种提供可靠性交付的协议. 也就是说,通过 TCP 连接传输的数据,无差错.不丢失.不重复.并且按序到达. 但是在网络 ...