前言

为了深入理解ONNX Runtime的底层机制,本文将对 Graph::SetGraphInputsOutputs() 的代码逐行分析。

正文

首先判断Graph是否从ONNX文件中加载所得:

if (is_loaded_from_model_file_) return Status::OK();

如果是,可直接返回;如果不是,则需要解析Graph中的节点,从而设置模型的输入和输出。

Graph中的成员变量 value_info_graph_inputs_excluding_initializers_graph_inputs_including_initializers_ 以及 graph_outputs_ 全部清空:

value_info_.clear();

graph_inputs_excluding_initializers_.clear();

if (!graph_inputs_manually_set_) {
graph_inputs_including_initializers_.clear();
} else {
std::unordered_set<std::string> existing_names;
for (auto arg : graph_inputs_including_initializers_) {
const std::string& name = arg->Name();
if (existing_names.count(name) == 0) {
graph_inputs_excluding_initializers_.push_back(arg);
existing_names.insert(name);
}
}
} if (!graph_outputs_manually_set_) {
graph_outputs_.clear();
}

设置一些局部变量,方便下面的使用分析:

std::unordered_map<std::string, size_t> output_name_to_node_arg_index;
std::vector<const NodeArg*> output_node_args_in_order;
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};

统计所有节点的输出,添加到以上局部变量(output_name_to_node_arg_index 和 output_node_args_in_order)中:

for (const auto& node : Nodes()) {
for (const auto* output_def : node.OutputDefs()) {
if (output_def->Exists()) {
output_node_args_in_order.push_back(output_def);
output_name_to_node_arg_index.insert({output_def->Name(), output_node_args_in_order.size() - 1});
}
}
}
auto graph_output_args = output_name_to_node_arg_index; // 拷贝一份输出节点map

然后遍历图中每个节点以及每个节点的输入:

for (const auto& node : Nodes()) {
// Go thru all node's inputs.
for (const auto* input_arg : node.InputDefs()) {
...
}
}

在输出节点name列表中查找当前输入name

auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name());

如果没有找到,说明这个节点的输入就是图的输入,接下来还需要判断这个输入是否已经放在局部变量added_input_names中:

if (output_name_to_node_arg_index.end() == output_arg_iter) {
// This input arg is not the output of another node so must come from either a graph input or an initializer.
const std::string& name = input_arg->Name();
if (added_input_names.end() == added_input_names.find(name)) {
...
}
}

如果已经放到局部变量added_input_names中,就可以判断节点的下一个输入或者下一个节点的输入。如果没有放到局部变量added_input_names中:

bool is_initializer = name_to_initial_tensor_.find(name) != name_to_initial_tensor_.end();  // 判断当前input_arg是否已初始化过的tensor,如果是就不可以再放置到 graph_inputs_excluding_initializers_ 中
if (!graph_inputs_manually_set_) { // 如果未主动调用 SetInputs() 方法
// if IR version < 4 all initializers must have a matching graph input
// (even though the graph input is not allowed to override the initializer).
// if IR version >= 4 initializers are not required to have a matching graph input.
// any graph inputs that are to override initializers must be specified by calling SetInputs.
if (!is_initializer || ir_version_ < 4) {
graph_inputs_including_initializers_.push_back(input_arg);
}
if (!is_initializer) {
// If input_arg is not of an initializer, we add it into graph_inputs_excluding_initializers_.
graph_inputs_excluding_initializers_.push_back(input_arg);
}
} else { // 如果主动调用了 SetInputs() 方法
// graph_inputs_including_initializers_ has been manually populated by SetInputs.
// Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
if (!is_initializer) {
const auto& inputs = graph_inputs_including_initializers_;
bool in_inputs = std::find(inputs.begin(), inputs.end(), input_arg) != inputs.end();
if (!in_inputs) {
return Status(ONNXRUNTIME, FAIL,
name + " must be either specified in graph inputs or graph initializers.");
}
} else {
// If arg_input is of an initializer, we remove it from graph_inputs_excluding_initializers_
// whose initial content has both initializers and non-initializers.
auto input_pos = std::find(graph_inputs_excluding_initializers_.begin(),
graph_inputs_excluding_initializers_.end(),
input_arg);
if (input_pos != graph_inputs_excluding_initializers_.end()) {
graph_inputs_excluding_initializers_.erase(input_pos);
}
}
}
added_input_names.insert(name);

