虽然说 TensorFlow 2.0 即将问世,但是有一些模块的内容却是不大变化的。其中就有 tf.saved_model 模块,主要用于模型的存储和恢复。为了防止学习记录文件丢失或者蠢笨的脑子直接遗忘掉这部分内容,在此做点简单的记录,以便将来查阅。

最近为了一个课程作业,不得已涉及到关于图像超分辨率恢复的内容,不得不准备随时存储训练的模型,只好再回过头来瞄一眼 TensorFlow 文档,真是太痛苦了。

tf.saved_model 模块下面有很多文件和函数,精力有限,只好选择于自己有用的东西来看,可能并不全面,望日后补上。

其中最重要的就是该模块下的一个类:tf.saved_model.builder.SavedModelBuilder

tf.saved_model.builder.SavedModelBuilder:

# 构造函数
.__init__(export_dir)
"""
作用:
  创建一个保存模型的实例对象
参数:
export_dir: 模型导出路径,由于 TensorFlow 会在你指定的路径上创建文件夹和文件,所以指定的路径最后不需要带 /,
   例如:export_dir='/home/***/saved_model' 即可,最后不需要加上 /
""" # 方法
#
.add_meta_graph_and_variables(sess, tags, signature_def_map=None, assets_collection=None,
clear_devices=False, main_op=None, strip_default_attrs=False, saver=None)
"""
作用:
  保存会话对象中的 graph 和所有变量,具体描述可参见文档
参数:
  sess: TensorFlow 会话对象,用于保存元图和变量
  tags: 用于保存元图的标记集(如果存在多个图对象,需要设置保证每个图标签不一样),是一个列表
  signature_def_map: 一个字典,保存模型时传入的参数,key 可以是字符串,也可以是 tf.saved_model.signature_constants 文件下预定义的变量,
值为 signatureDef protobuf(protobuf 是一种结构化的数据存储格式)
  assets_collection: 略
  clear_devices: 如果需要清除默认图上的设备信息,则设置为 true
  main_op: 这个参数包括后面一系列与其相关的东西没有弄明白
  strip_default_attrs: 如果设置为 True,将从 NodeDefs 中删除默认值属性
  saver: tf.train.Saver 的一个实例,用于导出元图并保存变量
""" #
.add_meta_graph()
"""
作用:
  其除了没有 sess 参数以外,其他参数和 .add_meta_graph_and_variables() 一模一样
  调用此方法之前必须先调用 .add_meta_graph_and_variables() 方法
""" #
.save(as_text=False)
"""
作用:
  将内建的 savedModel protobuf 写入磁盘
"""

除了这个最重要的类以外,tf.saved_model 模块还提供了一些方便构建 builder 和加载模型的函数方法。

#
tf.saved_model.utils.build_tensor_info(tensor)
"""
作用:
构建 TensorInfo protobuf,根据输入的 tensor 构建相应的 protobuf,返回的 TensorInfo 中包含输入 tensor 的 name,shape,dtype 信息
参数:
tensor: Tensor 或 SparseTensor
""" #
tf.saved_model.signature_def_utils.build_signature_def(inputs=None, outputs=None, method_name=None)
"""
作用:
构建 SignatureDef protobuf,并返回 SignatureDef protobuf
参数:
inputs: 一个字典,键为字符串类型,值为关于 tensor 的信息,也就是上述的 .build_tensor_info() 函数返回的 TensorInfo protobuf
outputs: 一个字典,同上
method_name: SignatureDef 名称
""" #
tf.saved_model.utils.get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None)
"""
作用:
根据一个 TensorInfo protobuf 解析出一个 tensor
参数:
tensor_info: 一个 TensorInfo protobuf
graph: tensor 所存在的 graph,参数为 None 时,使用默认图
import_scope: 给 tensor 的 name 加上前缀
""" #
tf.saved_model.loader.load(sess, tags, export_dir, import_scope=None, **saver_kwargs)
"""
作用:
加载已存储的模型
参数:
sess: 用于恢复模型的 tf.Session() 对象
tags: 用于标识 MetaGraphDef 的标记,应该和存储模型时使用的此参数完全一致
export_dir: 模型存储路径
import_scope: 加前缀
"""

除了这些以外,还有一些 TensorFlow 为了方便而预定义的一些变量,这些变量完全可以使用自定义字符串代替,不再赘述。详情:https://tensorflow.google.cn/api_docs/python/tf/saved_model

如果只看这些内容的话,确实会使人产生巨大的疑惑,下面是具体实践的例子:

