我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

  • Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
  • 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
  • 为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

示例代码:

import tensorflow as tf
import numpy as np
from six.moves import xrange x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 2 w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) #isTrain = True
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = 'test/' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i + 1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b)) y_result = sess.run(y_predict, feed_dict={x: np.reshape(4, (1, 1))})
print(y_result)

2.1 训练阶段

使用Saver.save()方法保存模型:
  1. sess:表示当前会话,当前会话记录了当前的变量值
  2. checkpoint_dir + 'model.ckpt':表示存储的文件名
  3. global_step:表示当前是第几步

训练完成后,当前目录底下会多出5个文件。

    打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。

2.2测试阶段

    测试阶段使用saver.restore()方法恢复变量:
  1. sess:表示当前会话,之前保存的结果将被加载入这个会话
  2. ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

运行结果如下图所示,加载了之前训练的参数w和b的结果

TensorFlow Saver的使用方法的更多相关文章

  1. 解决tensorflow Saver.restore()无效的问题

    解决tensorflow 的 Saver.restore()无法从本地读取变量的问题 最近做tensorflow 手写数字识别的时候遇到了一个问题,Saver的restore()方法无法从本地恢复变量 ...

  2. TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

      TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...

  3. 安装tensorflow的最简单方法(Ubuntu 16.04 && CentOS)

    先说点题外话:在用anaconda安装很多次tensorflow失败之后,我放弃了,如果你遇到这样的问题:Traceback (most recent call last)-如果不是因为pip版本,就 ...

  4. TensorFlow 常用函数与方法

    摘要:本文主要对tf的一些常用概念与方法进行描述. tf函数 TensorFlow 将图形定义转换成分布式执行的操作, 以充分利用可用的计算资源(如 CPU 或 GPU.一般你不需要显式指定使用 CP ...

  5. TensorFlow之Varibale 使用方法

    ------------------------------------------- 转载请注明: 来自博客园 xiuyuxuanchen 地址:http://www.cnblogs.com/gre ...

  6. 『TensorFlow』正则化添加方法整理

    一.基础正则化函数 tf.contrib.layers.l1_regularizer(scale, scope=None) 返回一个用来执行L1正则化的函数,函数的签名是func(weights).  ...

  7. TensorFlow——共享变量的使用方法

    1.共享变量用途 在构建模型时,需要使用tf.Variable来创建一个变量(也可以理解成节点).当两个模型一起训练时,一个模型需要使用其他模型创建的变量,比如,对抗网络中的生成器和判别器.如果使用t ...

  8. 两种Tensorflow模型保存的方法

    在Tensorflow中,有两种保存模型的方法:一种是Checkpoint,另一种是Protobuf,也就是PB格式: 一. Checkpoint方法: 1.保存时使用方法: tf.train.Sav ...

  9. 吴裕雄 python 神经网络——TensorFlow ckpt文件保存方法

    import tensorflow as tf v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1)) v2 = tf.Variable(t ...

随机推荐

  1. make命令回显Makefile执行脚本命令

    /********************************************************************** * make命令回显Makefile执行脚本命令 * 说 ...

  2. ss server端配置

    关于ss server的配置,可以参考一个网址 关于服务器的购买可以上VIRMACH购买 和本地安装ss类似,首先安装ss,pip install shadowsocks 配置服务器参数,vim /e ...

  3. doc转html

    Doc2Html  li标签没闭合 在线Word转HTML  style样式

  4. 使用NATS替换NSQ为后台服务解耦

    简介 满世界的后台都在向微服务架构发展,我对微服务的理解是将一个复杂的业务分拆为多个服务,由多个服务协作完成一个服务:在后台微服务架构时需要考虑高可用.一致性等问题,也要考虑在实现上.编码上的复杂程度 ...

  5. C#线程、前后台线程

    C#线程.前后台线程 本文提供全流程,中文翻译. Chinar 坚持将简单的生活方式,带给世人!(拥有更好的阅读体验 -- 高分辨率用户请根据需求调整网页缩放比例) Chinar -- 心分享.心创新 ...

  6. 牛客HJ浇花。

    我也不知道这是什么类型的题,算是简单模拟吧.但是有个方法很赞. 开两个数组,一个模拟花,一个记录不同浇花次数花的数量: 要找浇水的次数,那么记每次浇水的开头和结尾就行了,a—b;那么f[a]++;f[ ...

  7. SQLI DUMB SERIES-3

    less3 输入?id=1' 说明输入的id旁边加了单引号和括号('id'),直接在1后面加入“ ') ”,闭合前面的单引号和括号. 方法同less1相同. 例如:查询PHP版本和数据库名字 ?id= ...

  8. c++输入输出流加速器

      发现同样是cin,cout,其他大佬(orz)的耗时短很多.看了他们的代码,我发现他们加了一个很神奇的匿名函数(Lambda捕获)提高了cin,cout效率,因此去百度了解了一下.以下是大佬所使用 ...

  9. (1)HTML的组成(什么是标签、指令、转义字符、数据、标签字符表)

    html的组成:标签+指令+转义字符+数据 1.标签 <>内的,以字母开头,可以结合合法字符(- 或者数字),能被浏览器解析的符号 <!DOCTYPE html> #这个是系统 ...

  10. 实验吧—Web——WP之 FALSE

    打开链接,点击源码按钮,我们开始分析源码: 在这源码中我们能看出 如果名字等于密码就输出:你的名字不能等于密码 如果名字的哈希值等于密码的哈希值,那么就会输出flag 这就意味着我们要上传两个值,值不 ...