本文大部分内容为对 ONNX 官方资料的总结和翻译,部分知识点参考网上质量高的博客。

一,ONNX 概述

深度学习算法大多通过计算数据流图来完成神经网络的深度学习过程。 一些框架(例如CNTK,Caffe2,Theano和TensorFlow)使用静态图形,而其他框架(例如 PyTorch 和 Chainer)使用动态图形。 但是这些框架都提供了接口,使开发人员可以轻松构建计算图和运行时,以优化的方式处理图。 这些图用作中间表示(IR),捕获开发人员源代码的特定意图,有助于优化和转换在特定设备(CPU,GPU,FPGA等)上运行。

ONNX 的本质只是一套开放的 ML 模型标准,模型文件存储的只是网络的拓扑结构和权重(其实每个深度学习框架最后保存的模型都是类似的),脱离开框架是没办法对模型直接进行 inference

1.1,为什么使用通用 IR

现在很多的深度学习框架提供的功能都是类似的,但是在 API、计算图和 runtime 方面却是独立的,这就给 AI 开发者在不同平台部署不同模型带来了很多困难和挑战,ONNX 的目的在于提供一个跨框架的模型中间表达框架,用于模型转换和部署。ONNX 提供的计算图是通用的,格式也是开源的。

二,ONNX 规范

Open Neural Network Exchange Intermediate Representation (ONNX IR) Specification.

ONNX 结构的定义文件 .proto.prpto3 可以在 onnx folder 目录下找到,文件遵循的是谷歌 Protobuf 协议。ONNX 是一个开放式规范,由以下组件组成:

  • 可扩展计算图模型的定义
  • 标准数据类型的定义
  • 内置运算符的定义

IR6 版本的 ONNX 只能用于推理(inference),从 IR7 开始 ONNX 支持训练(training)。onnx.proto 主要的对象如下:

  • ModelProto
  • GraphProto
  • NodeProto
  • AttributeProto
  • ValueInfoProto
  • TensorProto

他们之间的关系:ONNX 模型 load 之后,得到的是一个 ModelProto,它包含了一些版本信息,生产者信息和一个非常重要的 GraphProto;在 GraphProto 中包含了四个关键的 repeated 数组,分别是node (NodeProto 类型),input(ValueInfoProto 类型),output(ValueInfoProto 类型)和 initializer (TensorProto 类型),其中 node 中存放着模型中的所有计算节点,input 中存放着模型所有的输入节点,output 存放着模型所有的输出节点,initializer 存放着模型所有的权重;节点与节点之间的拓扑定义可以通过 input 和output 这两个 string 数组的指向关系得到,这样利用上述信息我们可以快速构建出一个深度学习模型的拓扑图。最后每个计算节点当中还包含了一个 AttributeProto 数组,用于描述该节点的属性,例如 Conv 层的属性包含 grouppadsstrides 等等,具体每个计算节点的属性、输入和输出可以参考这个 Operators.md 文档。

需要注意的是,上面所说的 GraphProto 中的 input 输入数组不仅仅包含我们一般理解中的图片输入的那个节点,还包含了模型当中所有权重。举例,Conv 层中的 W 权重实体是保存在 initializer 当中的,那么相应的会有一个同名的输入在 input 当中,其背后的逻辑应该是把权重也看作是模型的输入,并通过 initializer 中的权重实体来对这个输入做初始化(也就是把值填充进来)

2.1,Model

模型结构的主要目的是将元数据( meta data)与图形(graph)相关联,图形包含所有可执行元素。 首先,读取模型文件时使用元数据,为实现提供所需的信息,以确定它是否能够:执行模型,生成日志消息,错误报告等功能。此外元数据对工具很有用,例如IDE和模型库,它需要它来告知用户给定模型的目的和特征。

每个 model 有以下组件:

