将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

  1. saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

  1. saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

  1. saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

  1. saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sun Jun 4 10:29:48 2017
  4.  
  5. @author: Administrator
  6. """
  7. import tensorflow as tf
  8. from tensorflow.examples.tutorials.mnist import input_data
  9. mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
  10.  
  11. x = tf.placeholder(tf.float32, [None, 784])
  12. y_=tf.placeholder(tf.int32,[None,])
  13.  
  14. dense1 = tf.layers.dense(inputs=x,
  15. units=1024,
  16. activation=tf.nn.relu,
  17. kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
  18. kernel_regularizer=tf.nn.l2_loss)
  19. dense2= tf.layers.dense(inputs=dense1,
  20. units=512,
  21. activation=tf.nn.relu,
  22. kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
  23. kernel_regularizer=tf.nn.l2_loss)
  24. logits= tf.layers.dense(inputs=dense2,
  25. units=10,
  26. activation=None,
  27. kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
  28. kernel_regularizer=tf.nn.l2_loss)
  29.  
  30. loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
  31. train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
  32. correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)
  33. acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  34.  
  35. sess=tf.InteractiveSession()
  36. sess.run(tf.global_variables_initializer())
  37.  
  38. saver=tf.train.Saver(max_to_keep=1)
  39. for i in range(100):
  40. batch_xs, batch_ys = mnist.train.next_batch(100)
  41. sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  42. val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  43. print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  44. saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  45. sess.close()

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

  1. saver=tf.train.Saver(max_to_keep=1)
  2. max_acc=0
  3. for i in range(100):
  4. batch_xs, batch_ys = mnist.train.next_batch(100)
  5. sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  6. val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  7. print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  8. if val_acc>max_acc:
  9. max_acc=val_acc
  10. saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  11. sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

  1. saver=tf.train.Saver(max_to_keep=3)
  2. max_acc=0
  3. f=open('ckpt/acc.txt','w')
  4. for i in range(100):
  5. batch_xs, batch_ys = mnist.train.next_batch(100)
  6. sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  7. val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  8. print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  9. f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
  10. if val_acc>max_acc:
  11. max_acc=val_acc
  12. saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  13. f.close()
  14. sess.close()

模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

  1. model_file=tf.train.latest_checkpoint('ckpt/')
  2. saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

  1. sess=tf.InteractiveSession()
  2. sess.run(tf.global_variables_initializer())
  3.  
  4. is_train=False
  5. saver=tf.train.Saver(max_to_keep=3)
  6.  
  7. #训练阶段
  8. if is_train:
  9. max_acc=0
  10. f=open('ckpt/acc.txt','w')
  11. for i in range(100):
  12. batch_xs, batch_ys = mnist.train.next_batch(100)
  13. sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  14. val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  15. print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  16. f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
  17. if val_acc>max_acc:
  18. max_acc=val_acc
  19. saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  20. f.close()
  21.  
  22. #验证阶段
  23. else:
  24. model_file=tf.train.latest_checkpoint('ckpt/')
  25. saver.restore(sess,model_file)
  26. val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  27. print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
  28. sess.close()

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sun Jun 4 10:29:48 2017
  4.  
  5. @author: Administrator
  6. """
  7. import tensorflow as tf
  8. from tensorflow.examples.tutorials.mnist import input_data
  9. mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
  10.  
  11. x = tf.placeholder(tf.float32, [None, 784])
  12. y_=tf.placeholder(tf.int32,[None,])
  13.  
  14. dense1 = tf.layers.dense(inputs=x,
  15. units=1024,
  16. activation=tf.nn.relu,
  17. kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
  18. kernel_regularizer=tf.nn.l2_loss)
  19. dense2= tf.layers.dense(inputs=dense1,
  20. units=512,
  21. activation=tf.nn.relu,
  22. kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
  23. kernel_regularizer=tf.nn.l2_loss)
  24. logits= tf.layers.dense(inputs=dense2,
  25. units=10,
  26. activation=None,
  27. kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
  28. kernel_regularizer=tf.nn.l2_loss)
  29.  
  30. loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
  31. train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
  32. correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)
  33. acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  34.  
  35. sess=tf.InteractiveSession()
  36. sess.run(tf.global_variables_initializer())
  37.  
  38. is_train=True
  39. saver=tf.train.Saver(max_to_keep=3)
  40.  
  41. #训练阶段
  42. if is_train:
  43. max_acc=0
  44. f=open('ckpt/acc.txt','w')
  45. for i in range(100):
  46. batch_xs, batch_ys = mnist.train.next_batch(100)
  47. sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  48. val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  49. print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  50. f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
  51. if val_acc>max_acc:
  52. max_acc=val_acc
  53. saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  54. f.close()
  55.  
  56. #验证阶段
  57. else:
  58. model_file=tf.train.latest_checkpoint('ckpt/')
  59. saver.restore(sess,model_file)
  60. val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  61. print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
  62. sess.close()

