前言

TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile)、嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具。之前想部署tensorflow模型,需要转换成tflite模型。

实现过程

1.不同模型的调用函数接口稍微有些不同

  1. # Converting a SavedModel to a TensorFlow Lite model.
  2. converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
  3. tflite_model = converter.convert()
  4.  
  5. # Converting a tf.Keras model to a TensorFlow Lite model.
  6. converter = lite.TFLiteConverter.from_keras_model(model)
  7. tflite_model = converter.convert()
  8.  
  9. # Converting ConcreteFunctions to a TensorFlow Lite model.
  10. converter = lite.TFLiteConverter.from_concrete_functions([func])
  11. tflite_model = converter.convert()

2. 完整的实现

  1. import tensorflow as tf
  2. saved_model_dir = './mobilenet/'
  3. converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
  4. converter.experimental_new_converter = True
  5. tflite_model = converter.convert()
  6. open('model_tflite.tflite', 'wb').write(tflite_model)

其中,

  1. @classmethod
  2. from_saved_model(
  3. cls,
  4. saved_model_dir,
  5. signature_keys=None,
  6. tags=None
  7. )

另外

  1. For more complex SavedModels, the optional parameters that can be passed into TFLiteConverter.from_saved_model() are input_arrays, input_shapes, output_arrays, tag_set and signature_key. Details of each parameter are available by running help(tf.lite.TFLiteConverter).

对于如何查看模型的操作op,可查看here;

