模型保存为单一个pb文件

背景

参考连接: https://www.yuque.com/g/jesse-ztr2k/nkke46/ss4rlv/collaborator/join?token=XUVZNORisVWEWyst#

注意有些时候需要添加一个pb文件。 而不是tensorflow 提供的save 方法生成的一个目录里面包含了若干pb文件。

load时候直接填写这个目录即可。 但是有些时候需要合成一个pb文件。

tf2生成pb 目录描述

  1 目录结构

-assets

-variables

-variables.data-00000-of-00001

-variables.index

-saved_model.pb

  2 作用

    其中 variables 记录模型参数 , pb文件记录模型结构

tf2 都是保存的 权重和 结构分开的, 如果需要兼容tf V1的代码,即导入一个pb文件,就需要 1 )保存常量计算图 2)frozen graph  pb格式。

tf1 生成pb脚本

环境准备:

注意 一定在tf v1 环境下生成pb

  1 import cv2
2 import numpy as np
3 import tensorflow as tf
4 import os
5 from tensorflow.python.framework import graph_util
6
7 # 参考连接 https://blog.csdn.net/tensorflowforum/article/details/112352764 代码
8 # 参考连接 参数详解:https://blog.csdn.net/weixin_43529465/article/details/124721583
9 # https://blog.csdn.net/rain6789/article/details/78754516
10
11 class SingleCnn(tf.keras.Model):
12 def __init__(self):
13 super(SingleCnn, self).__init__()
14 # filters=1 卷积核数目,相当于卷积核的channel
15 self.conv = tf.keras.layers.Conv2D(filters=1,
16 kernel_size=[1, 1],
17 # valid表示不填充, same表示合理填充
18 padding='valid',
19 # data_format='channels_last',-> 表示HWC,输入可以定义批次
20 data_format='channels_last',
21 use_bias=False,
22 kernel_initializer=tf.keras.initializers.he_uniform(seed=None),
23 name="conv")
24
25 def call(self, inputs):
26 x = self.conv(inputs)
27 return x
28 if __name__ == "__main__":
29 # 构建场景输入数据
30
31 # images=tf.random.uniform((1, 300, 300, 3))
32
33 # 图像数据
34 imagefile = r"catanddog\cat\5.JPG"
35 img = cv2.imread(imagefile)
36 img = cv2.resize(img, (64, 64))
37 img = np.expand_dims(img, axis=0)
38 print(img.shape, type(img), img.dtype)
39
40 # 未量化的model不支持int32和int8
41 # img = img.astype(np.int32)
42 img = tf.convert_to_tensor(img, np.float32)
43 print(img.shape, type(img), img.dtype)
44 singlecnn = SingleCnn()
45
46 output = singlecnn(img)
47 print(output.shape, type(output))
48 print(output[0][2:10][2:6])
49 # =========== ckpt保存 with session的写法tf2 已不再使用 ===========
50 # with tf.Session(graph=tf.Graph()) as sess:
51 # constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
52
53 # 保存参考 https://zhuanlan.zhihu.com/p/146243327
54 # save_format='tf' 代表保存pb
55 # singlecnn.save('./pbmodel/singlecnn', save_format='tf')
56 # tf.saved_model.save(singlecnn, './pbmodel/singlecnn')
57 tf.keras.models.save_model(singlecnn, './pbmodel/singlecnn_0',
58 save_format="tf",
59 include_optimizer=False, save_traces=False)
60
61 # 加载模型 验证可以加载
62 new_model = tf.keras.models.load_model('./pbmodel/singlecnn_0', compile=False)
63 # new_model = tf.saved_model.load('./pbmodel/singlecnn_0')
64 # output_ = new_model(img)
65 # # print(output_.shape, output_[0][2:6][2:6])
66 # print(output_.shape)
67 #
68 # 查看结构
69 new_model.summary()
70
71 # print("----------------")
72 # # 加载模型
73 # saved_model = tf.saved_model.load('./pbmodel/singlecnn_0')
74 # # 将模型转换为pb格式 还是目录方法。
75 # converter = tf.saved_model.save(saved_model, "model.pb")
76
77 def change_pb(pretrained_model):
78 """tf v1 选用tf1 跑这个脚本生成pb"""
79 from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
80 # 重点
81 # Convert Keras model to ConcreteFunction
82 # MobileNet is a function
83 full_model = tf.function(lambda x: pretrained_model(x))
84
85 # 指定shape和dtype对tf function进行重新追踪
86 full_model = full_model.get_concrete_function(
87 tf.TensorSpec(pretrained_model.inputs[0].shape, pretrained_model.inputs[0].dtype))
88
89 # Get frozen ConcreteFunction,将计算图中的变量及其取值通过常量的方式保存
90 frozen_func = convert_variables_to_constants_v2(full_model)
91 frozen_func.graph.as_graph_def()
92
93 layers = [op.name for op in frozen_func.graph.get_operations()]
94 print("-" * 50)
95 print("Frozen model layers: ")
96 for layer in layers:
97 print(layer)
98
99 print("-" * 50)
100 print("Frozen model inputs: ")
101 print(frozen_func.inputs)
102 print("Frozen model outputs: ")
103 print(frozen_func.outputs)
104
105 # Save frozen graph from frozen ConcreteFunction to hard drive
106 # as_text: If True, writes the graph as an ASCII proto; otherwise, The graph is written as a text proto
107 tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
108 logdir="./frozen_models",
109 name="frozen_graph.pb",
110 as_text=True)
111
112
113 change_pb(new_model)