Name Type Description
ir_version int64 The ONNX version assumed by the model.
opset_import OperatorSetId A collection of operator set identifiers made available to the model. An implementation must support all operators in the set or reject the model.
producer_name string The name of the tool used to generate the model.
producer_version string The version of the generating tool.
domain string A reverse-DNS name to indicate the model namespace or domain, for example, 'org.onnx'
model_version int64 The version of the model itself, encoded in an integer.
doc_string string Human-readable documentation for this model. Markdown is allowed.
graph Graph The parameterized graph that is evaluated to execute the model.
metadata_props map<string,string> Named metadata values; keys should be distinct.
training_info TrainingInfoProto[] An optional extension that contains information for training.

2.2,Operators Sets

每个模型必须明确命名它依赖于其功能的运算符集。 操作员集定义可用的操作符,其版本和状态。 每个模型按其域定义导入的运算符集。 所有模型都隐式导入默认的 ONNX 运算符集。

运算符集(Operators Sets)对象的属性如下:

Name Type Description
magic string T ‘ONNXOPSET’
ir_version int32 The ONNX version corresponding to the operators.
ir_version_prerelease string The prerelease component of the SemVer of the IR.
ir_build_metadata string The build metadata of this version of the operator set.
domain string The domain of the operator set. Must be unique among all sets.
opset_version int64 The version of the operator set.
doc_string string Human-readable documentation for this operator set. Markdown is allowed.
operator Operator[] The operators contained in this operator set.

2.3,ONNX Operator

图( graph)中使用的每个运算符必须由模型(model)导入的一个运算符集明确声明。

运算符(Operator)对象定义的属性如下:

Name Type Description
op_type string The name of the operator, as used in graph nodes. MUST be unique within the operator set’s domain.
since_version int64 The version of the operator set when this operator was introduced.
status OperatorStatus One of ‘EXPERIMENTAL’ or ‘STABLE.’
doc_string string A human-readable documentation string for this operator. Markdown is allowed.

2.4,ONNX Graph

序列化图由一组元数据字段(metadata),模型参数列表(a list of model parameters,)和计算节点列表组成(a list of computation nodes)。每个计算数据流图被构造为拓扑排序的节点列表,这些节点形成图形,其必须没有周期。 每个节点代表对运营商的呼叫。 每个节点具有零个或多个输入以及一个或多个输出。

图表(Graph)对象具有以下属性:

Name Type Description
name string 模型计算图的名称
node Node[] 节点列表,基于输入/输出数据依存关系形成部分排序的计算图,拓扑顺序排列。
initializer Tensor[] 命名张量值的列表。 当 initializer 与计算图 graph输入名称相同,输入指定一个默认值,否则指定一个常量值。
doc_string string 用于阅读模型的文档
input ValueInfo[] 计算图 graph 的输入参数,在 ‘initializer.’ 中可能能找到默认的初始化值。
output ValueInfo[] 计算图 graph 的输出参数。
value_info ValueInfo[] 用于存储除输入、输出值之外的类型和形状信息。

2.5,ValueInfo

ValueInfo 对象属性如下:

Name Type Description
name string The name of the value/parameter.
type Type The type of the value including shape information.
doc_string string Human-readable documentation for this value. Markdown is allowed.

2.6,Standard data types

ONNX 标准有两个版本,主要区别在于支持的数据类型和算子不同。计算图 graphs、节点 nodes和计算图的 initializers 支持的数据类型如下。原始数字,字符串和布尔类型必须用作张量的元素。

2.6.1,Tensor Element Types

Group Types Description
Floating Point Types float16, float32, float64 浮点数遵循IEEE 754-2008标准。
Signed Integer Types int8, int16, int32, int64 支持 8-64 位宽的有符号整数。
Unsigned Integer Types uint8, uint16 支持 816 位的无符号整数。
Complex Types complex64, complex128 具有 32 位或 64 位实部和虚部的复数。
Other string 字符串代表的文本数据。 所有字符串均使用UTF-8编码。
Other bool 布尔值类型,表示的数据只有两个值,通常为 truefalse

2.6.2,Input / Output Data Types

以下类型用于定义计算图和节点输入和输出的类型。