可以看到,这里会把当前的 input_arg 分别放到 graph_inputs_including_initializers_graph_inputs_excluding_initializers_ 中,并将name放在added_input_names中。

如果该输入的name已经在输出节点name列表中,说明这个节点是中间输出结果,而非整个图的输出,因此应该将其从图的输出(graph_output_args)中删除,并放在 value_info_ 中:

if (output_name_to_node_arg_index.end() == output_arg_iter) {
...
}else if(graph_output_args.erase(output_arg_iter->first) >= 1){
value_info_.insert(input_arg);
}

以上我们对Graph的三个成员变量:graph_inputs_including_initializers_graph_inputs_excluding_initializers_value_info_分别进行了赋值,其中前两者存储输入,后者存储中间结果。我们还需要处理图的输出结果:`graph_outputs_`:

if (!graph_outputs_manually_set_) {
// Set graph outputs in order.
std::vector<size_t> graph_output_args_index;
graph_output_args_index.reserve(graph_output_args.size());
for (const auto& output_arg : graph_output_args) { // graph_output_args原本存储了所有节点的输出,但是前面的代码已经把中间节点的输出给移除了,因此剩下的就是整个Graph的输出
graph_output_args_index.push_back(output_arg.second);
} std::sort(graph_output_args_index.begin(), graph_output_args_index.end());
for (auto& output_arg_index : graph_output_args_index) {
graph_outputs_.push_back(output_node_args_in_order[output_arg_index]);
}
}

最后,还需要对 graph_overridable_initializers_ 进行处理:

ComputeOverridableInitializers();

进入这个函数内部:

void Graph::ComputeOverridableInitializers() {
graph_overridable_initializers_.clear();
if (CanOverrideInitializer()) {
// graph_inputs_excluding_initializers_ and graph_inputs_including_initializers_
// are inserted in the same order. So we walk and compute the difference.
auto f_incl = graph_inputs_including_initializers_.cbegin();
const auto l_incl = graph_inputs_including_initializers_.cend();
auto f_excl = graph_inputs_excluding_initializers_.cbegin();
const auto l_excl = graph_inputs_excluding_initializers_.cend(); while (f_incl != l_incl) {
// Equal means not an initializer
if (f_excl != l_excl && *f_incl == *f_excl) {
++f_incl;
++f_excl;
continue;
}
graph_overridable_initializers_.push_back(*f_incl);
++f_incl;
}
}
}

这是一个很简单的算法,通过比较 graph_inputs_including_initializers_graph_inputs_excluding_initializers_,提取出 initializer 并放置到 graph_overridable_initializers_ 中。

至此,我们完成了对 Graph::SetGraphInputsOutputs() 函数的解析。

总结

针对这个函数的解析不仅理解了如何从Graph的nodes中分析出graph的输入和输出,而且懂得了graph_overridable_initializers_以及value_info_的作用。

