1. 从 ckpt-.data,ckpt-.index 和 .meta 生成 frozenpb

import os
import tensorflow as tf
from tensorflow.python.framework import graph_util def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "outputs"
saver = tf.train.import_meta_graph(os.path.join(os.path.split(input_checkpoint)[0], 'graph.meta'), clear_devices=True) with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants(
# 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node))
#得到当前图有几个操作节点 if __name__ == "__main__":
# 输入ckpt模型路径
input_checkpoint='ckpt_path/ckpt-10000'
# 输出pb模型的路径
out_pb_path="some_path/frozen_model.pb"
# 调用freeze_graph将ckpt转为pb
freeze_graph(input_checkpoint,out_pb_path)

2. 从网络代码和 ckpt-.data 文件生成 frozenpb

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph import network # 导入网络结构 os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 设置GPU
model_path = "ckpt_path/ckpt-10000" def main():
tf.reset_default_graph()
input_node = tf.placeholder(
tf.float32, shape=(None,112, 96, 3)
)
input_node = tf.identity(input_node,name="inputs") # 设置输入节点的名字,这里可以自定义名称
flow = network(input_node)
flow = tf.identity(flow, name="outs") # 设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, model_path)
# 保存图
tf.train.write_graph(sess.graph_def, "logdir/", "graph.pb")
# 把图和参数结构一起
freeze_graph.freeze_graph(
"logdir/graph.pb", # 上面保存的图结构 graph.pb
"",
False,
model_path,
"outs",
"save/restore_all", # 默认恢复所有
"save/Const:0", # 默认常量
"some_path/frozen.pb", # 保存frozen.pb
False,
"",
)
print("done") if __name__ == "__main__":
main()

3. 打印 网络中节点的名字

import tensorflow as tf

if __name__ == "__main__":
checkpoint_path = '../model_fintune/ckpt-1400'
reader = tf.train.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map:
print("tensor name: ", key)
# print(reader.get_tensor(key))

或者通过

import tensorflow as tf

def printTensors(pb_file):

    # read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) # import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def) # print operations
for op in graph.get_operations():
print(op.name) printTensors("path-to-my-pbfile.pb")

4. 两种方法对比

如果是自己的代码训练的模型,有网络结构,有 ckpt 文件,最好是使用第二种方法,使用起来很灵活,可以进行各种自定义,比如修改输入输出的节点名字,网络有多个路径的时候可以自定义输出路径。第一种方法,应该也能达到第二种方法的效果,因为它们本来就是等价的,可能会有些麻烦。第一种方法的好处就是快,不要去翻那些杂糅在一起的网络结构。

两种从 TensorFlow 的 checkpoint生成 frozenpb 的方法的更多相关文章

  1. linux两种增加交换分区(swap)的方法

    在安装Oracle后,为使Oracle流畅运行,需要手动增加linux的交换分区(相当于Windows下的虚拟内存)的大小,本文介绍两种增加交换分区(swap)的方法. 第一种方法:新建分区 1.fd ...

  2. JAVA 中两种判断输入的是否是数字的方法__正则化_

    JAVA 中两种判断输入的是否是数字的方法 package t0806; import java.io.*; import java.util.regex.*; public class zhengz ...

  3. 两种常用的jquery事件加载的方法 的区别

    两种常用的jquery事件加载的方法   $(function(){});  window.onload=function(){}  第一个呢,是在DOM结构渲染完成以后调用的,这时候网页中一些资源还 ...

  4. Android中两种设置全屏或者无标题的方法

    在开发中我们经常需要把我们的应用设置为全屏或者不想要title, 这里是有两种方法的,一种是在代码中设置,另一种方法是在配置文件里改: 一.在代码中设置: package jason.tutor; i ...

  5. JSONP和CORS两种跨域方式的优缺点及使用方法原理介绍

    随着软件开发分工趋于精细,前后端开发分离成为趋势,前端同事负责前端页面的展示及页面逻辑处理,服务端同事负责业务逻辑处理同时通过API为前端提供数据也为前端提供数据的持久化能力,考虑到前后端同事开发工具 ...

  6. 两种解决IE6不支持固定定位的方法

    有两种让IE6支持position:fixed1.用CSS执行表达式 *{margin:0;padding:0;} * html,* html body{ background-image:url(a ...

  7. 【DevCloud · 敏捷智库】两种你必须了解的常见敏捷估算方法

    背景 在某开发团队辅导的回顾会议上,团队成员对于优化估计具体方法上达成了一致意见.询问是否有什么具体的估计方法来做估算. 问题分析 回顾意见上大家对本次Sprint的效果做回顾,其中80%的成员对于本 ...

  8. Django—Form两种解决表单数据无法动态刷新的方法

    一.无法动态更新数据的实例 1. 如下,数据库中创建了班级表和教师表,两张表的对应关系为“多对多” from django.db import models class Classes(models. ...

  9. 【Django】Django—Form两种解决表单数据无法动态刷新的方法

    一.无法动态更新数据的实例 1. 如下,数据库中创建了班级表和教师表,两张表的对应关系为“多对多” from django.db import models class Classes(models. ...

随机推荐

  1. UNIX 版本

    一般UNIX系统都来源于AT&T公司的System V UNIX系统,BSD UNIX或其他类UNIX系统. System V UNIX:当今市场上大多数主要的商业UNIX系统都是基于AT&a ...

  2. Jconsole或者VisualVM监控远程主机(阿里云,jdk11或者8)

    准备: 1 一个war包或者jar包,这里我用springboot的 2 linux环境,安装tomcat,jdk,我用的jdk11和tomcat9,jdk11和8的拷贝权限文件路径有点不一样,这个需 ...

  3. Arduino系列之智能家居蓝牙语音遥控灯(四)

    用到的材料 Arduino uno hc-05   蓝牙模块 安卓手机 安卓APP AMR—voice 通过安卓手机连接Arduino的蓝牙模块Hc-05,通过语音识别软件AMR-voice识别语音, ...

  4. LoadIcon的使用

    LoadIcon msdn: Loads the specified icon resource from the executable (.exe) file associated with an ...

  5. python 函数3(模块)

    1.将函数存储在模块中 1.1.导入整个模块 要将函数导入,得先创建模块,模块 是扩展名为.py的文件,包含要导入到程序中的代码. 首先定义编写一个.py的文件,命名为pizza.py,代码如下: d ...

  6. BZOJ 4034 [HAOI2015]树上操作(欧拉序+线段树)

    题意: 有一棵点数为 N 的树,以点 1 为根,且树点有边权.然后有 M 个 操作,分为三种: 操作 1 :把某个节点 x 的点权增加 a . 操作 2 :把某个节点 x 为根的子树中所有点的点权都增 ...

  7. Disk:磁盘管理之LVM和系统磁盘扩容

    简介 小伙伴们好,好久不见,今天想给大家介绍一下关于磁盘管理的方法和心得:磁盘管理可谓运维工作中的重要内容,主要包括磁盘的合理规划以及扩缩容 常用的磁盘管理方法为LVM(Logical Volume ...

  8. OpenCV3入门(八)图像边缘检测

    1.边缘检测基础 图像的边缘是图像的基本特征,边缘点是灰度阶跃变化的像素点,即灰度值的导数较大或极大的地方,边缘检测是图像识别的第一步.用图像的一阶微分和二阶微分来增强图像的灰度跳变,而边缘也就是灰度 ...

  9. PHP在程序处理过程中动态输出内容

    在安装discuz或其他一些开源产品的时候,在安装数据库时页面上的安装信息都是动态输出出来的,主要通过php两个函数来实现的, flush();ob_flush(); 代码如下 <html xm ...

  10. golang 引入 和 创建 包

    /* 单个包: improt "包目录的路径" 多个包: improt ("包目录的路径", "包目录的路径") improt ( &quo ...