model_getpb

python download_and_convert_data.py --dataset_name=flowers --dataset_dir="tmp/dataset"

:)模型保存为单一个pb文件的更多相关文章

  1. 读取.properties配置文件并保存到另一个.properties文件内

    代码如下 import java.io.BufferedInputStream; import java.io.FileInputStream; import java.io.FileOutputSt ...

  2. tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...

  3. tensorflow实战笔记(19)----使用freeze_graph.py将ckpt转为pb文件

    一.作用: https://blog.csdn.net/yjl9122/article/details/78341689 这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合 ...

  4. 如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件

    这篇薄荷主要是讲了如何用tensorflow去训练好一个模型,然后生成相应的pb文件.最后会将如何重新加载这个pb文件. 首先先放出PO主的github: https://github.com/ppp ...

  5. keras中的模型保存和加载

    tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...

  6. tensorflow 保存训练模型ckpt 查看ckpt文件中的变量名和对应值

    TensorFlow 模型保存与恢复 一个快速完整的教程,以保存和恢复Tensorflow模型. 在本教程中,我将会解释: TensorFlow模型是什么样的? 如何保存TensorFlow模型? 如 ...

  7. 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)

    学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...

  8. 把ResNet-L152模型的ckpt文件转化为pb文件

    import tensorflow as tf from tensorflow.python.tools import freeze_graph #os.environ['CUDA_VISIBLE_D ...

  9. Express下使用formidable实现POST表单上传文件并保存

    Express下使用formidable实现POST表单上传文件并保存 在上一篇文章中使用formidable实现了上传文件,但没将它保存下来. 一开始,我也以为是只得到了文件的相关信息,需要用fs. ...

  10. 今天在Mac机器上使用了Flex Builder编辑了一个源代码文件,保存后使用vim命令去打开时发现系统自动在每一行的结尾添加了^M符号,其实^M在Linux/Unix中是非常常见的,也就是我们在Win中见过的/r回车符号。由于编辑软件的编码问题,某些IDE的编辑器在编辑完文件之后会自动加上这个^M符号。看起来对我们的源代码没有任何影响,其实并不然,当我们把源代码文件Check In到svn之类

    今天在Mac机器上使用了Flex Builder编辑了一个源代码文件,保存后使用vim命令去打开时发现系统自动在每一行的结尾添加了^M符号,其实^M在Linux/Unix中是非常常见的,也就是我们在W ...

随机推荐

  1. 题解 P7623 [AHOI2021初中组] 收衣服

    我还在小学的时候以现在初中名义我大五十牛逼参加了这次,然后身败名裂死磕这道题不会,现在觉得自己好傻啊 233333 显然这是要统计每个区间的贡献,所以我们可以打出来这个暴力,统计每个区间的次数,对于 ...

  2. echar 多个图形显示时,点击显示隐藏然后样式缺失,变得非常小

    原因:Echarts 图表是根据你定义的div 的样式来确定图表的大小,当图表隐藏时,Echarts会找不到div的宽和高,再次显示时它会给自己一个非常小的默认宽高值,所以在隐藏显示后会发现它变得非常 ...

  3. ABP微服务系列学习-搭建自己的微服务结构(三)

    上一篇我们基础服务初步搭建完毕,接下来我们整一下认证和网关. 搭建认证服务 认证服务的话,ABP CLI生成的所有模板都包括了一个AuthServer.我们直接生成模板然后微调一下就可以直接用了. a ...

  4. Abp返回时间格式化

    private void ConfigureDateTime() { Configure<MvcNewtonsoftJsonOptions>(options => { options ...

  5. 如何用HP 39GS计算器画出双曲线图像

    1.双曲线标准方程和参数方程 2.计算器上的操作 1.打开APLET->Parametric->START 2.设置X1(T)=3/COS(T),X2(T)=4*TAN(T) 3.SHIF ...

  6. SQL group by date (hour),数据库按小时分组统计数据量

    SELECT COUNT(1), TRUNC(BEGINTIME, 'HH24') FROM TASK -- WHERE BEGINTIME > '2022-03-01' GROUP BY TR ...

  7. 代码随想录训练营day 1 |704 二分查找 27移除算法

    LeetCode 704.二分查找(C++) 题目链接 704.二分查找 题目描述:给定一个 n 个元素有序的(升序)整型数组 nums 和一个目标值 target ,写一个函数搜索 nums 中的 ...

  8. doskey: windows版 Alias

    1.编辑doskey.bat文件 2.打开注册表寻找.HKEY_CURRENT_USER \ Software \ Microsoft \ Command Processor (自行百度) 3.添加d ...

  9. Windows 设置当前路径 临时环境变量 查看、修改、删除与添加

    需求 有些程序依赖的Python版本不同,安装了Python2.7和Python3.10(3.x没有向下兼容),需要设置当前路径的 python 版本(指定使用2或3). 也不止Python,类似的情 ...

  10. 生成数据库文档 —— Spring Boot + Screw

    1.创建一个SpringBoot项目(本人使用的是IntelliJ IDEA 2020.1 x64) 最佳简单的项目配置如下: 2.添加相关依赖 <!--screw依赖--> <de ...