TensorFlow——训练模型的保存和载入的方法介绍
我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入。
1.保存模型
首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起来,具体的代码流程如下
# 前面的是定义好的模型结构
- # 前面的代码是模型的定义代码
- saver = tf.train.Saver() # 生成saver
- with tf.Session() as sess:
- sess.run(init) # 模型的初始化
- #
- # 模型的训练代码,当模型训练完毕后,下面就可以对模型进行保存了
- #
- saver.save(sess, "model/linear") # 当路径不存在时,会自动创建路径
2.载入模型
将模型保存后,在保存的路径中,可以看到生成的模型路径,下面我们就能够加载模型了:
- saver = tf.train.Saver()
- with tf.Session() as sess:
- # 可以对模型进行初始化,也可以不进行模型的初始化,因为后面的加载会覆盖之前的
- # 初始化操作
- sess.run(init)
- saver.restore(sess, "model/linear")
下面我们以linearmodel为例进行讲解:
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- import os
- train_x = np.linspace(-5, 3, 50)
- train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5
- plt.plot(train_x, train_y, 'r.')
- plt.grid(True)
- plt.show()
- X = tf.placeholder(dtype=tf.float32)
- Y = tf.placeholder(dtype=tf.float32)
- w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
- b = tf.Variable(tf.random.truncated_normal([1]), name='bias')
- z = tf.multiply(X, w) + b
- cost = tf.reduce_mean(tf.square(Y - z))
- learning_rate = 0.01
- optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
- init = tf.global_variables_initializer()
- training_epochs = 20
- display_step = 2
- saver = tf.train.Saver()
- if __name__ == '__main__':
- with tf.Session() as sess:
- sess.run(init)
- if os.path.exists("model/"):
- saver.restore(sess, "model/linear")
- w_, b_ = sess.run([w, b])
- print(" Finished ")
- print("W: ", w_, " b: ", b_)
- plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
- plt.grid(True)
- plt.show()
- else:
- loss_list = []
- for epoch in range(training_epochs):
- for (x, y) in zip(train_x, train_y):
- sess.run(optimizer, feed_dict={X: x, Y: y})
- if epoch % display_step == 0:
- loss = sess.run(cost, feed_dict={X: x, Y: y})
- loss_list.append(loss)
- print('Iter: ', epoch, ' Loss: ', loss)
- w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
- saver.save(sess, "model/linear")
- print(" Finished ")
- print("W: ", w_, " b: ", b_, " loss: ", loss)
- plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
- plt.grid(True)
- plt.show()
3.查看模型的内容
- from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
- modeldir = 'model/'
- print_tensors_in_checkpoint_file(modeldir + 'linear.cpkt', None, True)
在上述使用saver的代码中,我们还可以将参数放入Saver中实现指定存储参数的功能,可以指定存储变量名字和变量的对应关系,如下形式:
- saver = tf.train.Saver({'weight_':w, 'bias_':b})
- # saver = tf.train.Saver([w, b])
TensorFlow——训练模型的保存和载入的方法介绍的更多相关文章
- TensorFlow 模型的保存与载入
参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...
- 『TensorFlow』模型保存和载入方法汇总
『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...
- caffe 日志保存以及matlab绘制方法(windows以及ubuntu下)
caffe 用matlab解析日志画loss和accuracy clc; clear; % load the log file of caffe model fid = fopen('log-prev ...
- 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()
save = tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...
- 使用TensorFlow训练模型的基本流程
本文已在公众号机器视觉与算法建模发布,转载请联系我. 使用TensorFlow的基本流程 本篇文章将介绍使用tensorflow的训练模型的基本流程,包括制作读取TFRecord,训练和保存模型,读取 ...
- (原+译)pytorch中保存和载入模型
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/8108466.html 参考网址: http://pytorch.org/docs/master/not ...
- matlab工作空间,变量的保存和载入
对于工作空间中变量的保存和载入可以使用save和load命令,详细的使用方法通过help指令获取(help save,help load). 两条指令最常用的情况为: 1.% 保存整个工作空间至指定 ...
- tensorflow模型的保存与恢复,以及ckpt到pb的转化
转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
随机推荐
- H3C 虚拟模板方式配置PPP MP
- Python--day23--初识面向对象复习
面向对象编程是大程序编程思想:
- el-table翻页序号不从1开始(已解决)
法一:赋值方式(亲测有效) <el-table-column type="index" fixed="left" align="center&q ...
- js基础——正则表达式
1.创建方式: var box = new RegExp('box');//第一个参数字符串 var box = new RegExp('box','ig');//第二个参数可选模式修饰符 等同于 v ...
- 如何在iOS手机上进行自动化测试
版权声明:允许转载,但转载必须保留原链接:请勿用作商业或者非法用途 Airtest支持iOS自动化测试,在Mac上为iOS手机部署iOS-Tagent之后,就可以使用AirtestIDE连接设备,像连 ...
- (转)学习C语言基本思路与参考书籍
计算机行业发展非常快,大学里的教育基本都跟不上实际的社会需求.如果你所在的学校还在指定大家使用谭浩强的教材,或使用VC6.0来教大家上机实验,那你不妨看看本文,这里有一些建议可以帮助你不会脱离社会太远 ...
- 云栖大会压轴好戏 阿里云发布视频云V5计划与系列新产品
9月25 - 27日,2019云栖大会如期召开.在大会最后一天下午,阿里云智能视频云分论坛为今年的云栖大会献上了一场精彩的压轴好戏. 视频云V5计划发布 使能生态合作伙伴 会上,阿里云智能研究员金戈进 ...
- appium启动app(ios)
Appium启动APP至少需要7个参数 'platformVersion','deviceName'.'udid'.'bundleId'.'platformName'.'automationNam ...
- 使用app-inspector时报错connect ECONNREFUSED 127.0.0.1:8001的解决方案
在使用 app-inspector -u udid时,报错如图所示 输入如下命令即可解决 npm config set proxy null 再次启动app-inspector即可成功
- EasyMock.replay()有什么用
现在很多项目都使用EasyMock来作为单元测试框架. EasyMock一个方法,基本上是三步:EasyMock.expect().EasyMock.replay().EasyMock.verify( ...