ONNXRuntime学习笔记(二)
继上一篇计划的实践项目,这篇记录我训练模型相关的工作。
- 首先要确定总体目标:训练一个pytorch模型,CIFAR-100数据集测试集acc达到90%;部署后推理效率达到50ms/张, 部署平台为window10+3050Ti+RX5800h.
- 训练模型的话,最好是有一套完备的代码,像谷歌的models,FB的detectron2,商汤的mm系列等等框架,这些是建立在深度学习框架tf或pth基础上的进一步封装,提供一些更高级的写好的模块可以调用,如Resnet、FPN、、proposal、NMS等等。但凡事都有两面,封装度越高意味着稳定性更好但修改的灵活性越差。只调用API对我们理解底层实现是不利的。之前我写过一个基于Pytorch的图像分类训练推理代码,现在又可以拿出来用一用了,地址:https://github.com/lee-zq/CNN-Backbone ,我在之前训练CIFAR-10的基础上又添加了CIFAR-100数据集的Dataloader创建代码。
- 首先,我尝试了CIFAR10+DenseNet,最后测试效果Acc=85%;然后尝试了CIFAR10+ResNet18,收敛较慢,但最终Acc=91.02%;基于此,。我尝试了CIFAR100+ResNet18,收敛很慢,大概到73Epoch稳定下来,但最终训练集Acc能达到90.62%,但测试集Acc为65.67%。大概率原因是模型拟合能力够用但是训练集多样性太差。模型结构如下:
ResNet(
(conv1): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(layer1): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer2): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer3): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer4): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(fc): Linear(in_features=512, out_features=10, bias=True)
)
Total number of parameters: 11173962
总参数量约11M,既然CIFAR-100效果太差,那就暂且还是用CIFAR-10做后面的训练测试吧,我又在之前的数据增强基础上加了RandomGrayscale和RandomAffine,最终的数据增强如下:
self.mean = [0.4914, 0.4822, 0.4465]
self.std = [0.2023, 0.1994, 0.2010]
self.num_workers= num_workers
self.transform_train = transforms.Compose([# 数据增强
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomGrayscale(0.15),
transforms.RandomAffine((-30,30)),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing(),
])
- 然后微调继续训练,测试集Acc进一步提升到92.28%,可见数据多样性的重要性。进一步的,torchvision提供了AutoAugment数据增强方法的接口,可以直接调用,最终数据增强代码如下:
self.mean = [0.4914, 0.4822, 0.4465]
self.std = [0.2023, 0.1994, 0.2010]
self.num_workers= num_workers
self.transform_train = transforms.Compose([# 数据增强
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.autoaugment.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing()
])
- 训练epoch数为80,优化器Adam,初始学习率0.01,每20epoch衰减,衰减因子gamma为0.1,目前还在训练ing,要花两个小时。完整重头训练估计要花4个小时,在之前的基础上微调会快很多,最终测试集Acc达到94.83%,达到预期。下一篇记录利用onnxruntime推理进行测试的过程。
ONNXRuntime学习笔记(二)的更多相关文章
- WPF的Binding学习笔记(二)
原文: http://www.cnblogs.com/pasoraku/archive/2012/10/25/2738428.htmlWPF的Binding学习笔记(二) 上次学了点点Binding的 ...
- AJax 学习笔记二(onreadystatechange的作用)
AJax 学习笔记二(onreadystatechange的作用) 当发送一个请求后,客户端无法确定什么时候会完成这个请求,所以需要用事件机制来捕获请求的状态XMLHttpRequest对象提供了on ...
- [Firefly引擎][学习笔记二][已完结]卡牌游戏开发模型的设计
源地址:http://bbs.9miao.com/thread-44603-1-1.html 在此补充一下Socket的验证机制:socket登陆验证.会采用session会话超时的机制做心跳接口验证 ...
- JMX学习笔记(二)-Notification
Notification通知,也可理解为消息,有通知,必然有发送通知的广播,JMX这里采用了一种订阅的方式,类似于观察者模式,注册一个观察者到广播里,当有通知时,广播通过调用观察者,逐一通知. 这里写 ...
- java之jvm学习笔记二(类装载器的体系结构)
java的class只在需要的时候才内转载入内存,并由java虚拟机的执行引擎来执行,而执行引擎从总的来说主要的执行方式分为四种, 第一种,一次性解释代码,也就是当字节码转载到内存后,每次需要都会重新 ...
- Java IO学习笔记二
Java IO学习笔记二 流的概念 在程序中所有的数据都是以流的方式进行传输或保存的,程序需要数据的时候要使用输入流读取数据,而当程序需要将一些数据保存起来的时候,就要使用输出流完成. 程序中的输入输 ...
- 《SQL必知必会》学习笔记二)
<SQL必知必会>学习笔记(二) 咱们接着上一篇的内容继续.这一篇主要回顾子查询,联合查询,复制表这三类内容. 上一部分基本上都是简单的Select查询,即从单个数据库表中检索数据的单条语 ...
- NumPy学习笔记 二
NumPy学习笔记 二 <NumPy学习笔记>系列将记录学习NumPy过程中的动手笔记,前期的参考书是<Python数据分析基础教程 NumPy学习指南>第二版.<数学分 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(二) indigo tools
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...
随机推荐
- 什么是openssl?
在计算机网络上,OpenSSL是一个开放源代码的软件库包,应用程序可以使用这个包来进行安全通信,避免窃听,同时确认另一端连接者的身份.这个包广泛被应用在互联网的网页服务器上.
- Jinkins流水线脚本使用curl命令调用服务接口,并且使用url传参。
curl http://xxx.xx.xx.xx:xxxx/jenkins/publish?fileName=${fileName}&tag_name=${tag_name} 如图调用不符合c ...
- 在 Spring AOP 中,关注点和横切关注的区别是什么?
关注点是应用中一个模块的行为,一个关注点可能会被定义成一个我们想实现的 一个功能. 横切关注点是一个关注点,此关注点是整个应用都会使用的功能,并影响整个应 用,比如日志,安全和数据传输,几乎应用的每个 ...
- 记一次 Nuxt 3 在 Windows 下的打包问题
0. 背景 之前用 Nuxt 3 写了公司的官网,包括了样式.字体图标.图片.视频等,其中样式和字体图标放在了 assets/styles 和 assets/fonts 目录下,而图片和视频则放在了 ...
- GlusterFS(GFS) 分布式存储
GlusterFS(GFS) 分布式存储 GFS 分布式文件系统 目录 一: GlusterFS 概述 1.1 GlusterFS 简介 1.2 GlusterFS特点 1.2.1 扩展性和高性能 ...
- 【wepy入门教程】48小时开发看美女微信小程序,万花阁
说明:本文只做小程序的开发过程记录:小程序仅供学习参考,严禁用于商业及非法用途 准备 不管是做网站还是做小程序,只要是To C,就少不了做内容,因此第一步依然是数据准备,从网上找到两个网站: http ...
- SQL之总结(二)
4.关于取两个日期之间的年份: ceil(MONTHS_BETWEEN(sysdate, c.sendtime)/12) workTime ceil(n) 取大于等于n的最小整数 floor(n) 取 ...
- 使用element UI el-upload组件实现视频文件上传及上传进度显示方法总结
实现效果: 上传中: 上传完成: 代码: <el-form-item label="视频上传" prop="Video"> <!-- acti ...
- vue后台管理系统组件弹窗
//addFormVisibleIcon可在data中设置true与falsehttps://element.eleme.io/#/zh-CN/component/installation <e ...
- MongoDB 集群-主从复制(一主二从)
MongoDB 集群-主从复制(一主二从) 官方文档 https://docs.mongodb.com/manual/tutorial/deploy-replica-set/ https://docs ...