tensorflow 1.0 学习:模型的保存与恢复(Saver)
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。
模型保存,先要创建一个Saver对象:如
- saver=tf.train.Saver()
在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:
- saver=tf.train.Saver(max_to_keep=0)
但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。
当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即
- saver=tf.train.Saver(max_to_keep=1)
创建完saver对象后,就可以保存训练好的模型了,如:
- saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
看一个mnist实例:
- # -*- coding: utf-8 -*-
- """
- Created on Sun Jun 4 10:29:48 2017
- @author: Administrator
- """
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
- x = tf.placeholder(tf.float32, [None, 784])
- y_=tf.placeholder(tf.int32,[None,])
- dense1 = tf.layers.dense(inputs=x,
- units=1024,
- activation=tf.nn.relu,
- kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
- kernel_regularizer=tf.nn.l2_loss)
- dense2= tf.layers.dense(inputs=dense1,
- units=512,
- activation=tf.nn.relu,
- kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
- kernel_regularizer=tf.nn.l2_loss)
- logits= tf.layers.dense(inputs=dense2,
- units=10,
- activation=None,
- kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
- kernel_regularizer=tf.nn.l2_loss)
- loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
- train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
- correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)
- acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- sess=tf.InteractiveSession()
- sess.run(tf.global_variables_initializer())
- saver=tf.train.Saver(max_to_keep=1)
- for i in range(100):
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
- val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
- print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
- saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
- sess.close()
代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).
在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。
- saver=tf.train.Saver(max_to_keep=1)
- max_acc=0
- for i in range(100):
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
- val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
- print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
- if val_acc>max_acc:
- max_acc=val_acc
- saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
- sess.close()
如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。
- saver=tf.train.Saver(max_to_keep=3)
- max_acc=0
- f=open('ckpt/acc.txt','w')
- for i in range(100):
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
- val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
- print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
- f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
- if val_acc>max_acc:
- max_acc=val_acc
- saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
- f.close()
- sess.close()
模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
- model_file=tf.train.latest_checkpoint('ckpt/')
- saver.restore(sess,model_file)
则程序后半段代码我们可以改为:
- sess=tf.InteractiveSession()
- sess.run(tf.global_variables_initializer())
- is_train=False
- saver=tf.train.Saver(max_to_keep=3)
- #训练阶段
- if is_train:
- max_acc=0
- f=open('ckpt/acc.txt','w')
- for i in range(100):
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
- val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
- print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
- f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
- if val_acc>max_acc:
- max_acc=val_acc
- saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
- f.close()
- #验证阶段
- else:
- model_file=tf.train.latest_checkpoint('ckpt/')
- saver.restore(sess,model_file)
- val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
- print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
- sess.close()
标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。
整个源程序:
- # -*- coding: utf-8 -*-
- """
- Created on Sun Jun 4 10:29:48 2017
- @author: Administrator
- """
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
- x = tf.placeholder(tf.float32, [None, 784])
- y_=tf.placeholder(tf.int32,[None,])
- dense1 = tf.layers.dense(inputs=x,
- units=1024,
- activation=tf.nn.relu,
- kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
- kernel_regularizer=tf.nn.l2_loss)
- dense2= tf.layers.dense(inputs=dense1,
- units=512,
- activation=tf.nn.relu,
- kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
- kernel_regularizer=tf.nn.l2_loss)
- logits= tf.layers.dense(inputs=dense2,
- units=10,
- activation=None,
- kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
- kernel_regularizer=tf.nn.l2_loss)
- loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
- train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
- correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)
- acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- sess=tf.InteractiveSession()
- sess.run(tf.global_variables_initializer())
- is_train=True
- saver=tf.train.Saver(max_to_keep=3)
- #训练阶段
- if is_train:
- max_acc=0
- f=open('ckpt/acc.txt','w')
- for i in range(100):
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
- val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
- print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
- f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
- if val_acc>max_acc:
- max_acc=val_acc
- saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
- f.close()
- #验证阶段
- else:
- model_file=tf.train.latest_checkpoint('ckpt/')
- saver.restore(sess,model_file)
- val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
- print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
- sess.close()
参考文章:http://blog.csdn.net/u011500062/article/details/51728830
tensorflow 1.0 学习:模型的保存与恢复(Saver)的更多相关文章
- TensorFlow笔记-模型的保存,恢复,实现线性回归
模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...
- tensorflow 1.0 学习:模型的保存与恢复
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...
- [翻译] Tensorflow模型的保存与恢复
翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...
- tensorflow模型的保存与恢复
1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...
- tensorflow模型的保存与恢复,以及ckpt到pb的转化
转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...
- tensorflow 1.0 学习:用CNN进行图像分类
tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化. 任务:花卉分类 版本:tensorflow 1 ...
- tensorflow 1.0 学习:用别人训练好的模型来进行图像分类
谷歌在大型图像数据库ImageNet上训练好了一个Inception-v3模型,这个模型我们可以直接用来进来图像分类. 下载地址:https://storage.googleapis.com/down ...
- tensorflow 1.0 学习:十图详解tensorflow数据读取机制
本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...
- Tensorflow Learning1 模型的保存和恢复
CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...
随机推荐
- 文件操作命令(del)
del 命令: // 描述: 删除一个或多个文件.同等于 erase 命令. 相比较 rd 命令来说,del 命令只能删除文件,不能删除文件夹. // 语法: del [/p] [/f] [/s] [ ...
- 行盒(line box)垂直方向的属性详解:从font-size、line-height到vertical-align
视觉格式化模型 在一个文档中,每个元素都被表示为0.1或多个矩形的盒子.确定这些盒子的尺寸, 属性 --- 像它的颜色,背景,边框方面 --- 和位置是渲染引擎的目标.① 在CSS中,使用标准盒模型描 ...
- 20155312 张竞予 Exp7 网络欺诈防范
Exp7 网络欺诈防范 目录 基础问题回答 (1)通常在什么场景下容易受到DNS spoof攻击 (2)在日常生活工作中如何防范以上两攻击方法 实验总结与体会 实践过程记录 (1)简单应用SET工具建 ...
- Desktop Central —— Windows 管理工具
Desktop Central —— Windows 管理工具 定期维护对于保持系统性能平稳必不可少.诸如磁盘检查.磁盘碎片整理程序之类的工具在系统维护中至关重要.因为管理员很难定期手动执行维护. D ...
- idea取消vim模式
在安装idea时选择了vim编辑模式,但是用习惯了eclipse,总是要拷贝粘贴,在idea中一直按ctrl+c和ctrl+v不起总用.于是想把vim模式关闭掉.方法:菜单栏:tools->vi ...
- oracle执行计划走偏处理步骤
-- sql执行时间select a.EXECUTIONS,a.ELAPSED_TIME,a.ELAPSED_TIME/a.EXECUTIONS/1000/1000 as 秒,a.SQL_ID,a.H ...
- Git系列:第七篇-Maven项目下提交时忽略不必要的文件或文件夹
用.gitignore文件来进行忽略不必要的文件或文件夹 在开发中我们要提交的内容大都是src里的全部文件(java文件).gitignore(忽略文件)pom.xml(maven配置文件)----- ...
- 第三次OO总结
规格化设计的调研 随着50年代高级语言的出现,编译技术不断完善,涌现出多种流派的语言,其中就有里程碑式的Pascal语言:进入70年代,由于众多语言造成的不可移植.难于维护,Ada程序设计语言诞生了, ...
- 解答VS2013编译报错不准确是什么原因
1.当程序在错误时,VS2013编译报出的错误有时不会一起全部报出,而是按错误的英文首字母逐个报出的 2.如果报错的信息双击点过去查看时又发现无明显错误问题时,这个这个时候可以是VS编译的缓存问题,这 ...
- 从git远程仓库Checkout项目到本地
一.登录coding 并且项目已创建好 已经是项目的组员 二.打开idea 1.弹出如下页面 复制远程项目上的SSH(URL)到下框URL 并且Test测试 成功就Clone即可 2.Clone ...