Tensorflow:模型变量保存

觉得有用的话,欢迎一起讨论相互学习~

参考文献Tensorflow实战Google深度学习框架

实验平台:

Tensorflow1.4.0

python3.5.0

Tensorflow常用保存模型方法

import tensorflow as tf
saver = tf.train.Saver() # 创建保存器
with tf.Session() as sess:
saver.save(sess,"/path/model.ckpt") #保存模型到相应ckpt文件
saver.restore(sess,"/path/model.ckpt") #从相应ckpt文件中恢复模型变量
  • 使用tf.train.Saver会保存运行Tensorflow程序所需要的全部信息,然而有时并不需要某些信息。比如在测试或离线预测时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似的变量初始化,模型保存等辅助节点的信息。Tensorflow提供了convert_varibales_to_constants函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个Tensorflow计算图可以统一存放在一个文件中。

将变量取值保存为pb文件

# pb文件保存方法
import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2 init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op) # 初始化所有变量
# 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
graph_def = tf.get_default_graph().as_graph_def()
# 将需要保存的add节点名称传入参数中,表示将所需的变量转化为常量保存下来。
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
# 将导出的模型存入文件中
with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
f.write(output_graph_def.SerializeToString()) # 2. 加载pb文件。 from tensorflow.python.platform import gfile with tf.Session() as sess:
model_filename = "Saved_model/combined_model.pb"
# 读取保存的模型文件,并将其解析成对应的GraphDef Protocol Buffer
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def中保存的图加载到当前图中,其中保存的时候保存的是计算节点的名称,为add
# 但是读取时使用的是张量的名称所以是add:0
result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print(sess.run(result))
# Converted 2 variables to const ops.
# [array([3.], dtype=float32)]

Tensorflow模型变量保存的更多相关文章

  1. [翻译] Tensorflow模型的保存与恢复

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  2. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

  3. tensorflow模型的保存与恢复,以及ckpt到pb的转化

    转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...

  4. TensorFlow 模型的保存与载入

    参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...

  5. 三、TensorFlow模型的保存和加载

    1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...

  6. tensorflow模型的保存与恢复

    1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...

  7. tensorflow模型的保存与加载

    模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...

  8. tensorflow模型持久化保存和加载

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  9. 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

    save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...

随机推荐

  1. Python20-Day02

    1.数据 数据为什么要分不同的类型 数据是用来表示状态的,不同的状态就应该用不同类型的数据表示: 数据类型 数字(整形,长整形,浮点型,复数),字符串,列表,元组,字典,集合 2.字符串 1.按索引取 ...

  2. 获取文件夹下某个类型的文件名---基于python

    方法1:import osclass flist_name(): def __init__(self,path): self.flist_name=os.listdir(path) def pcap_ ...

  3. php异步学习(1)

    1.为啥PHP需要异步操作? 一般来说PHP适用的场合是web页面展示等耗时比较短的任务,如果对于比较花时间的操作如resize图片.大数据导入.批量发送EDM.SMS等,就很容易出现操作超时情况.你 ...

  4. 个人第十一周PSP

    11.24 --11.30本周例行报告 1.PSP(personal software process )个人软件过程. 类型 任务 开始时间                结束时间 中断时间 实际用 ...

  5. AJAX请求.net controller数据交互过程

    AJAX发出请求 $.ajax({ url: "/Common/CancelTaskDeal", //CommonController下的CancelTaskDeal方法 type ...

  6. JDK,JRE,JVM之间的关系

    JDK包括JRE包括JVM http://java-mzd.iteye.com/blog/838514

  7. 学习调用第三方的WebService服务

    互联网上面有很多的免费webService服务,我们可以调用这些免费的WebService服务,将一些其他网站的内容信息集成到我们的应用中显示,下面就以查询国内手机号码归属地为例进行说明. 首先安利一 ...

  8. 404 Note Found队——现场编程

    目录 组员职责分工 github 的提交日志截图 程序运行截图 程序运行环境 GUI界面 基础功能实现 运行视频 LCG算法 过滤(降权)算法 算法思路 红黑树 附加功能一 背景 实现 附加功能二(迭 ...

  9. 模拟登入教务处(header)

    import HTMLParser import urlparse import urllib import urllib2 import cookielib import string import ...

  10. 第七周C语言代码

    #ifndef NMN_LIST_H #define NMN_LIST_H   #include <stdio.h>   struct list_head {     struct lis ...