TensorFlow 更新频率实在太快,从 1.0 版本正式发布后,很多 API 接口就发生了改变。今天用 TF 训练了一个 CNN 模型,结果在保存模型的时候居然遇到各种问题。Google 搜出来的答案也是莫衷一是,有些回答对 1.0 版本的已经不适用了。后来实在没办法,就翻了墙去官网看了下,结果分分钟就搞定了~囧~。

这篇文章内容不多,主要讲讲 TF v1.0 版本中保存和读取模型的最简单用法,其实就是对官网教程的简要翻译摘抄。

保存和恢复

在 TensorFlow 中,保存和恢复模型最简单的方法就是使用 tf.train.Saver 类。这个类会将变量的保存和恢复操作添加到 TF 的图(graph)中。

Checkpoint 文件

TF 将变量保存在二进制文件中,这个文件包含一个从变量名到 tensor 值的映射。当我们创建一个 Saver 对象的时候,我们可以指定 checkpoint 文件中的变量名。默认会使用变量的 Variable.name 属性。

这一段读起来比较生涩难懂,具体看下面的例子。

保存变量

可以通过创建 Saver 来管理模型内的所有变量。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables.
saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)

恢复变量

可以通过同一个 Saver 对象(指定相同的保存路径)来恢复变量。这种情况下,我们不需要事先初始化变量(即无需调用 tf.global_variables_initializer()

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Do some work with the model
...

例子

下面用我自己的例子解释一下。

首先,我们先定义一个图模型(只截选出变量部分):

    graph = tf.Graph()

    with graph.as_default():
# Input data
# ....省略代码若干 # Variables
layer1_weights = tf.Variable(tf.truncated_normal(
[patch_size, patch_size, image_channels, depth], stddev=0.1), name="layer1_weights")
layer1_biases = tf.Variable(tf.zeros([depth]), name="layer1_biases") layer2_weights = tf.Variable(tf.truncated_normal(
[image_size // 4 * image_size // 4 * depth, num_hidden], stddev=0.1, name="layer2_weights")
)
layer2_biases = tf.Variable(tf.constant(1.0, shape=[num_hidden]), name="layer2_biases") layer3_weights = tf.Variable(tf.truncated_normal(
[num_hidden, num_labels], stddev=0.1, name="layer3_weights"),
)
layer3_biases = tf.Variable(tf.constant(1.0, shape=[num_labels]), name="layer3_biases") def model(data):
#....省略代码若干
return tf.matmul(fc1, layer3_weights) + layer3_biases # Training computation
#....省略代码若干 # Optimizer
optimizer = tf.train.GradientDescentOptimizer(0.05).minimize(loss)

这个模型里的变量其实只有三个网络层的参数:layer1_weightslayer1_biaseslayer2_weightslayer2_biaseslayer3_weightslayer3_biases

然后就是启动会话进行训练:

    with tf.Session(graph=graph) as session:
saver = tf.train.Saver() if loading_model:
saver.restore(session, model_folder + "/" + model_file)
print("Model restored")
else:
tf.global_variables_initializer().run()
print("Initialized") for step in range(num_steps):
# ....省略训练模型的代码 print('Test accuracy: %.1f%%' % accuracy(test_prediction.eval(), test_labels))
save_path = saver.save(session, model_folder + "/" + model_file)
print("Model saved in file: ", save_path)

这段代码是本文的关键,我们先通过 tf.train.Saver() 构造一个 Saver 对象,注意,这一步要在 Session 启动之后执行,否则会抛异常 ValueError("No variables to save"),至少 v1.0 是这样。

通过 Saver,我们可以在模型训练完之后,将参数保存下来。Saver 保存数据的方法十分简单,只要将 session 和文件路径传入 save 函数即可:saver.save(session, model_folder + "/" + model_file)

如果我们一开始想载入本地的模型文件,而不是让 TF 自动初始化训练,则可以通过 Saverrestore 函数读取模型文件,文件路径需要和之前保存的文件路径一致。注意,如果是通过这种方式初始化变量,则不能再调用 tf.global_variables_initializer() 函数。之后,训练或预测的代码不需要改变,TensorFlow 会自动根据模型文件,将你的模型参数初始化。

当然啦,以上都是最基础的用法,只是简单地将所有参数保存下来。更高级的用法,之后如果使用到再继续总结。

参考

TensorFlow学习笔记:保存和读取模型的更多相关文章

  1. tensorflow学习笔记四----------构造线性回归模型

    首先通过构造随机数,模拟数据. import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # 随机生成100 ...

  2. Tensorflow学习笔记2019.01.22

    tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...

  3. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  4. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  5. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  6. TensorFlow学习笔记(一)

    [TensorFlow API](https://www.tensorflow.org/versions/r0.12/how_tos/variable_scope/index.html) Tensor ...

  7. Tensorflow学习笔记2019.01.03

    tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...

  8. Tensorflow学习笔记No.5

    tf.data卷积神经网络综合应用实例 使用tf.data建立自己的数据集,并使用CNN卷积神经网络实现对卫星图像的二分类问题. 数据下载链接:https://pan.baidu.com/s/141z ...

  9. Tensorflow学习笔记No.7

    tf.data与自定义训练综合实例 使用tf.data自定义猫狗数据集,并使用自定义训练实现猫狗数据集的分类. 1.使用tf.data创建自定义数据集 我们使用kaggle上的猫狗数据以及tf.dat ...

随机推荐

  1. Prometheus-自定义Node_Exporter

    标量(Scalar):一个浮点型的数字值 标量只有一个数字,没有时序. 需要注意的是,当使用表达式count(http_requests_total),返回的数据类型,依然是瞬时向量.用户可以通过内置 ...

  2. bzoj2957 奥妙重重的线段树

    https://www.lydsy.com/JudgeOnline/problem.php?id=2957 线段树的query和update竟然还可以结合起来用! 题意:小A的楼房外有一大片施工工地, ...

  3. JAVA-Enum 枚举

    [参考]枚举类名建议带上 Enum 后缀,枚举成员名称需要全大写,单词间用下划线隔开. 说明:枚举其实就是特殊的类,域成员均为常量,且构造方法被默认强制是私有. 正例:枚举名字为 ProcessSta ...

  4. python 面向对象(五)约束 异常处理 MD5 日志处理

    ###############################总结###################### 1.异常处理 raise:抛出异常 try: 可能出现错误代码 execpt 异常类 a ...

  5. python 深浅拷贝 for循环删除

    ###########################总结########################### 1. 基础数据类型补充 大多数的基本数据类型的知识.已经学完了 a='aaaa' ls ...

  6. Hadoop记录-yarn ResourceManager Active频繁易主问题排查(转载)

    一.故障现象 两个节点的ResourceManger频繁在active和standby角色中切换.不断有active易主的告警发出 许多任务的状态没能成功更新,导致一些任务状态卡在NEW_SAVING ...

  7. 【1】【leetcode-99】 恢复二叉搜索树

    (没思路) 99. 恢复二叉搜索树 二叉搜索树中的两个节点被错误地交换. 请在不改变其结构的情况下,恢复这棵树. 示例 1: 输入: [1,3,null,null,2]   1   /  3   \ ...

  8. 解析ArcGis的标注(一)——先看看分数式、假分数式标注是怎样实现的

    该“标注”系列博文的标注引擎使用“标准标注引擎(standard label engine)”,这个概念如不知道,可不理会,ArcGis默认标注引擎就是它. ArcGis的标注表达式支持VBScrip ...

  9. 使用wget命令下载JDK失败(文件特别小)

    问题RT: 我们在网页上下载的时候要点一下 “Accept License Agreement ” ,使用wget下载的时候也需要提交这个 accept,方法如下: wget --no-check-c ...

  10. 关于REST API设计的文章整理

    1. rest api uri设计的7个准则(1)uri末尾不需要出现斜杠/(2)在uri中使用斜杠/表达层级关系(3)在uri中可以使用连接符-提升可读性(4)在uri中不允许出现下划线字符_(5) ...