Tensorflow:模型变量保存

觉得有用的话,欢迎一起讨论相互学习~

参考文献Tensorflow实战Google深度学习框架

实验平台:

Tensorflow1.4.0

python3.5.0

Tensorflow常用保存模型方法

  1. import tensorflow as tf
  2. saver = tf.train.Saver() # 创建保存器
  3. with tf.Session() as sess:
  4. saver.save(sess,"/path/model.ckpt") #保存模型到相应ckpt文件
  5. saver.restore(sess,"/path/model.ckpt") #从相应ckpt文件中恢复模型变量
  • 使用tf.train.Saver会保存运行Tensorflow程序所需要的全部信息,然而有时并不需要某些信息。比如在测试或离线预测时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似的变量初始化,模型保存等辅助节点的信息。Tensorflow提供了convert_varibales_to_constants函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个Tensorflow计算图可以统一存放在一个文件中。

将变量取值保存为pb文件

  1. # pb文件保存方法
  2. import tensorflow as tf
  3. from tensorflow.python.framework import graph_util
  4. v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
  5. v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
  6. result = v1 + v2
  7. init_op = tf.global_variables_initializer()
  8. with tf.Session() as sess:
  9. sess.run(init_op) # 初始化所有变量
  10. # 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
  11. graph_def = tf.get_default_graph().as_graph_def()
  12. # 将需要保存的add节点名称传入参数中,表示将所需的变量转化为常量保存下来。
  13. output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
  14. # 将导出的模型存入文件中
  15. with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
  16. f.write(output_graph_def.SerializeToString())
  17. # 2. 加载pb文件。
  18. from tensorflow.python.platform import gfile
  19. with tf.Session() as sess:
  20. model_filename = "Saved_model/combined_model.pb"
  21. # 读取保存的模型文件,并将其解析成对应的GraphDef Protocol Buffer
  22. with gfile.FastGFile(model_filename, 'rb') as f:
  23. graph_def = tf.GraphDef()
  24. graph_def.ParseFromString(f.read())
  25. # 将graph_def中保存的图加载到当前图中,其中保存的时候保存的是计算节点的名称,为add
  26. # 但是读取时使用的是张量的名称所以是add:0
  27. result = tf.import_graph_def(graph_def, return_elements=["add:0"])
  28. print(sess.run(result))
  29. # Converted 2 variables to const ops.
  30. # [array([3.], dtype=float32)]

Tensorflow模型变量保存的更多相关文章

  1. [翻译] Tensorflow模型的保存与恢复

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  2. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

  3. tensorflow模型的保存与恢复,以及ckpt到pb的转化

    转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...

  4. TensorFlow 模型的保存与载入

    参考学习博客: # https://www.cnblogs.com/felixwang2/p/9190692.html 一.模型保存 # https://www.cnblogs.com/felixwa ...

  5. 三、TensorFlow模型的保存和加载

    1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...

  6. tensorflow模型的保存与恢复

    1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...

  7. tensorflow模型的保存与加载

    模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...

  8. tensorflow模型持久化保存和加载

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  9. 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

    save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...

随机推荐

  1. PHP开发中常见的漏洞及防范

    PHP开发中常见的漏洞及防范 对于PHP的漏洞,目前常见的漏洞有五种.分别是Session文件漏洞.SQL注入漏洞.脚本命令执行漏洞.全局变量漏洞和文件漏洞.这里分别对这些漏洞进行简要的介绍和防范. ...

  2. 第二次c++作业

    用c语言实现电梯问题的方法: 先用一堆变量存储各种变量,在写一个函数模拟电梯上下移动载人放人的过程. c++: 构造一个电梯的类,用成员函数实现电梯运作的过程. 对c和c++的理解太浅,并没有感觉到用 ...

  3. hdu 1241--入门DFS

    Oil Deposits Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others) Tot ...

  4. 图论---POJ 3660 floyd 算法(模板题)

    是一道floyd变形的题目.题目让确定有几个人的位置是确定的,如果一个点有x个点能到达此点,从该点出发能到达y个点,若x+y=n-1,则该点的位置是确定的.用floyd算发出每两个点之间的距离,最后统 ...

  5. HDU 5228 ZCC loves straight flush 暴力

    题目链接: hdu:http://acm.hdu.edu.cn/showproblem.php?pid=5228 bc(中文):http://bestcoder.hdu.edu.cn/contests ...

  6. C++进阶之_类型转换

    C++进阶之_类型转换 1.类型转换名称和语法 C风格的强制类型转换(Type Cast)很简单,不管什么类型的转换统统是: TYPE b = (TYPE)a C++风格的类型转换提供了4种类型转换操 ...

  7. CCF——数列分段201509-1

    问题描述 给定一个整数数列,数列中连续相同的最长整数序列算成一段,问数列中共有多少段? 输入格式 输入的第一行包含一个整数n,表示数列中整数的个数. 第二行包含n个整数a1, a2, …, an,表示 ...

  8. 从装饰者模式的理解说JAVA的IO包

    1. 装饰者模式的详解 装饰者模式动态地将责任附加到对象上.若要扩展功能,装饰者提供了比继承更有弹性 的替代方案. 装饰者模式设计类之间的关系: 其 中Component是一个超类,ConcreteC ...

  9. linux安装py3.6

    随手记录: https://www.python.org/ftp/python/3.6.8/Python-3.6.8rc1.tgz 所有linux版本: https://www.python.org/ ...

  10. Selenium遇到问题unknown error:cannot create default profile directory......

    1.selenium遇到问题unknown error:cannot create default profile directory...... 2.解决方案 问题1:把驱动放入C:\Windows ...