转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html

一、模型的保存

使用tensorflow训练模型的过程中,需要适时对模型进行保存,以及对保存的模型进行restore,以便后续对模型进行处理。如:测试、部署、拿别的模型进行fine-tune等。

保存模型是整个内容的第一步,操作十分简单,只需要创建一个saver,并在一个Session里完成保存。

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess, model_name)

以上代码在0.11以下版本的tensorflow里会保存与下面类似的3个文件

checkpoint
model.ckpt-1000.meta
model.ckpt-1000.ckpt

其中checkpoint列出保存的所有模型以及最近的模型;meta文件是模型定义的内容;ckpt(或data和index)文件是保存的模型数据。

除了上面最简单的保存方式,也可以指定保存的步数,多长时间保存一次,磁盘上最多保存几个模型(将前面的删除以保持固定个数),需要做的是在创建saver时指定参数

saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)

其中,savable_variables指定待保存的变量,比如指定为tf.global_variables()保存所有global变量;指定为[v1, v2]保存v1和v2两个变量,如果省略,则保存所有。

max_to_keep指定磁盘上最多保存有几个模型。

keep_checkpoint_every_n_hours指定多少小时保存一次。

保存模型时指定参数

saver.save(sess, 'model_name', global_step=step, write_meta_graph=False)

其中,可以指定模型文件名,步数,write_meta_graph则用来指定是否保存meta文件记录graph,等等。

二、模型的恢复及查看模型参数

with tf.Session() as sess:
# 加载模型定义的graph
saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
# 方式一:加载指定文件夹下最近保存的一个模型的数据
saver.restore(sess, tf.train.latest_checkpoint('./'))
# 方式二:指定具体某个数据,需要注意的是,指定的文件不要包含后缀
# saver.restore(sess, os.path.join(path, 'model.ckpt-1000')) # 查看模型中的trainable variables
tvs = [v for v in tf.trainable_variables()]
for v in tvs:
print(v.name)
print(sess.run(v)) # 查看模型中的所有tensor或者operations
gv = [v for v in tf.global_variables()]
for v in gv:
print(v.name) # 获得几乎所有的operations相关的tensor
ops = [o for o in sess.graph.get_operations()]
for o in ops:
print(o.name)

说明:

1、global_variables()比trainable_variables()多了一些非trainable的变量,比如定义时指定为trainable=False的变量,或Optimizer相关的变量。

2、sess.graph.get_operations()可以换为tf.get_default_graph().get_operations(),二者区别无非是graph明确的时候可以直接使用前者,否则需要使用后者。

三、将ckpt转化为pb

freeze_graph就是将模型固化,具体说就是将训练数据和模型固化成pb文件。

参数: (必选: 表示必须有值;可选: 表示可以为空):
1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)
2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。
3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False
4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。
5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。
6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all
7、filename_tensor_name:(可选)已弃用。默认:save/Const:0
8、output_graph:(必选)用来保存整合后的模型输出文件。
9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)
10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。

if __name__ == '__main__':
args = parse_args() # model path
demonet = args.demo_net
dataset = args.dataset
tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0]) if not os.path.isfile(tfmodel + '.meta'):
print(tfmodel)
raise IOError(('{:s} not found.\nDid you download the proper networks from '
'our server and place them properly?').format(tfmodel + '.meta')) # set config
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True # init session
sess = tf.Session(config=tfconfig)
# load network
if demonet == 'vgg16':
net = vgg16(batch_size=1)
else:
raise NotImplementedError net.create_architecture(sess, "TEST", 4,
tag='default', anchor_scales=[8, 16, 32])
saver = tf.train.Saver()
saver.restore(sess, tfmodel) # 保存图
tf.train.write_graph(sess.graph_def, 'pb/pb_model', 'model.pb')
# 把图和参数结构一起
freeze_graph.freeze_graph('pb/pb_model/model.pb',
'',
False,
tfmodel,
'vgg_16/cls_score/BiasAdd,vgg_16/cls_prob,vgg_16/bbox_pred/BiasAdd,vgg_16/rois/PyFunc',
'save/restore_all',
'save/Const:0',
'pb/pb_model/frozen_model.pb',
False,
"")

tensorflow模型的保存与恢复,以及ckpt到pb的转化的更多相关文章

  1. [翻译] Tensorflow模型的保存与恢复

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  2. tensorflow模型的保存与恢复

    1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...

  3. tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...

  4. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  5. tensorflow 1.0 学习:模型的保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  6. TensorFlow笔记-模型的保存,恢复,实现线性回归

    模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...

  7. Tensorflow Learning1 模型的保存和恢复

    CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...

  8. Tensorflow模型变量保存

    Tensorflow:模型变量保存 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 pyt ...

  9. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

随机推荐

  1. Jmeter 逻辑控制器 之 循环控制器

    今天和大家分享下循环控制器的使用. 一.认识循环控制器 如下图:新增一个循环控制器 循环控制器的设置界面: 循环次数:永远和自定义次数,这个应该比较好理解. 二.使用循环控制器 其实大家对Jmeter ...

  2. Python3 Selenium自动化web测试 ==> 第十一节 WebDriver高级应用 -- 显示等待 + 二次封装

    学习目的: 掌握显示等待 掌握二次封装 正式步骤: step1:显示等待的代码示例 # -*- coding:utf-8 -*- from selenium import webdriver from ...

  3. Go语言中defer语句使用小结

    defer是Go语言中的延迟执行语句,用来添加函数结束时执行的代码,常用于释放某些已分配的资源.关闭数据库连接.断开socket连接.解锁一个加锁的资源.Go语言机制担保一定会执行defer语句中的代 ...

  4. NDK学习笔记-JNI的引用

    JNI中的引用意在告知虚拟机何时回收一个JNI变量 JNI引用变量分为局部引用和全局引用 局部引用 局部引用,通过DeletLocalRef手动释放对象 原因 访问一个很大的Java对象,使用之后还用 ...

  5. [转载]由浅入深探究mysql索引结构原理、性能分析与优化

    第一部分:基础知识第二部分:MYISAM和INNODB索引结构1. 简单介绍B-tree B+ tree树 2. MyisAM索引结构 3. Annode索引结构 4. MyisAM索引与InnoDB ...

  6. eNSP——Hybrid接口的应用

    原理: Hybrid接口既可以连接普通终端的接入链路又可以连接交换机间的干道链路,它允许多个VLAN的帧通过,并可以在出接口方向将某些VLAN帧的标签剥掉. Hybrid接口处理VLAN帧的过程如下: ...

  7. jq+js获取到table标签中的value

    前端jsp页面,(这里接收后端的参数方式没有放在上面) <table> <tbody id="fPzQwQwzbrList"> <tr id=&quo ...

  8. [转帖]关于USB3.0以及type-C

    忘记来源页面了.. 但是昨天晚上 usb 4.0 发布了 跟雷电C 安全一样的标准 双向40gb 的带宽. 而且 以后只有usb type-C的接口了. 我们办公机器上面的 typeC 同事用 ngf ...

  9. 使用power designer,PL/SQL,cmd建立oracle数据库

    这一系列操作需要powerDesigner,PL/SQL工具 1.首先使用powerDesigner建立概念模型 2.概念模型界面例子 3.其中建立概念模型操作图标详解 4.建立物理模型 5.生成数据 ...

  10. 2019 Multi-University Training Contest 2: 1010 Just Skip The Problem 自闭记

    2019 Multi-University Training Contest 2: 1010 Just Skip The Problem 自闭记 题意 多测.每次给你一个数\(n\),你可以同时问无数 ...