TensorFlow——训练模型的保存和载入的方法介绍
我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入。
1.保存模型
首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起来,具体的代码流程如下
# 前面的是定义好的模型结构
# 前面的代码是模型的定义代码 saver = tf.train.Saver() # 生成saver with tf.Session() as sess:
sess.run(init) # 模型的初始化
#
# 模型的训练代码,当模型训练完毕后,下面就可以对模型进行保存了
#
saver.save(sess, "model/linear") # 当路径不存在时,会自动创建路径
2.载入模型
将模型保存后,在保存的路径中,可以看到生成的模型路径,下面我们就能够加载模型了:
saver = tf.train.Saver() with tf.Session() as sess:
# 可以对模型进行初始化,也可以不进行模型的初始化,因为后面的加载会覆盖之前的
# 初始化操作
sess.run(init) saver.restore(sess, "model/linear")
下面我们以linearmodel为例进行讲解:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5 plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show() X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32) w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias') z = tf.multiply(X, w) + b cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) init = tf.global_variables_initializer() training_epochs = 20
display_step = 2 saver = tf.train.Saver() if __name__ == '__main__':
with tf.Session() as sess:
sess.run(init)
if os.path.exists("model/"):
saver.restore(sess, "model/linear") w_, b_ = sess.run([w, b]) print(" Finished ")
print("W: ", w_, " b: ", b_)
plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
plt.grid(True)
plt.show()
else:
loss_list = []
for epoch in range(training_epochs):
for (x, y) in zip(train_x, train_y):
sess.run(optimizer, feed_dict={X: x, Y: y}) if epoch % display_step == 0:
loss = sess.run(cost, feed_dict={X: x, Y: y})
loss_list.append(loss)
print('Iter: ', epoch, ' Loss: ', loss) w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y}) saver.save(sess, "model/linear") print(" Finished ")
print("W: ", w_, " b: ", b_, " loss: ", loss)
plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
plt.grid(True)
plt.show()
3.查看模型的内容
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
modeldir = 'model/'
print_tensors_in_checkpoint_file(modeldir + 'linear.cpkt', None, True)
在上述使用saver的代码中,我们还可以将参数放入Saver中实现指定存储参数的功能,可以指定存储变量名字和变量的对应关系,如下形式:
saver = tf.train.Saver({'weight_':w, 'bias_':b})
# saver = tf.train.Saver([w, b])
TensorFlow——训练模型的保存和载入的方法介绍的更多相关文章
- TensorFlow 模型的保存与载入
参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...
- 『TensorFlow』模型保存和载入方法汇总
『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...
- caffe 日志保存以及matlab绘制方法(windows以及ubuntu下)
caffe 用matlab解析日志画loss和accuracy clc; clear; % load the log file of caffe model fid = fopen('log-prev ...
- 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()
save = tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...
- 使用TensorFlow训练模型的基本流程
本文已在公众号机器视觉与算法建模发布,转载请联系我. 使用TensorFlow的基本流程 本篇文章将介绍使用tensorflow的训练模型的基本流程,包括制作读取TFRecord,训练和保存模型,读取 ...
- (原+译)pytorch中保存和载入模型
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/8108466.html 参考网址: http://pytorch.org/docs/master/not ...
- matlab工作空间,变量的保存和载入
对于工作空间中变量的保存和载入可以使用save和load命令,详细的使用方法通过help指令获取(help save,help load). 两条指令最常用的情况为: 1.% 保存整个工作空间至指定 ...
- tensorflow模型的保存与恢复,以及ckpt到pb的转化
转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
随机推荐
- Python--day19--os模块
os模块 os模块是与操作系统交互的一个接口 os.makedirs('dirname1/dirname2') 可生成多层递归目录 os.removedirs('dirname1') 若目录为空,则删 ...
- laravel .env 文件的使用
转载地址 http://www.cnblogs.com/Eden-cola/p/DotEnv-in-lumen.html umen 是 laravel 的衍生品,核心功能的使用和 laravel 都 ...
- 浅谈集合框架三、Map常用方法及常用工具类
最近刚学完集合框架,想把自己的一些学习笔记与想法整理一下,所以本篇博客或许会有一些内容写的不严谨或者不正确,还请大神指出.初学者对于本篇博客只建议作为参考,欢迎留言共同学习. 之前有介绍集合框架的体系 ...
- Python3 dir() 函数
Python dir() 函数 描述 dir() 函数不带参数时,返回当前范围内的变量.方法和定义的类型列表:带参数时,返回参数的属性.方法列表.如果参数包含方法__dir__(),该方法将被调用.如 ...
- C# 7.2 通过 in 和 readonly struct 减少方法值复制提高性能
在 C# 7.2 提供了一系列的方法用于方法参数传输的时候减少对结构体的复制从而可以高效使用内存同时提高性能 在开始阅读之前,希望读者对 C# 的值类型.引用类型有比较深刻的认知. 在 C# 中,如果 ...
- dotnet core 输出调试信息到 DebugView 软件
本文告诉大家如何在 dotnet core 输出调试信息到 DebugView 软件 在之前告诉小伙伴,如何在 WPF 输出调试信息到 DebugView 软件,请看文章 WPF 调试 获得追踪输出 ...
- 【12.78%】【codeforces 677D】Vanya and Treasure
time limit per test1.5 seconds memory limit per test256 megabytes inputstandard input outputstandard ...
- ansible核心模块playbook介绍
ansible的playbook采用yaml语法,它简单地实现了json格式的事件描述.yaml之于json就像markdown之于html一样,极度简化了json的书写.在学习ansible pla ...
- Python13_安装、解释器
Linux下大部分系统默认自带python2.x的版本,最常见的是python2.6或python2.7版本,默认的python被系统很多程序所依赖, 比如centos下的yum就是python2写的 ...
- 0003 HTML常用标签(含base、锚点)、路径
学习目标 理解: 相对路径三种形式 应用 排版标签 文本格式化标签 图像标签 链接 相对路径,绝对路径的使用 1. HTML常用标签 首先 HTML和CSS是两种完全不同的语言,我们学的是结构,就只写 ...