help(tf.lite.TFLiteConverter)结果

  1. Help on class TFLiteConverterV2 in module tensorflow.lite.python.lite:
  2.  
  3. class TFLiteConverterV2(TFLiteConverterBase)
  4. | TFLiteConverterV2(funcs, trackable_obj=None)
  5. |
  6. | Converts a TensorFlow model into TensorFlow Lite model.
  7. |
  8. | Attributes:
  9. | allow_custom_ops: Boolean indicating whether to allow custom operations.
  10. | When false any unknown operation is an error. When true, custom ops are
  11. | created for any op that is unknown. The developer will need to provide
  12. | these to the TensorFlow Lite runtime with a custom resolver.
  13. | (default False)
  14. | target_spec: Experimental flag, subject to change. Specification of target
  15. | device.
  16. | optimizations: Experimental flag, subject to change. A list of optimizations
  17. | to apply when converting the model. E.g. `[Optimize.DEFAULT]
  18. | representative_dataset: A representative dataset that can be used to
  19. | generate input and output samples for the model. The converter can use the
  20. | dataset to evaluate different optimizations.
  21. | experimental_enable_mlir_converter: Experimental flag, subject to change.
  22. | Enables the MLIR converter instead of the TOCO converter.
  23. |
  24. | Example usage:
  25. |
  26. | ```python
  27. | # Converting a SavedModel to a TensorFlow Lite model.
  28. | converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
  29. | tflite_model = converter.convert()
  30. |
  31. | # Converting a tf.Keras model to a TensorFlow Lite model.
  32. | converter = lite.TFLiteConverter.from_keras_model(model)
  33. | tflite_model = converter.convert()
  34. |
  35. | # Converting ConcreteFunctions to a TensorFlow Lite model.
  36. | converter = lite.TFLiteConverter.from_concrete_functions([func])
  37. | tflite_model = converter.convert()
  38. | ```
  39. |
  40. | Method resolution order:
  41. | TFLiteConverterV2
  42. | TFLiteConverterBase
  43. | builtins.object
  44. |
  45. | Methods defined here:
  46. |
  47. | __init__(self, funcs, trackable_obj=None)
  48. | Constructor for TFLiteConverter.
  49. |
  50. | Args:
  51. | funcs: List of TensorFlow ConcreteFunctions. The list should not contain
  52. | duplicate elements.
  53. | trackable_obj: tf.AutoTrackable object associated with `funcs`. A
  54. | reference to this object needs to be maintained so that Variables do not
  55. | get garbage collected since functions have a weak reference to
  56. | Variables. This is only required when the tf.AutoTrackable object is not
  57. | maintained by the user (e.g. `from_saved_model`).
  58. |
  59. | convert(self)
  60. | Converts a TensorFlow GraphDef based on instance variables.
  61. |
  62. | Returns:
  63. | The converted data in serialized format.
  64. |
  65. | Raises:
  66. | ValueError:
  67. | Multiple concrete functions are specified.
  68. | Input shape is not specified.
  69. | Invalid quantization parameters.
  70. |
  71. | ----------------------------------------------------------------------
  72. | Class methods defined here:
  73. |
  74. | from_concrete_functions(funcs) from builtins.type
  75. | Creates a TFLiteConverter object from ConcreteFunctions.
  76. |
  77. | Args:
  78. | funcs: List of TensorFlow ConcreteFunctions. The list should not contain
  79. | duplicate elements.
  80. |
  81. | Returns:
  82. | TFLiteConverter object.
  83. |
  84. | Raises:
  85. | Invalid input type.
  86. |
  87. | from_keras_model(model) from builtins.type
  88. | Creates a TFLiteConverter object from a Keras model.
  89. |
  90. | Args:
  91. | model: tf.Keras.Model
  92. |
  93. | Returns:
  94. | TFLiteConverter object.
  95. |
  96. | from_saved_model(saved_model_dir, signature_keys=None, tags=None) from builtins.type
  97. | Creates a TFLiteConverter object from a SavedModel directory.
  98. |
  99. | Args:
  100. | saved_model_dir: SavedModel directory to convert.
  101. | signature_keys: List of keys identifying SignatureDef containing inputs
  102. | and outputs. Elements should not be duplicated. By default the
  103. | `signatures` attribute of the MetaGraphdef is used. (default
  104. | saved_model.signatures)
  105. | tags: Set of tags identifying the MetaGraphDef within the SavedModel to
  106. | analyze. All tags in the tag set must be present. (default set(SERVING))
  107. |
  108. | Returns:
  109. | TFLiteConverter object.
  110. |
  111. | Raises:
  112. | Invalid signature keys.
  113. |
  114. | ----------------------------------------------------------------------
  115. | Data descriptors inherited from TFLiteConverterBase:
  116. | __dict__
  117. | dictionary for instance variables (if defined)
  118. |
  119. | __weakref__
  120. | list of weak references to the object (if defined)

问题:

使用tf_saved_model中生成mobilenet网络模型转换成tfLite能够成功,为什么使用另一个设计的模型进行转换却出现问题了呢??

  1. Traceback (most recent call last):
  2. File "pb2tflite.py", line , in <module>
  3. tflite_model = converter.convert()
  4. File "~/.local/lib/python3.7/site-packages/tensorflow_core/lite/python/lite.py", line , in convert
  5. "invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list))
  6. ValueError: None is only supported in the 1st dimension. Tensor 'images' has invalid shape '[None, None, None, None]'.

facebox模型节点:

  1. (tf_test) ~/workspace/test_code/github_test/faceboxes-tensorflow$ saved_model_cli show --dir model/detector/ --all
  2.  
  3. MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
  4.  
  5. signature_def['__saved_model_init_op']:
  6. The given SavedModel SignatureDef contains the following input(s):
  7. The given SavedModel SignatureDef contains the following output(s):
  8. outputs['__saved_model_init_op'] tensor_info:
  9. dtype: DT_INVALID
  10. shape: unknown_rank
  11. name: NoOp
  12. Method name is:
  13.  
  14. signature_def['serving_default']:
  15. The given SavedModel SignatureDef contains the following input(s):
  16. inputs['images'] tensor_info:
  17. dtype: DT_FLOAT
  18. shape: (-, -, -, -)
  19. name: serving_default_images:
  20. The given SavedModel SignatureDef contains the following output(s):
  21. outputs['boxes'] tensor_info:
  22. dtype: DT_FLOAT
  23. shape: (-, , )
  24. name: StatefulPartitionedCall:
  25. outputs['num_boxes'] tensor_info:
  26. dtype: DT_INT32
  27. shape: (-)
  28. name: StatefulPartitionedCall:
  29. outputs['scores'] tensor_info:
  30. dtype: DT_FLOAT
  31. shape: (-, )
  32. name: StatefulPartitionedCall:
  33. Method name is: tensorflow/serving/predict

mobilenet的模型节点

  1. ~/workspace/test_code/github_test/faceboxes-tensorflow/mobilenet$ saved_model_cli show --dir ./ --all
  2.  
  3. MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
  4.  
  5. signature_def['__saved_model_init_op']:
  6. The given SavedModel SignatureDef contains the following input(s):
  7. The given SavedModel SignatureDef contains the following output(s):
  8. outputs['__saved_model_init_op'] tensor_info:
  9. dtype: DT_INVALID
  10. shape: unknown_rank
  11. name: NoOp
  12. Method name is:
  13.  
  14. signature_def['serving_default']:
  15. The given SavedModel SignatureDef contains the following input(s):
  16. inputs['input_1'] tensor_info:
  17. dtype: DT_FLOAT
  18. shape: (-, , , )
  19. name: serving_default_input_1:
  20. The given SavedModel SignatureDef contains the following output(s):
  21. outputs['act_softmax'] tensor_info:
  22. dtype: DT_FLOAT
  23. shape: (-, )
  24. name: StatefulPartitionedCall:
  25. Method name is: tensorflow/serving/predict

得到大神指点,tflite是静态图,需要指定hwc的值,在此谢过,那么问题来了,怎么指定hwc呢?

  1. import tensorflow as tf
  2. saved_model_dir = './model/detector/'
  3. model = tf.saved_model.load(saved_model_dir)
  4. concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  5. concrete_func.inputs[0].set_shape([1, 512, 512, 3])
  6. converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
  7. # converter.experimental_new_converter = True
  8. tflite_model = converter.convert()
  9. open('model_tflite_facebox.tflite', 'wb').write(tflite_model)

error

  1. Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If those are native TensorFlow operators, you might be able to use the extended runtime by passing --enable_select_tf_ops, or by setting target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling tf.lite.TFLiteConverter(). Otherwise, if you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.lite.TFLiteConverter(). Here is a list of builtin operators you are using: ADD, AVERAGE_POOL_2D, CONCATENATION, CONV_2D, MAXIMUM, MINIMUM, MUL, NEG, PACK, RELU, RESHAPE, SOFTMAX, STRIDED_SLICE, SUB, UNPACK. Here is a list of operators for which you will need custom implementations: TensorListFromTensor, TensorListReserve, TensorListStack, While.

TensorFlow Lite 已经内置了很多运算符,并且还在不断扩展,但是仍然还有一部分 TensorFlow 运算符没有被 TensorFlow Lite 原生支持。这些不被支持的运算符会给 TensorFlow Lite 的模型转换带来一些阻力。

  1. import tensorflow as tf
  2. saved_model_dir = './model/detector/'
  3. model = tf.saved_model.load(saved_model_dir)
  4. concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  5. concrete_func.inputs[0].set_shape([1, 512, 512, 3])
  6. converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
  7. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
  8. # converter.experimental_new_converter = True
  9. tflite_model = converter.convert()
  10. open('model_tflite_facebox.tflite', 'wb').write(tflite_model)

还是有点问题。。。

参考

1. tf.lite.TFLiteConverter

2. stackoverflow_how-to-create-a-tflite-file-from-saved-model-ssd-mobilenet;

3. tfv1-模型文件转换

4. github_keras_lstm;

5. tf_saved_model;

6. tf_tflite_get_start;

7. tflite_convert_python_api;

8. ops_select;

【tensorflow-v2.0】如何将模型转换成tflite模型的更多相关文章

  1. 「新手必看」Python+Opencv实现摄像头调用RGB图像并转换成HSV模型

    在ROS机器人的应用开发中,调用摄像头进行机器视觉处理是比较常见的方法,现在把利用opencv和python语言实现摄像头调用并转换成HSV模型的方法分享出来,希望能对学习ROS机器人的新手们一点帮助 ...

  2. 三分钟快速上手TensorFlow 2.0 (中)——常用模块和模型的部署

    本文学习笔记参照来源:https://tf.wiki/zh/basic/basic.html 前文:三分钟快速上手TensorFlow 2.0 (上)——前置基础.模型建立与可视化 tf.train. ...

  3. TensorFlow v2.0实现逻辑斯谛回归

    使用TensorFlow v2.0实现逻辑斯谛回归 此示例使用简单方法来更好地理解训练过程背后的所有机制 MNIST数据集概览 此示例使用MNIST手写数字.该数据集包含60,000个用于训练的样本和 ...

  4. 利用反射将Datatable、SqlDataReader转换成List模型

    1. DataTable转IList public class DataTableToList<T>whereT :new() { ///<summary> ///利用反射将D ...

  5. h5模型文件转换成pb模型文件

      本文主要记录Keras训练得到的.h5模型文件转换成TensorFlow的.pb文件 #*-coding:utf-8-* """ 将keras的.h5的模型文件,转换 ...

  6. 三分钟快速上手TensorFlow 2.0 (上)——前置基础、模型建立与可视化

    本文学习笔记参照来源:https://tf.wiki/zh/basic/basic.html 学习笔记类似提纲,具体细节参照上文链接 一些前置的基础 随机数 tf.random uniform(sha ...

  7. 使用TensorFlow v2.0构建多层感知器

    使用TensorFlow v2.0构建一个两层隐藏层完全连接的神经网络(多层感知器). 这个例子使用低级方法来更好地理解构建神经网络和训练过程背后的所有机制. 神经网络概述 MNIST 数据集概述 此 ...

  8. 使用TensorFlow v2.0构建卷积神经网络

    使用TensorFlow v2.0构建卷积神经网络. 这个例子使用低级方法来更好地理解构建卷积神经网络和训练过程背后的所有机制. CNN 概述 MNIST 数据集概述 此示例使用手写数字的MNIST数 ...

  9. TensorFlow v2.0实现Word2Vec算法

    使用TensorFlow v2.0实现Word2Vec算法计算单词的向量表示,这个例子是使用一小部分维基百科文章来训练的. 更多信息请查看论文: Mikolov, Tomas et al. " ...

随机推荐

  1. Angular pagination分页模块 只提供分页参数处理 不处理分页记录数据

    <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  2. npm install 命令。默认会找到当前路径下的package.json。然后安装其中的依赖

    npm install 命令.默认会找到当前路径下的package.json.然后安装其中的依赖 By default, npm install will install all modules li ...

  3. materialize 读取单选按钮

    $('input[name='xxx']:checked')

  4. Game of Cards Gym - 101128G (SG函数)

    Problem G: Game of Cards \[ Time Limit: 1 s \quad Memory Limit: 256 MiB \] 题意 题意就是给出\(n\)堆扑克牌,然后给出一个 ...

  5. WinDbg常用命令系列---!uniqstack

    简介 这个!uniqstack扩展扩展显示的所有线程的堆栈的所有当前进程,不包括显示为具有重复项的堆栈中. 使用形式 !uniqstack [ -b | -v | -p ] [ -n ] 参数 -b将 ...

  6. Noip2018/Csp2019 ------退役记

    退役记 上记 不知道为啥,自从今下午某大佬的人生第一次政治运动(虽然最后被镇压,现在小命难保)后,仿佛有一种看破感. 以下有点在自作多情,不喜者可以不看. 学信竞快一年了.可以说有收获也有失去吧. 收 ...

  7. 洛谷P4170 [CQOI2007]涂色题解

    废话: 这个题我第一眼看就是贪心呐, 可能是我之前那做过一道类似的题这俩题都是关于染色的 现在由于我帅气无比的学长的指导, 我已经豁然开朗, 这题贪心不对啊, 当时感觉自己好厉害贪心都能想出来 差点就 ...

  8. gethostname、gethostbyname

    gethostname():返回本地主机的标准主机名 原型: #include<unistd.h> int gethostname(char *name, size_t len); 参数说 ...

  9. 市值TOP10,人类进化及中美坐标

    题记:观察人类进化,以及各国.各民族在这个进化中所起的作用.所处的位置,市值 TOP 10 的变迁,会是一个再好不过的指标! 2008年,经历了全球金融危机后,原油期货一路飙升,创出了147.27美元 ...

  10. 【Gamma】Scrum Meeting 4 & 助教参会记录

    目录 前言 任务分配 燃尽图 会议照片 签入记录 上周助教交流总结 技术博客 一些说明 前言 第4次会议于5月29日22:00线上交流形式召开. 交流确认了各自的任务进度,并与助教进行了沟通.时长20 ...