接上一篇完成的pytorch模型训练结果,模型结构为ResNet18+fc,参数量约为11M,最终测试集Acc达到94.83%。接下来有分两个部分:导出onnx和使用onnxruntime推理。

一、pytorch导出onnx

直接放函数吧,这部分我是直接放在test.py里面的,直接从dataloader中拿到一个batch的数据走一遍推理即可。

def export_onnx(net, testloader, output_file):
net.eval()
with torch.no_grad():
for data in testloader:
images, labels = data torch.onnx.export(net,
(images),
output_file,
training=False,
do_constant_folding=True,
input_names=["img"],
output_names=["output"],
dynamic_axes={"img": {0: "b"},"output": {0: "b"}}
)
print("onnx export done!")
break

上面函数中几个比较重要的参数:do_constant_folding是常量折叠,建议打开;输入张量通过一个tuple传入,并且最好指定每个输入和输出的名称,此外,为保证使用onnxruntime推理的时候batchsize可变,dynamic_axes的第一维需要像上述一样设置为动态的。如果是全卷积做分割的网络,类似的输入h和w也应该是动态的。

单独运行test.py计算测试集效果和平均相应时间,为方便比较,这里batch_size设置为1,结果为:

Test Acc is: 94.84%
Average response time cost: 8.703978610038757 ms

二、使用onnxruntime推理

这里我们使用gpu版本的onnxruntime库进行推理,其python包可直接pip install onnxruntime-gpu安装。onnxruntime推理代码和测试集推理代码很类似,如下:

import numpy as np
import onnxruntime as ort
import argparse, os
from lib import CIFARDataset def onnxruntime_test(session, testloader):
print("Start Testing!")
input_name = session.get_inputs()[0].name
correct = 0
total = 0 # 计数归零(初始化)
for data in testloader:
images, labels = data
images, labels = images.numpy(), labels.numpy()
outputs = session.run(None, {input_name:images})
predicted = np.argmax(outputs[0], axis=1) # 取得分最高的那个类
total += labels.shape[0] # 累加样本总数
correct += (predicted == labels).sum() # 累加预测正确的样本个数
acc = correct / total
print('ONNXRuntime Test Acc is: %.2f%%' % (100*acc)) if __name__ == '__main__':
# 命令行参数解析
parser = argparse.ArgumentParser("CNN backbone on cifar10")
parser.add_argument('--onnx', default='./output/test_resnet18_10_autoaug/densenet_best.onnx')
args = parser.parse_args() NUM_CLASS =10
BATCH_SIZE = 1 # 批处理尺寸(batch_size) # 数据集迭代器
data_path="./data"
dataset = CIFARDataset(dataset_path=data_path, batchsize=BATCH_SIZE)
_, testloader = dataset.get_cifar10_dataloader() # 构建session
sess = ort.InferenceSession(args.onnx, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) #onnxruntime推理
import time
start = time.time()
onnxruntime_test(sess, testloader)
end = time.time()
print(f"Average response time cost: {1000*(end-start)/len(testloader.dataset)} ms")

使用onnxruntime加载导出的onnx模型,计算测试集效果和平均响应时间,结果为:

ONNXRuntime Test Acc is: 94.83%
Average response time cost: 3.1050602436065673 ms

三、小结

分析上面的pytorch和onnxruntime的测试结果可知,最终测试集效果是一致的,Acc分别为94.84%和94.83%,相当于10000个样本里面只有1个的预测结果不一致,这是可以接受范围内。但onnxruntime的效率更高,平均耗时只有3.1ms,比pytorch的8.7ms快了将近3倍。这在实际部署中的优势是非常明显的。目前Python端的结论比最初目标设定的50ms高很多,如果说需要进一步优化,两个方向:模型量化或并行化推理(拼batch或多线程)。下一篇再分析。