Variant Type Description
ONNX dense tensors 张量是向量和矩阵的一般化
ONNX sequence sequence (序列)是有序的稠密元素集合。
ONNX map 映射是关联表,由键类型和值类型定义。

ONNX 现阶段没有定义稀疏张量类型

三,ONNX版本控制

四,主要算子概述

五,Python API 使用

5.1,加载模型

1,Loading an ONNX model

import onnx
# onnx_model is an in-mempry ModelProto
onnx_model = onnx.load('path/to/the/model.onnx') # 加载 onnx 模型

2,Loading an ONNX Model with External Data

  • 【默认加载模型方式】如果外部数据(external data)和模型文件在同一个目录下,仅使用 onnx.load() 即可加载模型,方法见上小节。
  • 如果外部数据(external data)和模型文件不在同一个目录下,在使用 onnx_load() 函数后还需使用 load_external_data_for_model() 函数指定外部数据路径。
import onnx
from onnx.external_data_helper import load_external_data_for_model onnx_model = onnx.load('path/to/the/model.onnx', load_external_data=False)
load_external_data_for_model(onnx_model, 'data/directory/path/')
# Then the onnx_model has loaded the external data from the specific directory

3,Converting an ONNX Model to External Data

from onnx.external_data_helper import convert_model_to_external_data

# onnx_model is an in-memory ModelProto
onnx_model = ...
convert_model_to_external_data(onnx_model, all_tensors_to_one_file=True, location='filename', size_threshold=1024, convert_attribute=False)
# Then the onnx_model has converted raw data as external data
# Must be followed by save

5.2,保存模型

1,Saving an ONNX Model

import onnx

# onnx_model is an in-memory ModelProto
onnx_model = ... # Save the ONNX model
onnx.save(onnx_model, 'path/to/the/model.onnx')

2,Converting and Saving an ONNX Model to External Data

import onnx

# onnx_model is an in-memory ModelProto
onnx_model = ...
onnx.save_model(onnx_model, 'path/to/save/the/model.onnx', save_as_external_data=True, all_tensors_to_one_file=True, location='filename', size_threshold=1024, convert_attribute=False)
# Then the onnx_model has converted raw data as external data and saved to specific directory

5.3,Manipulating TensorProto and Numpy Array

import numpy
import onnx
from onnx import numpy_helper # Preprocessing: create a Numpy array
numpy_array = numpy.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=float)
print('Original Numpy array:\n{}\n'.format(numpy_array)) # Convert the Numpy array to a TensorProto
tensor = numpy_helper.from_array(numpy_array)
print('TensorProto:\n{}'.format(tensor)) # Convert the TensorProto to a Numpy array
new_array = numpy_helper.to_array(tensor)
print('After round trip, Numpy array:\n{}\n'.format(new_array)) # Save the TensorProto
with open('tensor.pb', 'wb') as f:
f.write(tensor.SerializeToString()) # Load a TensorProto
new_tensor = onnx.TensorProto()
with open('tensor.pb', 'rb') as f:
new_tensor.ParseFromString(f.read())
print('After saving and loading, new TensorProto:\n{}'.format(new_tensor))

5.4,创建ONNX模型

可以通过 helper 模块提供的函数 helper.make_graph 完成创建 ONNX 格式的模型。创建 graph 之前,需要先创建相应的 NodeProto(node),参照文档设定节点的属性,指定该节点的输入与输出,如果该节点带有权重那还需要创建相应的ValueInfoProtoTensorProto 分别放入 graph 中的 inputinitializer 中,以上步骤缺一不可。

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto # The protobuf definition can be found here:
# https://github.com/onnx/onnx/blob/master/onnx/onnx.proto # Create one input (ValueInfoProto)
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 2])
pads = helper.make_tensor_value_info('pads', TensorProto.FLOAT, [1, 4]) value = helper.make_tensor_value_info('value', AttributeProto.FLOAT, [1]) # Create one output (ValueInfoProto)
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 4]) # Create a node (NodeProto) - This is based on Pad-11
node_def = helper.make_node(
'Pad', # name
['X', 'pads', 'value'], # inputs
['Y'], # outputs
mode='constant', # attributes
) # Create the graph (GraphProto)
graph_def = helper.make_graph(
[node_def], # nodes
'test-model', # name
[X, pads, value], # inputs
[Y], # outputs
) # Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example') print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
print('The model is checked!')