import tensorflow as tf
from tensorflow import saved_model as sm # 首先定义一个极其简单的计算图
X = tf.placeholder(tf.float32, shape=(3, ))
scale = tf.Variable([10, 11, 12], dtype=tf.float32)
y = tf.multiply(X, scale) # 在会话中运行
with tf.Session() as sess:
sess.run(tf.initializers.global_variables())
value = sess.run(y, feed_dict={X: [1., 2., 3.]})
print(value) # 准备存储模型
path = '/home/×××/tf_model/model_1'
builder = sm.builder.SavedModelBuilder(path) # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
X_TensorInfo = sm.utils.build_tensor_info(X)
scale_TensorInfo = sm.utils.build_tensor_info(scale)
y_TensorInfo = sm.utils.build_tensor_info(y) # 构建 SignatureDef protobuf
SignatureDef = sm.signature_def_utils.build_signature_def(
inputs={'input_1': X_TensorInfo, 'input_2': scale_TensorInfo},
outputs={'output': y_TensorInfo},
method_name='what'
) # 将 graph 和变量等信息写入 MetaGraphDef protobuf
# 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,不在新地方将自定义的字符串忘记,可以使用预定义的这些值
builder.add_meta_graph_and_variables(sess, tags=[sm.tag_constants.TRAINING],
signature_def_map={sm.signature_constants.CLASSIFY_INPUTS: SignatureDef}
)  # 将 MetaGraphDef 写入磁盘
builder.save()

这样我们就把模型整体存储到了磁盘中,而且我们将三个变量 X, scale, y 全部序列化后存储到了其中,所以恢复模型时便可以将他们完全解析出来:

import tensorflow as tf
from tensorflow import saved_model as sm # 需要建立一个会话对象,将模型恢复到其中
with tf.Session() as sess:
path = '/home/×××/tf_model/model_1'
MetaGraphDef = sm.loader.load(sess, tags=[sm.tag_constants.TRAINING], export_dir=path) # 解析得到 SignatureDef protobuf
SignatureDef_d = MetaGraphDef.signature_def
SignatureDef = SignatureDef_d[sm.signature_constants.CLASSIFY_INPUTS] # 解析得到 3 个变量对应的 TensorInfo protobuf
X_TensorInfo = SignatureDef.inputs['input_1']
scale_TensorInfo = SignatureDef.inputs['input_2']
y_TensorInfo = SignatureDef.outputs['output'] # 解析得到具体 Tensor
# .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图
X = sm.utils.get_tensor_from_tensor_info(X_TensorInfo, sess.graph)
scale = sm.utils.get_tensor_from_tensor_info(scale_TensorInfo, sess.graph)
y = sm.utils.get_tensor_from_tensor_info(y_TensorInfo, sess.graph) print(sess.run(scale))
print(sess.run(y, feed_dict={X: [3., 2., 1.]})) # 输出
[10. 11. 12.]
[30. 22. 12.]

可以看出模型整体和变量个体都被完整地保存了下来。其中涉及的关于 protobuf 的知识,需要补习,在 TensorFlow 中好多地方都用到了相关的知识。上述恢复模型的代码中对具体的 TensorInfo protobuf 解析时,还可以使用另一种方式得到相应的 Tensor:

# 已知 X_TensorInfo, scale_TensorInfo, y_TensorInfo
X = sess.graph.get_tensor_by_name(X_TensorInfo.name)
scale = sess.grpah.get_tensor_by_name(scale_TensorInfo.name)
y = sess.graph.get_tensor_by_name(y_TensorInfo.name) # 因为 TensorFlow 构建 TensorInfo protobuf 时,使用了 Tensor 的 name 信息,所以可以直接读出来使用

