我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  1. 定义变量
  2. 使用saver.save()方法保存

载入

  1. 定义变量
  2. 使用saver.restore()方法载入

保存代码

import tensorflow as tf
import numpy as np W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"save/model.ckpt")

 载入代码如下

import tensorflow as tf
import numpy as np W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
meta_graph_or_file,
clear_devices=False,
import_scope=None,
**kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存模型代码

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y') w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
# 保存checkpoint, 同时也默认导出一个meta_graph
# graph名为'my-model-{global_step}.meta'.
saver.save(sess, 'my-model', global_step=step)

载入模型

with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
y = tf.get_collection('pred_network')[0] graph = tf.get_default_graph() # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
input_x = graph.get_operation_by_name('input_x').outputs[0]
keep_prob = graph.get_operation_by_name('keep_prob').outputs[0] # 使用y进行预测
sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的: 
一、 saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如 
my-model-10000.metamy-model-10000.indexmy-model-10000.data-00000-of-00001 
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

TensorFlow 模型保存/载入的更多相关文章

  1. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  2. TensorFlow模型保存和提取方法

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...

  3. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  4. Tensorflow模型保存与加载

    在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提 ...

  5. Tensorflow模型保存与载入

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = in ...

  6. 10 Tensorflow模型保存与读取

    我们的模型训练出来想给别人用,或者是我今天训练不完,明天想接着训练,怎么办?这就需要模型的保存与读取.看代码: import tensorflow as tf import numpy as np i ...

  7. 一份快速完整的Tensorflow模型保存和恢复教程(译)(转载)

    该文章转自https://blog.csdn.net/sinat_34474705/article/details/78995196 我在进行图像识别使用ckpt文件预测的时候,这个文章给我提供了极大 ...

  8. 转 tensorflow模型保存 与 加载

    使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获 ...

  9. tensorflow 模型保存后的加载路径问题

    import tensorflow as tf #保存模型 saver = tf.train.Saver() saver.save(sess, "e://code//python//test ...

随机推荐

  1. vue 仿今日头条

    vue 仿今日头条 为了增加移动端项目的经验,近一周通过 vue 仿写今日头条,以下就项目实现过程中遇到的问题以及解决方法给出总结,有什么不正确的地方,恳请大家批评指正^ _ ^!,代码仓库地址为 g ...

  2. 关于Linux启动文件rc.local的解惑

    背景 首先,rc.local是Linux启动程序在login程序前执行的最后一个脚本,有的服务器中在rc.local中可能会有一句touch /var/lock/subsys/local,这是干什么的 ...

  3. linux应用之tomcat的安装及配置(centos)

    CentOS 6.6下安装配置Tomcat环境 [日期:2015-08-25] 来源:Linux社区  作者:tae44 [字体:大 中 小]   实验系统:CentOS 6.6_x86_64 实验前 ...

  4. MySQL丨5.6版本插入中文显示问号解决方法

    解决办法: 1.找到安装目录下的my-default.ini 这个配置文件 2.copy一份粘贴到同目录下 另命名为my.ini 3.在my.ini 配置下加上下面几句代码 并保存 [mysql]de ...

  5. H3C-交换机密码恢复

    交换机密码恢复: 一. 拔掉电源再插上重新启动交换机,在超级终端中可以看到交换机启动画面,当出现提示按CTRL+B时,此时按住CTRL+B,我们会看到有9个选项: 1. download applic ...

  6. .html 页面修改成 .jsp 后缀后中文乱码解决办法。

    .html 后缀的文件,如果直接将 .html后缀改成 .jsp 后缀,则会乱码. 正确方法如下: 将如图的代码中 html  声明去掉,然后加上这段代码:<%@ page language=& ...

  7. BZOJ_5416_[Noi2018]冒泡排序_DP+组合数+树状数组

    BZOJ_5416_[Noi2018]冒泡排序_DP+组合数+树状数组 Description www.lydsy.com/JudgeOnline/upload/noi2018day1.pdf 好题. ...

  8. spellchecker inspection helps locate typeos and misspelling in your code, comments and literals, and fix them in one click

    项目layout文件中出现 spellchecker inspection helps locate typos and misspelling in your code, comments and ...

  9. bzoj4892

    后缀数组 先开始nc了,觉得自动机做法是指数级的,就写了个后缀数组 具体方法是暴力,枚举起点,然后用lcp向后暴力匹配,如果失配就减少一次,我们一共有3次机会,这样每次匹配复杂度是O(1)的,所以总复 ...

  10. char的定义在iOS和Android下是不同的

    char is different in iOS and Android!跨平台开发时很容易忽略的非常坑爹的一个区别. 我的需求是实现一个算法,这个算法在iOS和Android下需要保持一致的结果,很 ...