5.5,检查模型

在完成 ONNX 模型加载或者创建后,有必要对模型进行检查,使用 onnx.check.check_model() 函数。

import onnx

# Preprocessing: load the ONNX model
model_path = 'path/to/the/model.onnx'
onnx_model = onnx.load(model_path) print('The model is:\n{}'.format(onnx_model)) # Check the model
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print('The model is invalid: %s' % e)
else:
print('The model is valid!')

5.6,实用功能函数

函数 extract_model() 可以从 ONNX 模型中提取子模型,子模型由输入和输出张量的名称定义。这个功能方便我们 debug 原模型和转换后的 ONNX 模型输出结果是否一致(误差小于某个阈值),不再需要我们手动去修改 ONNX 模型。

import onnx

input_path = 'path/to/the/original/model.onnx'
output_path = 'path/to/save/the/extracted/model.onnx'
input_names = ['input_0', 'input_1', 'input_2']
output_names = ['output_0', 'output_1'] onnx.utils.extract_model(input_path, output_path, input_names, output_names)

5.7,工具

函数 update_inputs_outputs_dims() 可以将模型输入和输出的维度更新为参数中指定的值,可以使用 dim_param 提供静态和动态尺寸大小。

import onnx
from onnx.tools import update_model_dims model = onnx.load('path/to/the/model.onnx')
# Here both 'seq', 'batch' and -1 are dynamic using dim_param.
variable_length_model = update_model_dims.update_inputs_outputs_dims(model, {'input_name': ['seq', 'batch', 3, -1]}, {'output_name': ['seq', 'batch', 1, -1]})
# need to check model after the input/output sizes are updated
onnx.checker.check_model(variable_length_model )

参考资料

  1. ONNX--跨框架的模型中间表达框架
  2. 深度学习模型转换与部署那些事(含ONNX格式详细分析)
  3. onnx

ONNX模型分析与使用的更多相关文章

  1. 【推理引擎】ONNX 模型解析

    定义模型结构 首先使用 PyTorch 定义一个简单的网络模型: class ConvBnReluBlock(nn.Module): def __init__(self) -> None: su ...

  2. 数据挖掘应用案例:RFM模型分析与客户细分(转)

    正好刚帮某电信行业完成一个数据挖掘工作,其中的RFM模型还是有一定代表性,就再把数据挖掘RFM模型的建模思路细节与大家分享一下吧!手机充值业务是一项主要电信业务形式,客户的充值行为记录正好满足RFM模 ...

  3. dlib人脸关键点检测的模型分析与压缩

    本文系原创,转载请注明出处~ 小喵的博客:https://www.miaoerduo.com 博客原文(排版更精美):https://www.miaoerduo.com/c/dlib人脸关键点检测的模 ...

  4. 高级设计总监的设计方法论——5W1H需求分析法 KANO模型分析法

    本期开始进入设计方法论的学习,大湿自己也是边学边分享,算是巩固一遍吧: 另外这些理论基本都是交叉结合来应用于工作中,我们学习理论但不要拘泥于理论的框架中,掌握后要灵活运用一点- 这些理论一部分来自于我 ...

  5. 基于Python的信用评分卡模型分析(二)

    上一篇文章基于Python的信用评分卡模型分析(一)已经介绍了信用评分卡模型的数据预处理.探索性数据分析.变量分箱和变量选择等.接下来我们将继续讨论信用评分卡的模型实现和分析,信用评分的方法和自动评分 ...

  6. No.1_NABCD模型分析

        Reminder 之 NABCD模型分析           定位 多平台的闹钟提醒软件. 在安卓市场发布软件,发布后一周的用户量为1000.           N (Need 需求) 这个 ...

  7. Task 6.1 校友聊之NABCD模型分析

    我们团队开发的一款软件是“校友聊”--一个在局域网内免流量进行文字.语音.视频聊天的软件.下面将对此进行NABCD的模型分析. N(Need需求):现如今,随着网络的迅速普及,手机和电脑已经成为每个大 ...

  8. (小组)第六次作业:NABCD模型分析。产品Backlog。

    NABCD模型分析: NABCD模型分析 1.N——need需求 随着时代的进步,人们生活水平的提高,现在手机的普及率已经非常高了,而且现在的家长很多时候会忙于工作,很少会花时间出来给自己读小学的孩子 ...

  9. libevent-select模型分析

    下面内容为windows下select模型分析,原博客链接 http://blog.csdn.net/fish_55_66/article/details/50352080 https://www.c ...

  10. 产品需求分析神器:KANO模型分析法

    前言: 任何一个互联网产品,哪怕是一个简单的页面,也会涉及到很多的需求,产品经理也会经常遇到这样的情况:老板,业务提的各种新需求一下子都扎堆,哪个需求对用户来说最重要,用户对我们的新功能是否满意?开发 ...

