定义模型结构

首先使用 PyTorch 定义一个简单的网络模型:

  1. class ConvBnReluBlock(nn.Module):
  2. def __init__(self) -> None:
  3. super().__init__()
  4. self.conv1 = nn.Conv2d(3, 64, 3)
  5. self.bn1 = nn.BatchNorm2d(64)
  6. self.maxpool1 = nn.MaxPool2d(3, 1)
  7. self.conv2 = nn.Conv2d(64, 32, 3)
  8. self.bn2 = nn.BatchNorm2d(32)
  9. self.relu = nn.ReLU()
  10. def forward(self, x):
  11. out = self.conv1(x)
  12. out = self.bn1(out)
  13. out = self.relu(out)
  14. out = self.maxpool1(out)
  15. out = self.conv2(out)
  16. out = self.bn2(out)
  17. out = self.relu(out)
  18. return out

在导出模型之前,需要提前定义一些变量:

  1. model = ConvBnReluBlock() # 定义模型对象
  2. x = torch.randn(2, 3, 255, 255) # 定义输入张量

然后使用 PyTorch 官方 API(torch.onnx.export)导出 ONNX 格式的模型:

  1. # way1:
  2. torch.onnx.export(model, (x), "conv_bn_relu_evalmode.onnx", input_names=["input"], output_names=['output'])
  3. # way2:
  4. import torch._C as _C
  5. TrainingMode = _C._onnx.TrainingMode
  6. torch.onnx.export(model, (x), "conv_bn_relu_trainmode.onnx", input_names=["input"], output_names=['output'],
  7. opset_version=12, # 默认版本为9,但是如果低于12,将不能正确导出 Dropout 和 BatchNorm 节点
  8. training=TrainingMode.TRAINING, # 默认模式为 TrainingMode.EVAL
  9. do_constant_folding=False) # 常量折叠,默认为True,但是如果使用TrainingMode.TRAINING模式,则需要将其关闭
  10. # way3
  11. torch.onnx.export(model,
  12. (x),
  13. "conv_bn_relu_dynamic.onnx",
  14. input_names=['input'],
  15. output_names=['output'],
  16. dynamic_axes={'input': {0: 'batch_size', 2: 'input_width', 3: 'input_height'},
  17. 'output': {0: 'batch_size', 2: 'output_width', 3: 'output_height'}})

可以看到,这里主要以三种方式导出模型,下面分别介绍区别:

  • way1:如果模型中存在 BatchNorm 或者 Dropout,我们在导出模型前会首先将其设置成 eval 模式,但是这里我们即使忘记设置也无所谓,因为在导出模型时会自动设置(export函数中training参数的默认值为TrainingMode.EVAL)。
  • way2:如果我们想导出完整的模型结构,包括 BatchNorm 和 Dropout,则应该将 training 属性设置为 train 模式。
  • way3:如果想要导出动态输入的模型结构,则需要设置 dynamic_axes 属性,比如这里我们将第一、三和四维设置成动态结构,那么我们就可以输入任何Batch大小、任何长宽尺度的RGB图像。

下图分别将这三种导出方式的模型结构使用 Netron 可视化:

分析模型结构

这里参考了BBuf大佬的讲解:【传送门:https://zhuanlan.zhihu.com/p/346511883】

接下来主要针对 way1 方式导出的ONNX模型进行深入分析。

ONNX格式定义:https://github.com/onnx/onnx/blob/master/onnx/onnx.proto

在这个文件中,定义了多个核心对象:ModelProto、GraphProto、NodeProto、ValueInfoProto、TensorProto 和 AttributeProto。

在加载ONNX模型之后,就获得了一个ModelProto,其中包含一些

  • 版本信息(本例中:ir_version = 7)
  • 生成者信息:producer_name: pytorch,producer_version: 1.10,这两个属性主要用来说明由哪些框架哪个版本导出的onnx
  • 核心组件:GraphProto