参考文章:http://blog.csdn.net/u011500062/article/details/51728830

tensorflow 1.0 学习:模型的保存与恢复(Saver)的更多相关文章

  1. TensorFlow笔记-模型的保存,恢复,实现线性回归

    模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...

  2. tensorflow 1.0 学习:模型的保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  3. [翻译] Tensorflow模型的保存与恢复

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  4. tensorflow模型的保存与恢复

    1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...

  5. tensorflow模型的保存与恢复,以及ckpt到pb的转化

    转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...

  6. tensorflow 1.0 学习:用CNN进行图像分类

    tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化. 任务:花卉分类 版本:tensorflow 1 ...

  7. tensorflow 1.0 学习:用别人训练好的模型来进行图像分类

    谷歌在大型图像数据库ImageNet上训练好了一个Inception-v3模型,这个模型我们可以直接用来进来图像分类. 下载地址:https://storage.googleapis.com/down ...

  8. tensorflow 1.0 学习:十图详解tensorflow数据读取机制

    本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...

  9. Tensorflow Learning1 模型的保存和恢复

    CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...

随机推荐

  1. 文件操作命令(del)

    del 命令: // 描述: 删除一个或多个文件.同等于 erase 命令. 相比较 rd 命令来说,del 命令只能删除文件,不能删除文件夹. // 语法: del [/p] [/f] [/s] [ ...

  2. 行盒(line box)垂直方向的属性详解:从font-size、line-height到vertical-align

    视觉格式化模型 在一个文档中,每个元素都被表示为0.1或多个矩形的盒子.确定这些盒子的尺寸, 属性 --- 像它的颜色,背景,边框方面 --- 和位置是渲染引擎的目标.① 在CSS中,使用标准盒模型描 ...

  3. 20155312 张竞予 Exp7 网络欺诈防范

    Exp7 网络欺诈防范 目录 基础问题回答 (1)通常在什么场景下容易受到DNS spoof攻击 (2)在日常生活工作中如何防范以上两攻击方法 实验总结与体会 实践过程记录 (1)简单应用SET工具建 ...

  4. Desktop Central —— Windows 管理工具

    Desktop Central —— Windows 管理工具 定期维护对于保持系统性能平稳必不可少.诸如磁盘检查.磁盘碎片整理程序之类的工具在系统维护中至关重要.因为管理员很难定期手动执行维护. D ...

  5. idea取消vim模式

    在安装idea时选择了vim编辑模式,但是用习惯了eclipse,总是要拷贝粘贴,在idea中一直按ctrl+c和ctrl+v不起总用.于是想把vim模式关闭掉.方法:菜单栏:tools->vi ...

  6. oracle执行计划走偏处理步骤

    -- sql执行时间select a.EXECUTIONS,a.ELAPSED_TIME,a.ELAPSED_TIME/a.EXECUTIONS/1000/1000 as 秒,a.SQL_ID,a.H ...

  7. Git系列:第七篇-Maven项目下提交时忽略不必要的文件或文件夹

    用.gitignore文件来进行忽略不必要的文件或文件夹 在开发中我们要提交的内容大都是src里的全部文件(java文件).gitignore(忽略文件)pom.xml(maven配置文件)----- ...

  8. 第三次OO总结

    规格化设计的调研 随着50年代高级语言的出现,编译技术不断完善,涌现出多种流派的语言,其中就有里程碑式的Pascal语言:进入70年代,由于众多语言造成的不可移植.难于维护,Ada程序设计语言诞生了, ...

  9. 解答VS2013编译报错不准确是什么原因

    1.当程序在错误时,VS2013编译报出的错误有时不会一起全部报出,而是按错误的英文首字母逐个报出的 2.如果报错的信息双击点过去查看时又发现无明显错误问题时,这个这个时候可以是VS编译的缓存问题,这 ...

  10. 从git远程仓库Checkout项目到本地

    一.登录coding  并且项目已创建好  已经是项目的组员 二.打开idea 1.弹出如下页面  复制远程项目上的SSH(URL)到下框URL 并且Test测试 成功就Clone即可 2.Clone ...