tensorflow tf.train.Supervisor作用
tf.train.Supervisor可以简化编程,避免显示地实现restore操作.通过一个例子看.
import tensorflow as tf
import numpy as np
import os
log_path = r"D:\Source\model\linear"
log_name = "linear.ckpt"
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# Before starting, initialize the variables. We will 'run' this first.
saver = tf.train.Saver()
init = tf.global_variables_initializer()
# Launch the graph.
sess = tf.Session()
sess.run(init)
if len(os.listdir(log_path)) != 0: # 已经有模型直接读取
saver.restore(sess, os.path.join(log_path, log_name))
for step in range(201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
saver.save(sess, os.path.join(log_path, log_name))
这段代码是对tensorflow官网上的demo做一个微小的改动.如果模型已经存在,就先读取模型接着训练.tf.train.Supervisor可以简化这个步骤.看下面的代码.
import tensorflow as tf
import numpy as np
import os
log_path = r"D:\Source\model\supervisor"
log_name = "linear.ckpt"
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
sv = tf.train.Supervisor(logdir=log_path, init_op=init) # logdir用来保存checkpoint和summary
saver = sv.saver # 创建saver
with sv.managed_session() as sess: # 会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
for i in range(201):
sess.run(train)
if i % 20 == 0:
print(i, sess.run(W), sess.run(b))
saver.save(sess, os.path.join(log_path, log_name))
sv = tf.train.Supervisor(logdir=log_path, init_op=init)会判断模型是否存在.如果存在,会自动读取模型.不用显式地调用restore.
参考资料
tensorflow tf.train.Supervisor作用的更多相关文章
- tensorflow|tf.train.slice_input_producer|tf.train.Coordinator|tf.train.start_queue_runners
#### ''' tf.train.slice_input_producer :定义样本放入文件名队列的方式[迭代次数,是否乱序],但此时文件名队列还没有真正写入数据 slice_input_prod ...
- tensorflow中的Supervisor
tf.train.Supervisor()可以帮我们简化一些事情,可以保存模型参数和Summary,它有以下的作用: 1)自动去checkpoint加载数据或初始化数据 ,因此我们就不需要手动初始化或 ...
- Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析
觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...
- tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数(转)
tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...
- tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数
tensorflow数据读取机制 tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数 ...
- TensorFlow 实战(二)—— tf.train(优化算法)
Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...
- 【转载】 tensorflow中 tf.train.slice_input_producer 和 tf.train.batch 函数
原文地址: https://blog.csdn.net/dcrmg/article/details/79776876 ----------------------------------------- ...
- tensorflow数据读取机制tf.train.slice_input_producer 和 tf.train.batch 函数
tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算. 具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程 ...
- 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑
import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...
随机推荐
- js在字符串中加入一段字符串
在这个功能的实现主要是slice()方法的掌握 arrayObject.slice(start,end) start 必需.规定从何处开始选取.如果是负数,那么它规定从数组尾部开始算起的位置.也就是说 ...
- 在linux (centos)上使用puppeteer实现网页截图
1.安装nodejs和npm # 下载解压 wget -c https://nodejs.org/dist/v8.9.1/node-v8.9.1-linux-x64.tar.xz tar -xvf n ...
- nyoj 37-回文字符串(reverse, 动态规划, lcs)
37-回文字符串 内存限制:64MB 时间限制:3000ms Special Judge: No accepted:10 submit:17 题目描述: 所谓回文字符串,就是一个字符串,从左到右读和从 ...
- requests保存图片
1.创建07_save_jpg.py文件 import requests #发送请求respone = requests.get("https://www.baidu.com/img/bd_ ...
- list,tuple,dict,set 思维导图整理
- python requirements.txt的创建及使用
要求文件(requirements.txt)是安装包的依赖项及版本的记录文件. pip: 创建 (venv) $ pip freeze >requirements.txt 使用 (venv) $ ...
- Redis 4.0鲜为人知的功能将加速您的应用程序
来源:Redislabs 作者:Kyle Davis 翻译:Kevin (公众号:中间件小哥) Redis 4.0给Redis生态带来了一个惊人的功能:Modules(模块).Modules是Redi ...
- FPGA基础(verilog语言)——语法篇(续1)
上一篇文章提到了FPGA中一个模块基本结构,这篇文章开始介绍语法. 首先,我们学习一门语言都要从这门语言的单词学起,所以verilog中的关键词都有哪些呢?看下面: A:always.assign B ...
- 【论文阅读】Binary Multi-View Clustering
文章地址:https://ieeexplore.ieee.org/document/8387526 出自:IEEE Trans. on Pattern Analysis and Machine Int ...
- Protues7.8仿真软件有中文路径无法正常运行怎么办?
Protues7.8是一款功能强大的单片机仿真软件,在我们的学习生活中经常会用的到,在装软件时明明已经装好了,却不能报错跳出两行红字,让人心痛. 一般都是因为账户名字是中文的问题,这个软件对中文不兼容 ...