摘要:TensorFlow 模型训练完成后,通常会通过frozen过程保存一个最终的pb模型。

本文分享自华为云社区《TensorFlow pb模型修改和优化》,作者:luchangli。

TensorFlow 模型训练完成后,通常会通过frozen过程保存一个最终的pb模型。保存的pb模型是以GraphDef数据结构保存的,可以序列化保存为二进制pb模型或者文本pbtxt模型。GraphDef本质上是一个DAG有向无环图,里面主要是存放了一个算子node list,每个算子具有名称,attr等内容,以及通过input包含了node之间的连接关系。

整个GraphDef的输入节点是以Placeholder节点来标识的,模型参数权重通常是以Const节点来保存的。不同于onnx,GraphDef没有对输出进行标识,好处是可以通过node_name:idx来引用获取任意一个节点的输出,缺点是一般需要通过netron手动打开查看模型输出,或者通过代码分析没有输出节点的node作为模型输出节点。下面简单介绍下pb模型常用的一些处理方法。

pb模型保存

  1. # write pb model
  2. with tf.io.gfile.GFile(model_path, "wb") as f:
  3. f.write(graph_def.SerializeToString())
  4. # write pbtxt model
  5. tf.io.write_graph(graph_def, os.path.dirname(model_path), os.path.basename(model_path))

创建node

  1. from tensorflow.core.framework import attr_value_pb2
  2. from tensorflow.core.framework import node_def_pb2
  3. from tensorflow.python.framework import tensor_util
  4. pld_node = node_def_pb2.NodeDef()
  5. pld_node.name = name
  6. pld_node.op = "Placeholder"
  7. shape = tf.TensorShape([None, 3, 256, 256])
  8. pld_node.attr["shape"].CopyFrom(attr_value_pb2.AttrValue(shape=shape.as_proto()))
  9. dtype = tf.dtypes.as_dtype("float32")
  10. pld_node.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(type=dtype.as_datatype_enum))
  11. # other commonly used setting
  12. node.input.extend(in_node_names)
  13. node.attr["value"].CopyFrom(
  14. attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
  15. np_array, np_array.type, np_array.shape)))

构建模型和保存

  1. import tensorflow as tf
  2. import numpy as np
  3. tf.compat.v1.disable_eager_execution()
  4. tf.compat.v1.reset_default_graph()
  5. m = 200
  6. k = 256
  7. n = 128
  8. a_shape = [m, k]
  9. b_shape = [k, n]
  10. np.random.seed(0)
  11. input_np = np.random.uniform(low=0.0, high=1.0, size=a_shape).astype("float32")
  12. kernel_np = np.random.uniform(low=0.0, high=1.0, size=b_shape).astype("float32")
  13. # 构建模型
  14. pld1 = tf.compat.v1.placeholder(dtype="float32", shape=a_shape, name="input1")
  15. kernel = tf.constant(kernel_np, dtype="float32")
  16. feed_dict = {pld1: input_np}
  17. result_tf = tf.raw_ops.MatMul(a=pld1, b=kernel, transpose_a=False, transpose_b=False)
  18. with tf.compat.v1.Session() as sess:
  19. results = sess.run(result_tf, feed_dict=feed_dict)
  20. print("results:", results)
  21. # 保存模型
  22. dump_model_name = "matmul_graph.pb"
  23. graph = tf.compat.v1.get_default_graph()
  24. graph_def = graph.as_graph_def()
  25. with tf.io.gfile.GFile(dump_model_name, "wb") as f:
  26. f.write(graph_def.SerializeToString())

当然一般用其他方式而不是raw_ops构建模型。

pb模型读取

  1. from google.protobuf import text_format
  2. graph_def = tf.compat.v1.GraphDef()
  3. # read pb model
  4. with tf.io.gfile.GFile(model_path, "rb") as f:
  5. graph_def.ParseFromString(f.read())
  6. # read pbtxt model
  7. with open(model_path, "r") as pf:
  8. text_format.Parse(pf.read(), graph_def)

node信息打印

常用信息:

  1. node.name
  2. node.op
  3. node.input
  4. node.device
  5. # please ref https://www.tensorflow.org/api_docs/python/tf/compat/v1/AttrValue
  6. node.attr[attr_name].f # b, i, tensor, etc.
  7. # graph_def中node遍历:
  8. for node in graph_def.node:
  9. ##

对于node的input,一般用node_name:idx如node_name:0来表示输入来自上一个算子的第idx个输出。:0省略则是默认为第0个输出。 名称前面加^符号是控制边。这个input是一个string list,这里面的顺序也对应这个node的各个输入的顺序。

创建GraphDef和添加node

  1. graph_def_n = tf.compat.v1.GraphDef()
  2. for node in graph_def_o.node:
  3. node_n = node_def_pb2.NodeDef()
  4. node_n.CopyFrom(node)
  5. graph_def_n.node.extend([node_n])
  6. # you probably need copy other value like version, etc. from old graph
  7. graph_def_n.version = graph_def_o.version
  8. graph_def_n.library.CopyFrom(graph_def_o.library)
  9. graph_def_n.versions.CopyFrom(graph_def_o.versions)

return graph_def_n

没有onnx模型往graph里面添加节点的topo排序要求

设置placeholder的shape

参考前面创建node部分,通过修改Placeholder的shape属性。

模型shape推导

需要导入模型到tf:tf.import_graph_def(graph_def, name='')。当然需要先设置正确的pld的shape。

然后获取node的输出tensor:graph.get_tensor_by_name(node_name + ":0")。

最后可以从tensor里面获取shape和dtype。

pb模型图优化

思路一般比较简单:

1,子图连接关系匹配,比如要匹配conv2d+bn+relu这个pattern连接关系。由于每个node只保存其输入的node连接关系,要进行DFS/BFS遍历图一般需要每个node的输入输出,这可以首先读取所有的node连接关系并根据input信息同时创建一个output信息map。

2,子图替换,先创建新的算子,再把旧的算子替换为新的算子。这个需要创建新的node或者直接修改原来的node。旧的不要的算子可以创建个新图拷贝时丢弃,新的node可以直接extend到graph_def。

3,如果替换为TF内置的算子,算子定义可以参考tensorflow raw_ops中的定义,但是有些属性(例如数据类型attr "T")没有列出来:https://www.tensorflow.org/api_docs/python/tf/raw_ops

当然也可以替换为自定义算子,这就需要用户开发和注册自定义算子:https://www.tensorflow.org/guide/create_op

如上所述,TensorFlow的pb模型修改优化可以直接使用python代码实现,极大简化开发过程。当然TensorFlow也可以注册grappler和post rewrite图优化pass在C++层面进行图优化,后者除了可以用于推理,也可以用于训练优化。

saved model与pb模型的相互转换

可以参考:tensorflow 模型导出总结 - 知乎

saved model保存的是一整个训练图,并且参数没有冻结。而只用于模型推理serving并不需要完整的训练图,并且参数不冻结无法进行转TensorRT等极致优化。当然也可以saved_model->frozen pb->saved model来同时利用两者的优点。

pb转onnx

使用tf2onnx库GitHub - onnx/tensorflow-onnx: Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX

  1. #!/bin/bash
  2. graphdef=input_model.pb
  3. inputs=Placeholder_1:0,Placeholder_2:0
  4. outputs=output0:0,output1:0
  5. output=${graphdef}.onnx
  6. python -m tf2onnx.convert \
  7. --graphdef ${graphdef} \
  8. --output ${output} \
  9. --inputs ${inputs} \
  10. --outputs ${outputs}\
  11. --opset 12

点击关注,第一时间了解华为云新鲜技术~

带你了解TensorFlow pb模型常用处理方法的更多相关文章

  1. 查看tensorflow pb模型文件的节点信息

    查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...

  2. MxNet 模型转Tensorflow pb模型

    用mmdnn实现模型转换 参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af 安装mmdnn pip install mmdnn 准备好mx ...

  3. 查看tensorflow Pb模型所有层的名字

    代码如下: import tensorflow as tf def get_all_layernames(): """get all layers name"& ...

  4. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  5. tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...

  6. [Tensorflow]模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    文章目录 [Tensorflow]模型持久化的原理,将CKPT转为pb文件,使用pb模型预测 一.模型持久化 1.持久化代码实现 convert_variables_to_constants固化模型结 ...

  7. tensorflow c++ API加载.pb模型文件并预测图片

    tensorflow  python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测 #include <iostream> #include <map> # ...

  8. 将keras的h5模型转换为tensorflow的pb模型

    h5_to_pb.py from keras.models import load_model import tensorflow as tf import os import os.path as ...

  9. tensorflow机器学习模型的跨平台上线

    在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法 ...

  10. (四) tensorflow笔记:常用函数说明

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

随机推荐

  1. 给wordpress后台侧栏菜单添加自定义字段的方法

    我们在使用wordpress做网站的时候,难免有一些需要在后台设置侧栏菜单下添加自定义字段的情况.下面就简单说说一下,如何在后台设置侧栏菜单下添加自定义字段? 在这里我们主要是使用wordpress的 ...

  2. Docker磁盘&内存&CPU资源实战

    Docker 资源实战:cpu/内存配置: #查看帮助 docker run --help docker update --help #配置容器使用cpu /内存大小--privileged 给与容器 ...

  3. Mach-O Inside: 命令行工具集 otool objdump od 与 dwarfdump

    1 otool otool 命令行工具用来查看 Mach-O 文件的结构. 1.1 查看文件头 otool -h -v 文件路径 -h选项表明查看 Mach-O 文件头. -v 选项表明将展示的内容进 ...

  4. 阿里云创建BUCKET脚本

    创建BUCKET脚本 安装模块 pip install pymysql pip install aliyun-python-sdk-core pip install aliyun-python-sdk ...

  5. MySQL高级SQL语句

    MySQL高级SQL语句 围绕两张表 Location表 Store_Info表  #select选择  SELECT Store_Name FROM Store_Info;  #distinct去重 ...

  6. http协议与apache

    http协议与apache 1.httpd协议 两台主机通信需要socket文件  yum insatll -y nc  ​  [root@localhost ~]#nc -l 8000  #主机1 ...

  7. 【scipy 基础】--积分和微分方程

    对于手工计算来说,积分计算是非常困难的,对于一些简单的函数,我们可以直接通过已知的积分公式来求解,但在更多的情况下,原函数并没有简单的表达式,因此确定积分的反函数变得非常困难. 另外,相对于微分运算来 ...

  8. AcWing 1064. 小国王

    状态:f[i][j][k]表示第i行放了j个皇帝,状态为k的方案. 那么首先预处理出所有可行的方案,以及两两可以相互转移的答案. 从b状态转移到a状态就是 :\(f[i][j][a] += f[i - ...

  9. Codeforces Round #707 (Div. 2)A~C题解

    写在前边 链接:Codeforces Round #707 (Div. 2) 心态真的越来越不好了,看A没看懂,赛后模拟了一遍就过了,B很简单,但是漏了个判断重复的条件. A. Alexey and ...

  10. CentOS 7替换默认软件源

    安装CentOS 7后,默认源在国外,可以替换为国内的源以提升访问速度 参考https://mirrors.ustc.edu.cn/help/centos.html sudo vi /etc/yum. ...