随机推荐

  1. React动画实现方案之 Framer Motion,让你的页面“自己”动起来

    前言 相信很多前端同学都或多或少和动画打过交道.有的时候是产品想要的过度效果:有的时候是UI想要的酷炫动画.但是有没有人考虑过,是不是我们的页面上面的每一次变化,都可以像是自然而然的变化:是不是每一次 ...

  2. C#--@符号的使用(逐字字符串,跨行,声明关键字变量名)

    ---对字符串的使用 @可以定义逐字字符串 注意:@只对字符串常量有用 1)不需要用\\来转义非转义符号的\号   例如:@"\"="\\"2)可以实现多行字符 ...

  3. 前后端分离项目(九):实现"添加"功能(后端接口)

    好家伙,来了来了,"查"已经完成了,现在是"增" 前端的视图已经做好了,现在我们来完善后端 后端目录结构   完整代码在前后端分离项目(五):数据分页查询(后端 ...

  4. packet Capture 手机抓包工具

    packet Capture packet Capture 是一款免root的app, 运行在安卓平台上,用于捕获http/https网络流量嗅探的应用程序 特点: 捕获网络数据包,并记录太慢,使用中 ...

  5. Python基础部分:8、for循环和range的使用

    目录 一.while循环补充说明 1.死循环 2.嵌套及全局标志位 二.for...循环 1.for...循环特点 2.for...循环语法结构 三.range方法 1.什么是range 2.不同版本 ...

  6. 论文笔记 - RETRIEVE: Coreset Selection for Efficient and Robust Semi-Supervised Learning

    Motivation 虽然半监督学习减少了大量数据标注的成本,但是对计算资源的要求依然很高(无论是在训练中还是超参搜索过程中),因此提出想法:由于计算量主要集中在大量未标注的数据上,能否从未标注的数据 ...

  7. while、for循环结合else

    """1.while else,当while循环正常结束时,才走else里的代码块,也就是没有被break打断的情况下2.此处只是不被break打断,也就是遇到break ...

  8. SCI简介和写作顺序

    一.SCI论文组成部分简介 一篇完整的 sci 论文主要包括以下几个主要的组成部分,从前往后依次分别是 Title 就是说这个文章的标题其次是 Abstract 也就是这个文章的摘要.接下来是 Int ...

  9. C#字典出错“集合已经修改,可能无法执行枚举操作”

    出现这个现象的原因是由于线程安全考虑,如果你边对字典循环,又同时移除字典中的某个键值对, 那么将会出现这种错误,解决这种问题的方法是你没次remove某个键值对后需要break结束对字典的循环.

  10. WCF实现大文件上传

    一.文件服务接口 1.文件上传 2.文件传输(上传按钮) 3.文件传输停止 服务地址: 在客端添加服务器引用,从而实现客户端调用服务器的功能. 二.契约 服务契约[ServiceContract]:定 ...