ONNXRuntime学习笔记(三)的更多相关文章

  1. Oracle学习笔记三 SQL命令

    SQL简介 SQL 支持下列类别的命令: 1.数据定义语言(DDL) 2.数据操纵语言(DML) 3.事务控制语言(TCL) 4.数据控制语言(DCL)  

  2. [Firefly引擎][学习笔记三][已完结]所需模块封装

    原地址:http://www.9miao.com/question-15-54671.html 学习笔记一传送门学习笔记二传送门 学习笔记三导读:        笔记三主要就是各个模块的封装了,这里贴 ...

  3. JSP学习笔记(三):简单的Tomcat Web服务器

    注意:每次对Tomcat配置文件进行修改后,必须重启Tomcat 在E盘的DATA文件夹中创建TomcatDemo文件夹,并将Tomcat安装路径下的webapps/ROOT中的WEB-INF文件夹复 ...

  4. java之jvm学习笔记三(Class文件检验器)

    java之jvm学习笔记三(Class文件检验器) 前面的学习我们知道了class文件被类装载器所装载,但是在装载class文件之前或之后,class文件实际上还需要被校验,这就是今天的学习主题,cl ...

  5. VSTO学习笔记(三) 开发Office 2010 64位COM加载项

    原文:VSTO学习笔记(三) 开发Office 2010 64位COM加载项 一.加载项简介 Office提供了多种用于扩展Office应用程序功能的模式,常见的有: 1.Office 自动化程序(A ...

  6. Java IO学习笔记三

    Java IO学习笔记三 在整个IO包中,实际上就是分为字节流和字符流,但是除了这两个流之外,还存在了一组字节流-字符流的转换类. OutputStreamWriter:是Writer的子类,将输出的 ...

  7. NumPy学习笔记 三 股票价格

    NumPy学习笔记 三 股票价格 <NumPy学习笔记>系列将记录学习NumPy过程中的动手笔记,前期的参考书是<Python数据分析基础教程 NumPy学习指南>第二版.&l ...

  8. Learning ROS for Robotics Programming Second Edition学习笔记(三) 补充 hector_slam

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

  9. Learning ROS for Robotics Programming Second Edition学习笔记(三) indigo rplidar rviz slam

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

随机推荐

  1. Redis 集群,集群的原理是什么?

    1).Redis Sentinal 着眼于高可用,在 master 宕机时会自动将 slave 提升为master,继续提供服务. 2).Redis Cluster 着眼于扩展性,在单个 redis ...

  2. java-設計模式-抽象工場模式

    抽象工廠模式AbstractFactory 一种创建型设计模式, 它能创建一系列相关的对象, 而无需指定其具体类. 工廠方法模式中考虑的是一类产品的生产,如畜牧场只养动物.电视机厂只生产电视机,同种类 ...

  3. Java语言的特点有哪些?

    1.简单 Java最初是为对家用电器进行集成控制而设计的一种语言,因此它必须简单明了.Java语言的简单性主要体现在以下三个方面: 1) Java的风格类似于C++,因而C++程序员是非常熟悉的.从某 ...

  4. Redis 集群的主从复制模型是怎样的?

    为了使在部分节点失败或者大部分节点无法通信的情况下集群仍然可用,所 以集群使用了主从复制模型,每个节点都会有 N-1 个复制品.

  5. 使用 Docker, 7 个命令部署一个 Mesos 集群

    这个教程将给你展示怎样使用 Docker 容器提供一个单节点的 Mesos 集群(未来的一篇文章将展示怎样很容易的扩展这个到多个节点或者是见底部更新).这意味着你可以使用 7 个命令启动整个集群!不需 ...

  6. 关于elementUI如何在表格循环列表里分别新增Tag的设计使用

    话不多说,直接上代码.想要实现的目的是这样的,想要在表格里单独添加每一个tag 那么,需要解决的问题就是如何定义这每一个插槽里的输入框,把每个输入框以及里面插入的数据区分开. 研究了很久,最后选择了对 ...

  7. java中final变量的用法

    4.4 final变量    final变量的数值不能在初始化之后进行改变(你希望a=3,有很多用到a的场合, 你当然不能在程序中就用3来代替a). 比如: final int h = 0; 想像有一 ...

  8. wx:key报错does not look like a valid key name

    把花括号去掉就行了,  现在改版了,  要注意了     wx:key="index"  

  9. 浅谈ES6中的Async函数

    转载地址:https://www.cnblogs.com/sghy/p/7987640.html 定义:Async函数是一个异步操作函数,本质上,Async函数是Generator函数的语法糖.asy ...

  10. 在 Mac 上开发 .NET MAUI

    .NET 多平台应用程序 UI (.NET MAUI) 是一个跨平台框架,用于使用 C# 和 XAML 创建本机移动和桌面应用程序,这些应用程序可以从单个共享代码库在 Android.iOS.macO ...