在 GraphProto 中,有如下几个属性需要注意:

  • name:本例中:name = 'torch-jit-export'
  • input 数组:
    1. [name: "input"
    2. type {
    3. tensor_type {
    4. elem_type: 1
    5. shape {
    6. dim {
    7. dim_value: 2
    8. }
    9. dim {
    10. dim_value: 3
    11. }
    12. dim {
    13. dim_value: 255
    14. }
    15. dim {
    16. dim_value: 255
    17. }
    18. }
    19. }
    20. }
    21. ]
  • output 数组:
    1. [name: "output"
    2. type {
    3. tensor_type {
    4. elem_type: 1
    5. shape {
    6. dim {
    7. dim_value: 2
    8. }
    9. dim {
    10. dim_value: 32
    11. }
    12. dim {
    13. dim_value: 249
    14. }
    15. dim {
    16. dim_value: 249
    17. }
    18. }
    19. }
    20. }
    21. ]
  • node 数组,该数组中包含了模型中所有的计算节点(本例中:"Conv_0"、"Relu_1"、"MaxPool_2"、"Conv_3"、"Relu_4"),以及各个节点的属性,:
    1. [input: "input"
    2. input: "23"
    3. input: "24"
    4. output: "22"
    5. name: "Conv_0"
    6. op_type: "Conv"
    7. attribute {
    8. name: "dilations"
    9. ints: 1
    10. ints: 1
    11. type: INTS
    12. }
    13. attribute {
    14. name: "group"
    15. i: 1
    16. type: INT
    17. }
    18. attribute {
    19. name: "kernel_shape"
    20. ints: 3
    21. ints: 3
    22. type: INTS
    23. }
    24. attribute {
    25. name: "pads"
    26. ints: 0
    27. ints: 0
    28. ints: 0
    29. ints: 0
    30. type: INTS
    31. }
    32. attribute {
    33. name: "strides"
    34. ints: 1
    35. ints: 1
    36. type: INTS
    37. }
    38. ,
    39. input: "22"
    40. output: "17"
    41. name: "Relu_1"
    42. op_type: "Relu"
    43. , input: "17"
    44. output: "18"
    45. name: "MaxPool_2"
    46. op_type: "MaxPool"
    47. attribute {
    48. name: "kernel_shape"
    49. ints: 3
    50. ints: 3
    51. type: INTS
    52. }
    53. attribute {
    54. name: "pads"
    55. ints: 0
    56. ints: 0
    57. ints: 0
    58. ints: 0
    59. type: INTS
    60. }
    61. attribute {
    62. name: "strides"
    63. ints: 1
    64. ints: 1
    65. type: INTS
    66. }
    67. ,
    68. input: "18"
    69. input: "26"
    70. input: "27"
    71. output: "25"
    72. name: "Conv_3"
    73. op_type: "Conv"
    74. attribute {
    75. name: "dilations"
    76. ints: 1
    77. ints: 1
    78. type: INTS
    79. }
    80. attribute {
    81. name: "group"
    82. i: 1
    83. type: INT
    84. }
    85. attribute {
    86. name: "kernel_shape"
    87. ints: 3
    88. ints: 3
    89. type: INTS
    90. }
    91. attribute {
    92. name: "pads"
    93. ints: 0
    94. ints: 0
    95. ints: 0
    96. ints: 0
    97. type: INTS
    98. }
    99. attribute {
    100. name: "strides"
    101. ints: 1
    102. ints: 1
    103. type: INTS
    104. }
    105. ,
    106. input: "25"
    107. output: "output"
    108. name: "Relu_4"
    109. op_type: "Relu"
    110. ]

    通过以上 node 的输入输出信息,可提取出节点之间的拓扑关系,构建出一个完整的神经网络。

  • initializer 数组:存放模型的权重参数。
    1. [dims: 64
    2. dims: 3
    3. dims: 3
    4. dims: 3
    5. data_type: 1
    6. name: "23"
    7. raw_data: "\220\251\001>\232\326&>\253\227\372 ... 省略一眼望不到头的内容 ... "
    8. dims: 64
    9. data_type: 1
    10. name: "24"
    11. raw_data: "Rt\347\275\005\203\0 ..."
    12. dims: 32
    13. dims: 64
    14. dims: 3
    15. dims: 3
    16. data_type: 1
    17. name: "26"
    18. raw_data: "9\022\273;+^\004\2 ..."
    19. ...

至此,我们已经分析完 GraphProto 的内容,下面根据图中的一个节点可视化说明以上内容:

从图中可以发现,Conv 节点的输入包含三个部分:输入的图像(input)、权重(这里以数字23代表该节点权重W的名字)以及偏置(这里以数字24表示该节点偏置B的名字);输出内容的名字为22;属性信息包括dilations、group、kernel_shape、pads和strides,不同节点会具有不同的属性信息。在initializer数组中,我们可以找到该Conv节点权重(name:23)对应的值(raw_data),并且可以清楚地看到维度信息(64X3X3X3)。

【推理引擎】ONNX 模型解析的更多相关文章

  1. 阿里开源!轻量级深度学习端侧推理引擎 MNN

    阿里妹导读:近日,阿里正式开源轻量级深度学习端侧推理引擎“MNN”. AI科学家贾扬清如此评价道:“与 Tensorflow.Caffe2 等同时覆盖训练和推理的通用框架相比,MNN 更注重在推理时的 ...

  2. 阿里开源首个移动AI项目,淘宝同款推理引擎

    淘宝上用的移动AI技术,你也可以用在自己的产品中了. 刚刚,阿里巴巴宣布,开源自家轻量级的深度神经网络推理引擎MNN(Mobile Neural Network),用于在智能手机.IoT设备等端侧加载 ...

  3. 滴滴推理引擎IFX:千万规模设备下AI部署实践

    桔妹导读:「滴滴技术」将于本月开始,联合各技术团队为大家带来精彩分享.你想了解的技术干货,深度专访,团队及招聘将于每周三与你准时见面.本月为「滴滴云平台事业群分享月」,在今天的内容中,云平台事业群-机 ...

  4. 【推理引擎】从源码看ONNXRuntime的执行流程

    目录 前言 准备工作 构造 InferenceSession 对象 & 初始化 让模型 Run 总结 前言 在上一篇博客中:[推理引擎]ONNXRuntime 的架构设计,主要从文档上对ONN ...

  5. 全场景AI推理引擎MindSpore Lite, 助力HMS Core视频编辑服务打造更智能的剪辑体验

    移动互联网的发展给人们的社交和娱乐方式带来了很大的改变,以vlog.短视频等为代表的新兴文化样态正受到越来越多人的青睐.同时,随着AI智能.美颜修图等功能在图像视频编辑App中的应用,促使视频编辑效率 ...

  6. Caffe学习笔记(一):Caffe架构及其模型解析

    Caffe学习笔记(一):Caffe架构及其模型解析 写在前面:关于caffe平台如何快速搭建以及如何在caffe上进行训练与预测,请参见前面的文章<caffe平台快速搭建:caffe+wind ...

  7. 人体姿态和形状估计的视频推理:CVPR2020论文解析

    人体姿态和形状估计的视频推理:CVPR2020论文解析 VIBE: Video Inference for Human Body Pose and Shape Estimation 论文链接:http ...

  8. 【模型推理】Tengine 模型转换及量化

      欢迎关注我的公众号 [极智视界],回复001获取Google编程规范   O_o   >_<   o_O   O_o   ~_~   o_O   本文介绍一下 Tengine 模型转换 ...

  9. 【推理引擎】ONNXRuntime 的架构设计

    ONNXRuntime,深度学习领域的神经网络模型推理框架,从名字中可以看出它和 ONNX 的关系:以 ONNX 模型作为中间表达(IR)的运行时(Runtime). 本文许多内容翻译于官方文档:ht ...

随机推荐

  1. Solution -「Gym 102956A」Belarusian State University

    \(\mathcal{Description}\)   Link.   给定两个不超过 \(2^n-1\) 次的多项式 \(A,B\),对于第 \(i\in[0,n)\) 个二进制位,定义任意一个二元 ...

  2. GCC 使用库文件名进行链接

    使用 GCC 进行 C/C++ 代码编译时,如果代码中使用到了库函数,需要使用 -l 选项指定该库函数所在的库.如:-lm.-lrt.-lpthread等.这种方式使用的是库的缩写.一个库的文件名如果 ...

  3. Java老码农心得:卷了这么多年,您真的卷会了吗?

    前言 大家好,我是福隆苑居士,今天跟大家聊一下程序员在当下内卷成风的情况下,使用什么方法可以了解行业发展趋势,知道哪些该学,哪些可以略过,今年应该掌握什么,可以放弃什么,让自己时刻紧跟行业的步伐永不掉 ...

  4. Spring MVC 是什么? 核心总结

    SpringMVC是一个基于Java的实现了MVC设计模式的请求驱动类型的轻量级Web框架,通过把Model,View,Controller分离,将web层进行职责解耦,把复杂的web应用分成逻辑清晰 ...

  5. python中try...except的用法

    num = [1,2,0,3,1.5,'6'] for x in num: try: # 尝试执行下列代码 print (6/x) except ZeroDivisionError: print('0 ...

  6. HMS Core Discovery第13期回顾长文——构建手游中的真实世界

    HMS Core Discovery第13期直播<来吧!构建手游中的真实世界>,已于2月24日圆满结束,本期直播我们同三七游戏的专家一同向小伙伴们分享了HMS Core图形引擎服务(Sce ...

  7. Golang 包管理机制

    Golang 包管理机制 1. 历史 在go1.11之前, 并没有官方的包管理机制(Godep算个半官方), 主流的包管理机制有: GoVendor Glide Godep 在go1.11之后, 官方 ...

  8. 你别告诉我你还在用Excel做数据透视分析吧,太low了!

    来到大数据分析的时代,大量的大数据分析软件涌现,尽管如此,如果今天有人问起最常用的数据透视分析工具是什么的时候,我猜想Excel应该是大家的不二之选. 但是其实我想说,用现在的手机来打比方,Excel ...

  9. 【C# 异常处理】 开端

    异常概述 在使用计算机语言进行项目开发的过程中,即使程序员把代码写得尽善尽美,在系统的运行过程中仍然会遇到一些问题,因为很多问题不是靠代码能够避免的,比如:客户输入数据的格式,读取文件是否存在,网络是 ...

  10. httpHelper 从URL获取值

    /// <summary> /// 从URL获取值(字符串) /// </summary> public static string GetValueFromUrl(strin ...