转自:http://blog.csdn.net/lwplwf/article/details/62419087

之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西。

TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。


下面代码给出了保存TensorFlow模型的方法:

  1. import tensorflow as tf
  2. # 声明两个变量
  3. v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
  4. v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
  5. init_op = tf.global_variables_initializer() # 初始化全部变量
  6. saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
  7. with tf.Session() as sess:
  8. sess.run(init_op)
  9. print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
  10. print("v2:", sess.run(v2))
  11. saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
  12. print("Model saved in file:", saver_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

这段代码中,通过saver.save函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt",也就是保存到了当前程序所在文件夹里面的save文件夹中。

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中实际会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

  • model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
  • model.ckpt文件保存了TensorFlow程序中每一个变量的取值
  • checkpoint文件保存了一个目录下所有的模型文件列表


下面代码给出了加载TensorFlow模型的方法:

可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?

  1. import tensorflow as tf
  2. # 使用和保存模型代码中一样的方式来声明变量
  3. v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
  4. v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
  5. saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
  6. with tf.Session() as sess:
  7. saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来
  8. print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
  9. print("v2:", sess.run(v2))
  10. print("Model Restored")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

运行结果:

  1. v1: [[ 0.76705766 1.82217288]]
  2. v2: [[-0.98012197 1.2369734 0.5797025 ]
  3. [ 2.50458145 0.81897354 0.07858191]]
  4. Model Restored
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。 
也就是说使用TensorFlow完成了一次模型的保存和读取的操作。



如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:

  1. import tensorflow as tf
  2. # 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量
  3. # 直接加载持久化的图
  4. saver = tf.train.import_meta_graph("save/model.ckpt.meta")
  5. with tf.Session() as sess:
  6. saver.restore(sess, "save/model.ckpt")
  7. # 通过张量的名称来获取张量
  8. print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

运行程序,输出:

  1. [[ 0.76705766 1.82217288]]
  • 1
  • 1

有时可能只需要保存或者加载部分变量。 
比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。

为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。

…未完待续

TensorFlow学习笔记(8)--网络模型的保存和读取【转】的更多相关文章

  1. matlab学习笔记4--MAT文件的保存和读取

    一起来学matlab-matlab学习笔记4 数据导入和导出_1 MAT文件的保存和读取 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考书籍 <matlab 程序设计与综合应用&g ...

  2. TensorFlow基础笔记(14) 网络模型的保存与恢复_mnist数据实例

    http://blog.csdn.net/huachao1001/article/details/78502910 http://blog.csdn.net/u014432647/article/de ...

  3. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  4. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  5. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  6. Tensorflow学习笔记No.7

    tf.data与自定义训练综合实例 使用tf.data自定义猫狗数据集,并使用自定义训练实现猫狗数据集的分类. 1.使用tf.data创建自定义数据集 我们使用kaggle上的猫狗数据以及tf.dat ...

  7. Tensorflow学习笔记No.8

    使用VGG16网络进行迁移学习 使用在ImageNet数据上预训练的VGG16网络模型对猫狗数据集进行分类识别. 1.预训练网络 预训练网络是一个保存好的,已经在大型数据集上训练好的卷积神经网络. 如 ...

  8. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

  9. Tensorflow学习笔记2019.01.22

    tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...

  10. Tensorflow学习笔记2019.01.03

    tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...

随机推荐

  1. 【mysql】数据库Schema的优化

    由于MySQL数据库是基于行(Row)存储的数据库,而数据库操作 IO 的时候是以 page(block)的方式,也就是说,如果我们每条记录所占用的空间量减小,就会使每个page中可存放的数据行数增大 ...

  2. 解决ubuntu13.04 有线网络 时常掉线的问题

    不少朋友在升级或新装ubuntu13.04时遇到有线老掉线的问题:连上不到半分钟又掉了,把网线重新拔插一下又可以接着又掉..基本不能正常使用或工作,很恼人的问题. 网上这方面的资料很少现在我把解决方法 ...

  3. POST、GET请求中文参数乱码问题

    POST请求中文乱码问题解决方法: 在web.xml文件中添加编码过滤器,如下: <!-- 解决post乱码 --> <filter> <filter-name>C ...

  4. 【LeetCode】Longest Substring with At Most Two Distinct Characters (2 solutions)

    Longest Substring with At Most Two Distinct Characters Given a string, find the length of the longes ...

  5. Linux安装ElasticSearch-2.2.0-分词器插件(Mmseg)

    1.在gitpub上搜索elasticsearch-analysis,能够看到所有elasticsearch的分词器: 2.安装Mmseg分词器:https://github.com/medcl/el ...

  6. C# 程序打包Release版本

    注意:DEBUG和RELEASE的区别,DEBUG下可以直接运行,而RELEASE不一定能直接运行,这并不是表示RELEASE版本有问题,而是表示两者需要操作不同.RELEASE版本要比DEBUG版本 ...

  7. C语言中 不定义结构体变量求成员大小

    所谓的求成员大小, 是求成员在该结构体中 用 sizeof(结构体名.结构体成员名) 求来的. 很多时候我们需要知道一个结构体成员中的某个成员的大小, 但是我们又不需要定义该结构体类型的变量(定义的话 ...

  8. Easyui 中 Tabsr的常用方法

    注:index必须为变量 tab页从0开始 //新增一个tab页var index = 0;$('#tt').tabs('add',{ title: 'Tab'+index, content: '&l ...

  9. 跟我学SharePoint 2013视频培训课程——排序、过滤在列表、库中的使用(10)

    课程简介 第10天,SharePoint 2013排序.过滤在列表.库中的使用. 视频 SharePoint 2013 交流群 41032413

  10. Shiro(一):shiro架构和组件介绍

    简介 Apache Shiro是一个强大且易用的Java安全框架,执行身份认证.授权.加密和会话管理.使用Shiro的易于理解的API,可以快速.轻松地获得任何应用程序,从最小的移动应用程序到最大的网 ...