记录:tf.saved_model 模块的简单使用(TensorFlow 模型存储与恢复)的更多相关文章

  1. 一份快速完整的Tensorflow模型保存和恢复教程(译)(转载)

    该文章转自https://blog.csdn.net/sinat_34474705/article/details/78995196 我在进行图像识别使用ckpt文件预测的时候,这个文章给我提供了极大 ...

  2. 学习笔记TF049:TensorFlow 模型存储加载、队列线程、加载数据、自定义操作

    生成检查点文件(chekpoint file),扩展名.ckpt,tf.train.Saver对象调用Saver.save()生成.包含权重和其他程序定义变量,不包含图结构.另一程序使用,需要重新创建 ...

  3. ROS学习记录(一)————创建简单的机器人模型smartcar

    这是我在古月居上找的(http://www.guyuehome.com/243),但直接运行的话,没办法跑起来,我也是查了好多博客和日志,才实现最后的功能的,所以,记录下来,以备后用吧,也欢迎其他和我 ...

  4. TensorFlow使用记录 (九): 模型保存与恢复

    模型文件 tensorflow 训练保存的模型注意包含两个部分:网络结构和参数值. .meta .meta 文件以 “protocol buffer”格式保存了整个模型的结构图,模型上定义的操作等信息 ...

  5. TensorFlow 模型文件

    在这篇 TensorFlow 教程中,我们将学习如下内容: TensorFlow 模型文件是怎么样的? 如何保存一个 TensorFlow 模型? 如何恢复一个 TensorFlow 模型? 如何使用 ...

  6. tensorflow 模型前向传播 保存ckpt tensorbard查看 ckpt转pb pb 转snpe dlc 实例

    参考: TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式 TensorFlow 模型保存与恢复 snpe tensorflow 模型前向传播 保存ckpt  tensor ...

  7. Tensorflow 的saved_model模块学习

    saved_model模块主要用于TensorFlow Serving.TF Serving是一个将训练好的模型部署至生产环境的系统,主要的优点在于可以保持Server端与API不变的情况下,部署新的 ...

  8. TensorFlow saved_model 模块

    最近在学tensorflow serving 模块,一直对接口不了解,后面看到这个文章就豁然开朗了, 主要的困难在于   tf.saved_model.builder.SavedModelBuilde ...

  9. Tensorflow模型加载与保存、Tensorboard简单使用

    先上代码: from __future__ import absolute_import from __future__ import division from __future__ import ...

随机推荐

  1. teradata 数据定义

    teradata 数据定义 创建表的可选项 是否允许记录重复 set 不允许记录重复 multiset 允许记录重复 数据保护 fallback       fallback    使用fallbac ...

  2. 另开一篇 https

    https 流程 1.加密传输:对称加密传输信息 2.身份认证:非对称加密.通过证书来保障客户端给服务器的密钥唯一性. 因为中间层要是伪装公钥和证书,但是又无法解密原有的发送的数据,那么发给服务器的数 ...

  3. 使用JavaScript实现简单的小游戏-贪吃蛇

    最近初学JavaScript,在这里分享贪吃蛇小游戏的实现过程, 希望能看到的前辈们能指出这个程序的不足之处. 大致思路 首先要解决的问题 随着蛇头的前进,尾巴也要前进. 用键盘控制蛇的运动方向. 初 ...

  4. 2.Dubbo2.5.3注册中心和监控中心部署

    转载请出自出处:http://www.cnblogs.com/hd3013779515/ 1.注册中心Zookeeper安装 (1)搭建要求 zk服务器集群规模不小于3个节点要求各服务器之间系统时间要 ...

  5. PHPer是草根吗

    以下文字并没有非常多的技术词汇,所以只要对PHP感兴趣的人都可以看看. PHPer是草根吗? 从PHP诞生之日起,PHP就开始在Web应用方面为广大的程序员服务.同时,作为针对Web开发量身定制的脚本 ...

  6. 【转】 最新版chrome谷歌浏览器Ajax跨域调试问题

    Ajax本身是不支持跨域的,而我们在开发工作中,可能会遇到本地开发环境未配置相关代码,需要到其他服务器上获取数据的情况,尤其在用HTML5开发app的过程中,前后台完全分离,使用Ajax进行数据交互, ...

  7. vagrant特性——基于docker开发环境(docker和vagrant的结合)-4-简单例子-有问题

    运行一个十分简单的例子: Vagrant.configure() do |config| config.vm.provider "docker" do |d| d.image = ...

  8. 剑指offer.找出数组中重复的数字

    题目: 给定一个长度为 n 的整数数组 nums,数组中所有的数字都在 0∼n−1 的范围内.数组中某些数字是重复的,但不知道有几个数字重复了,也不知道每个数字重复了几次.请找出数组中任意一个重复的数 ...

  9. 20175310 《Java程序设计》第1周学习总结(2)

    20175310 <Java程序设计>第1周学习总结(2) 教材学习内容总结 本周学习了教材的第一章内容,通过看微课的方式,自主学习,教材上讲的比较简单,主要的问题都在调试代码上,还有一两 ...

  10. JS实现拖动div层移动

    JS实现拖动div层移动 在谈到拖动div层之前,我们有必要来了解下 下面JS几个属性的区别----  pageX,pageY,layerX,layerY,clientX,clientY,screen ...