ONNX Runtime 源码阅读:Graph::SetGraphInputsOutputs() 函数的更多相关文章

  1. linux源码阅读笔记 fork函数

    在阅读源码的过程中,发现找不到fork函数的定义.后来在linux/init/main.c中找到了这样一条语句 static inline _syscall0(int,fork) 原来这里就是fork ...

  2. [PHP源码阅读]number_format函数

    上次讲到PHP是如何解析大整数的,一笔带过了number_format的处理,再详细阅读该函数的源码,以下是小分析. 函数原型 string number_format ( float $number ...

  3. linux源码阅读笔记 asm函数

    在linux源码中经常遇到__asm__函数.它其实是函数asm的宏定义 #define __asm__ asm,asm函数让系统执行汇编语句. __asm__常常与__volatile__一起出现. ...

  4. Runtime 源码阅读

    Runtime 属性说明 /** * 每一个 Java 应用程序都有一个关联的运行时对象 * * @author unascribed * @see java.lang.Runtime#getRunt ...

  5. [PHP源码阅读]explode和implode函数

    explode和implode函数主要用作字符串和数组间转换的操作,比如获取一段参数后根据某个字符分割字符串,或者将一个数组的结果使用一个字符合并成一个字符串输出.在PHP中经常会用到这两个函数,因此 ...

  6. CI框架源码阅读笔记3 全局函数Common.php

    从本篇开始,将深入CI框架的内部,一步步去探索这个框架的实现.结构和设计. Common.php文件定义了一系列的全局函数(一般来说,全局函数具有最高的加载优先权,因此大多数的框架中BootStrap ...

  7. 3 EventTime 事件时间类和TimeNow函数——Live555源码阅读(一)基本组件类

    这是Live555源码阅读的第一部分,包括了时间类,延时队列类,处理程序描述类,哈希表类这四个大类. 这里是时间相关类的第三个部分,也是最后一个部分. EventTime 事件时间类 这个类和Dela ...

  8. PHP源码阅读笔记一(explode和implode函数分析)

    PHP源码阅读笔记一一.explode和implode函数array explode ( string separator, string string [, int limit] )此函数返回由字符 ...

  9. [PHP源码阅读]strtolower和strtoupper函数

    字符串的操作函数中,字符串的大小写转换也算是比较常用的函数,其底层实现也比较简单,下面来一探究竟. 我在github上有对PHP源码更详细的注解.感兴趣的可以围观一下,给个star.PHP5.4源码注 ...

随机推荐

  1. synchronized 关键字的用法?

    synchronized 关键字可以将对象或者方法标记为同步,以实现对对象和方法的互 斥访问,可以用 synchronized(对象) { - }定义同步代码块,或者在声明方法时 将 synchron ...

  2. Zookeeper 下 Server 工作状态 ?

    服务器具有四种状态,分别是 LOOKING.FOLLOWING.LEADING.OBSERVING. 1.LOOKING:寻找 Leader 状态.当服务器处于该状态时,它会认为当前集群中 没有 Le ...

  3. Java 中,DOM 和 SAX 解析器有什么不同?

    DOM 解析器将整个 XML 文档加载到内存来创建一棵 DOM 模型树,这样可以 更快的查找节点和修改 XML 结构,而 SAX 解析器是一个基于事件的解析器, 不会将整个 XML 文档加载到内存.由 ...

  4. Spring源码分析笔记--事务管理

    核心类 InfrastructureAdvisorAutoProxyCreator 本质是一个后置处理器,和AOP的后置处理器类似,但比AOP的使用级别低.当开启AOP代理模式后,优先使用AOP的后置 ...

  5. 学习heartbeat-05 实现web服务高可用

    一.环境介绍 说明:所有案例在虚拟机(VMware)上完成 操作系统:centos 6.5 64bit 高可用软件:heartbeat 3.0.4 Web应用服务器:apache httpd 2.2. ...

  6. 前端性能优化(JavaScript篇)

    正巧看到在送书,于是乎找了找自己博客上记录过的一些东西来及其无耻的蹭书了~~~ 小广告:更多内容可以看我的博客 优化循环 如果现在有个一个data[]数组,需要对其进行遍历,应当怎么做?最简单的代码是 ...

  7. js压缩图片到2m以下

    用的canvas.这个问题测试妹子反馈了好几次bug,解决了好多次,虽然用了比较僵硬的办法,但总算最终解决了. 因为php的同事说,页面上的图片要直接调用七牛的接口上传到七牛,所以后端那边不能处理,必 ...

  8. pdm的说明

    软件行业的JAVA代码静态分析工具 PMD是一种开源分析Java代码错误的工具.与其他分析工具不同的是,PMD通过静态分析获知代码错误.也就是说,在不运行Java程序的情况下报告错误.PMD附带了许多 ...

  9. Java中的反射原理以及简单运用(原理+例子)

    @ 目录 学习总结 1. 为什么要使用反射 2. 反射的概念 3. Java反射加载过程 4. 反射优缺点 5. 字节码对象理解 6. 获取字节码对象(.class)的三种方式 7. 反射常用API ...

  10. linux(Ubuntu)安装python

    Linux下安装python 提前安装一个依赖环境 (1)ubuntu/Debian: sudo apt-get install -y gcc make cmake build-essential l ...