Tensorflow Learning1 模型的保存和恢复
CKPT->pb
tensorflow的模型保存有两种形式:
1. ckpt:可以恢复图和变量,继续做训练
2. pb : 将图序列化,变量成为固定的值,,只可以做inference;不能继续训练
Demo
- 1 def freeze_graph(input_checkpoint,output_graph):
- 2
- 3 '''
- 4 :param input_checkpoint:
- 5 :param output_graph: PB模型保存路径
- 6 :return
- 7 void
- 8 '''
- 9
- 10 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
- 11 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
- 12
- 13 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
- 14 output_node_names = "InceptionV3/Logits/SpatialSqueeze" # 如果是多个输出节点,使用 ‘,’号隔开
- 15
- 16 ############################ Step1: 从ckpt中恢复图: #############################################
- 17 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
- 18 graph = tf.get_default_graph() # 获得默认的图, 可以省略
- 19 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图,可以省略
- 20
- 21 with tf.Session() as sess: # 会使用默认的图 作为当前的图
- 22 saver.restore(sess, input_checkpoint) #恢复图并得到数据
- 23
- 24 ######################## Step2: 创建持久化对象,指定sess,图、以及输出的序列化节点信息 ##############
- 25 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
- 26 sess=sess,
- 27 input_graph_def=input_graph_def,# 等于:sess.graph_def
- 28 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
- 29 ######################### Step3: 模型持久化 #######################################################
- 30 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
- 31 f.write(output_graph_def.SerializeToString()) #序列化输出
- 32 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
- 33 # for op in graph.get_operations():
- 34
- 35 # print(op.name, op.values())
- 36
- 37
- 38 ########################### 调用方式 ################################
- 39 # 输入ckpt模型路径
- 40 input_checkpoint='models/model.ckpt-10000'
- 41 # 输出pb模型的路径
- 42 out_pb_path="models/pb/frozen_model.pb"
- 43 # 调用freeze_graph将ckpt转为pb
- 44 freeze_graph(input_checkpoint,out_pb_path)解析函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。在保存pb的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称; tensor name 和 node name 的区别node name 是 图 的节点,里面包含了很多操作和tensortensor 是 node 里面的一个组成部分;以input 为例,“input:0”是张量的名称,而"input"表示的是节点的名称PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"input:0"
Tensorflow Learning1 模型的保存和恢复的更多相关文章
- TensorFlow笔记-模型的保存,恢复,实现线性回归
模型的保存 tf.train.Saver(var_list=None,max_to_keep=5) •var_list:指定将要保存和还原的变量.它可以作为一个 dict或一个列表传递. •max_t ...
- 第六节,TensorFlow编程基础案例-保存和恢复模型(中)
在我们使用TensorFlow的时候,有时候需要训练一个比较复杂的网络,比如后面的AlexNet,ResNet,GoogleNet等等,由于训练这些网络花费的时间比较长,因此我们需要保存模型的参数. ...
- tensorflow模型的保存与恢复
1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...
- [翻译] Tensorflow模型的保存与恢复
翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...
- tensorflow 1.0 学习:模型的保存与恢复
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...
- AI - TensorFlow - 示例05:保存和恢复模型
保存和恢复模型(Save and restore models) 官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_mo ...
- tensorflow模型的保存与恢复,以及ckpt到pb的转化
转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
随机推荐
- Vue基础第二章
1.数据绑定与数据声明 Vue中的数据绑定就是让与Vue实例绑定的DOM节点或script标签内的变量之间数据更新互相影响,即数据绑定后Vue实例的数据修改会使DOM节点的数据或者script标签内的 ...
- VMware主机使用无线上网
VMware主机使用无线上网,默认的NAT连接在ubuntu下上不了网,需要把网络适配器改成桥接模式.
- Java中InputStream和String之间的转换方法
1.InputStream转化为String1.1 JDK原生提供方法一:byte[] bytes = new byte[0];bytes = new byte[inputStream.availab ...
- div中的图片跑出来
一:div中的图片跑出来 <style> /* 图片在一行 */ #div1 li{ float: left; list-style: none; } </style> < ...
- springMVC项目访问URL链接时遇到某一段然后忽略后面的部分
背景:一个链接URL:http:localhost/tq/asf/218732,配置URL使遇到/asf后直接跳转不识别/asf后面的218732 因为是在ssm框架下做的项目,所以用spring的注 ...
- UVA 11090 : Going in Cycle!! 【spfa】
题目链接 题意及题解参见lrj训练指南 #include<bits/stdc++.h> using namespace std; const double INF=1e18; ; ; in ...
- 容器"共享"宿主机的hosts文件(终极方案)
0.背景 有时候制作docker镜像生成容器时需要宿主机的hosts文件共享到容器中.首先想的是通过挂载的方式共享hosts文件,但是实践时发现根本行不通,hosts文件在/etc/目录下,如进行挂载 ...
- 16.合并两个排序的链表(python)
题目描述 输入两个单调递增的链表,输出两个链表合成后的链表,当然我们需要合成后的链表满足单调不减规则. class Solution: # 返回合并后列表 def Merge(self, pHead1 ...
- vue学习-day03(动画,组件)
目录: 1.品牌列表-从数据库获取列表 2.品牌列表-完成添加功能 3.品牌列表-完成删除功能 4.品牌列表-全局配置数据接口的根域名 5.品牌列表-全局配置emulateJS ...
- html audio标签 语法
html audio标签 语法 audio标签的作用是什么? 作用:<audio> 标签定义声音,比如音乐和视频或其他音频资源,使用audio标签可以不用Flash插件就可以听音乐看视频, ...