tf用 tf.train.Saver类来实现神经网络模型的保存和读取。无论保存还是读取,都首先要创建saver对象。

用saver对象的save方法保存模型

保存的是所有变量

save(
sess,
save_path,
global_step=None,  
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True
)

保存模型需要session,初始化变量

用法示例

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, "Model/model.ckpt", global_step=3)

输出

1. global_step 放在文件名后面,起个标记作用

2. save方法输出4个文件

  // checkpoint 里面是一堆路径,model_checkpoint_path 记录了最新模型的路径,all_model_checkpoint_paths 记录了之前模型的路径

  // model.ckpt-3.data-00000-of-00001 存放的是模型参数

  // model.ckpt-3.meta 存放的是计算图

3. 最多只能保存近5次模型,比如我们迭代100次,每次保存一下,最后只留下了最近的5次。

用saver对象的restore方法加载模型

加载的是所有变量,以name为准,假如保存的模型中有变量叫 a ,value是2,那么在加载后,即使重新建立变量a,并赋其他value,其value仍然是2

restore(
sess,
save_path
)

加载模型需要session,不需要初始化变量

用法示例(接前例)

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint
result = v1 + v2 saver = tf.train.Saver()
#
with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
print(sess.run(result)) # [ 3.]

1. 重新给 name为 v2的变量 赋值,其结果仍然是3,说明加载了之前的v2

2. 新建name为 v22 的变量,报错, 在保存的模型中没找到v2 。说明寻找变量以name为准,不以变量名为准

继续做如下尝试

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint
result = v1 + v3 saver = tf.train.Saver()
#
with tf.Session() as sess:
# sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint
saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
# sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint
print(sess.run(result)) # [ 3.]

1. 新建name为v22的变量v3,仍然报错,说明新的变量没有被接受

2. 在加载模型前初始化v3,仍然报错,加载模型后初始化v3,仍然报错,这说明在加载的模型中不接受新的变量。

继续尝试

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint
result = v1 + v3 saver = tf.train.Saver()
#
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint
print(sess.run(v3)) # [7.]
saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint
print(sess.run(result)) # [ 3.]

在加载模型前初始化变量,正确输出,但在加载后,报错,证实了我上面的说法,“不接受新的变量”

总结:

1. 模型加载加载的是所有变量,以name为准

2. 模型加载后不接受任何新的变量

3. 在加载模型时需要重新定义计算图上的所有节点,但是变量无需初始化

加载计算图

直接加载计算图就无需重新定义计算图上的节点

用法示例

saver = tf.train.import_meta_graph("Model/model.ckpt-3.meta")

with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt-3") # 注意路径写法
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [3.]
# print(sess.run(sess.graph.get_tensor_by_name('add:0'))) # [3.]

重命名变量

在加载模型时不接受新的变量,这会造成很多麻烦。

为解决这个问题,加载模型时可以给变量重命名。

用法示例

u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = u1 + u2 # 若直接声明Saver类对象,会报错变量找不到
# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}
# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中
saver = tf.train.Saver({"v1": u1, "v2": u2}) with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt-3")
print(sess.run(result)) # [ 3.]

注意重命名格式  老变量的name: 新变量名

参考资料:

https://blog.csdn.net/marsjhao/article/details/72829635

https://blog.csdn.net/shuzfan/article/details/79197432

tf 模型保存的更多相关文章

  1. TensorFlow:tf.train.Saver()模型保存与恢复

    1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...

  2. tensorflow的tf.train.Saver()模型保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...

  3. TensorFlow构建卷积神经网络/模型保存与加载/正则化

    TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...

  4. 10 Tensorflow模型保存与读取

    我们的模型训练出来想给别人用,或者是我今天训练不完,明天想接着训练,怎么办?这就需要模型的保存与读取.看代码: import tensorflow as tf import numpy as np i ...

  5. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  6. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  7. TensorFlow进阶(六)---模型保存与恢复、自定义命令行参数

    模型保存与恢复.自定义命令行参数. 在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用.模型的保存和恢复也是通过tf.train.Saver类去实现,它主 ...

  8. Sklearn,TensorFlow,keras模型保存与读取

    一.sklearn模型保存与读取 1.保存 from sklearn.externals import joblib from sklearn import svm X = [[0, 0], [1, ...

  9. TensorFlow模型保存和提取方法

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...

随机推荐

  1. ADO.NET Entity Framework学习笔记(3)ObjectContext

    ADO.NET Entity Framework学习笔记(3)ObjectContext对象[转]   说明 ObjectContext提供了管理数据的功能 Context操作数据 AddObject ...

  2. 架构探险笔记12-安全控制框架Shiro

    什么是Shiro Shiro是Apache组织下的一款轻量级Java安全框架.Spring Security相对来说比较臃肿. 官网 Shiro提供的服务 1.Authentication(认证) 2 ...

  3. CentOS6.8下实现配置配额

    CentOS6.8下实现配置配额 Linux系统是支持多用户的,即允许多个用户同时使用linux系统,普通用户在/home/目录下均有自己的家目录,在默认状态下,各个用户可以在自己的家目录下任意创建文 ...

  4. 前端Vue之vue的基本操作

    1.1 vue.js的快速入门使用 vue.js是目前前端web开发最流行的工具库之一,由尤雨溪在2014年2月发布的. 另外几个常见的工具库:react.js /angular.js 官方网站: 中 ...

  5. 重写TreeMap的compare方法处理配置表

    需要处理的配置表如下: 接上一篇的优化,接着优化,优化代码如下:  这段代码的关键在于重写TreeMap的compare方法. 关于如何重写TreeMap的compare方法,以及返回值代表的意义,可 ...

  6. [Fiddler] ReadResponse() failed: The server did not return a complete response for this request. Server returned 0 bytes.

    待解决 [Fiddler] ReadResponse() failed: The server did not return a complete response for this request. ...

  7. array 数组去重 过滤空值等方法

    去重操作 第一种方式, ES 6 引入的新书据结构 Set 本身就是没有重复数据的, 可以使用这个数据结构来转化数组.时间复杂度 O(n) 123456 const target = [];const ...

  8. 深入理解php内核

    目录 第一部分 基本原理 第一章 准备工作和背景知识 第一节 环境搭建 第二节 源码布局及阅读方法 第三节 常用代码 第四节 小结 第二章 用户代码的执行 第一节 PHP生命周期 第二节 从SAPI开 ...

  9. Ping 的TTL理解

    http://www.webkaka.com/tutorial/zhanzhang/2017/061570/ 根据自己的扩展重新整理了一下,虽然不是运维,想了解一点东西就希望了解清楚. 一.含义 “T ...

  10. MSSQL2012中SQL调优(SQL TUNING)时CBO支持和常用的hints

    虽然当前各关系库CBO都已经非常先进和智能,但因为关系库理论和实现上的限制,CBO在特殊场景下也会给出次优甚至存在严重性能问题的执行计划,而这些场景中,有一部分只能或适合通过关系库提供的hints来进 ...