ONNXRuntime学习笔记(二)
继上一篇计划的实践项目,这篇记录我训练模型相关的工作。
- 首先要确定总体目标:训练一个pytorch模型,CIFAR-100数据集测试集acc达到90%;部署后推理效率达到50ms/张, 部署平台为window10+3050Ti+RX5800h.
- 训练模型的话,最好是有一套完备的代码,像谷歌的models,FB的detectron2,商汤的mm系列等等框架,这些是建立在深度学习框架tf或pth基础上的进一步封装,提供一些更高级的写好的模块可以调用,如Resnet、FPN、、proposal、NMS等等。但凡事都有两面,封装度越高意味着稳定性更好但修改的灵活性越差。只调用API对我们理解底层实现是不利的。之前我写过一个基于Pytorch的图像分类训练推理代码,现在又可以拿出来用一用了,地址:https://github.com/lee-zq/CNN-Backbone ,我在之前训练CIFAR-10的基础上又添加了CIFAR-100数据集的Dataloader创建代码。
- 首先,我尝试了CIFAR10+DenseNet,最后测试效果Acc=85%;然后尝试了CIFAR10+ResNet18,收敛较慢,但最终Acc=91.02%;基于此,。我尝试了CIFAR100+ResNet18,收敛很慢,大概到73Epoch稳定下来,但最终训练集Acc能达到90.62%,但测试集Acc为65.67%。大概率原因是模型拟合能力够用但是训练集多样性太差。模型结构如下:
ResNet(
(conv1): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(layer1): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer2): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer3): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(layer4): Sequential(
(0): ResidualBlock(
(left): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResidualBlock(
(left): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(shortcut): Sequential()
)
)
(fc): Linear(in_features=512, out_features=10, bias=True)
)
Total number of parameters: 11173962
总参数量约11M,既然CIFAR-100效果太差,那就暂且还是用CIFAR-10做后面的训练测试吧,我又在之前的数据增强基础上加了RandomGrayscale和RandomAffine,最终的数据增强如下:
self.mean = [0.4914, 0.4822, 0.4465]
self.std = [0.2023, 0.1994, 0.2010]
self.num_workers= num_workers
self.transform_train = transforms.Compose([# 数据增强
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomGrayscale(0.15),
transforms.RandomAffine((-30,30)),
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing(),
])
- 然后微调继续训练,测试集Acc进一步提升到92.28%,可见数据多样性的重要性。进一步的,torchvision提供了AutoAugment数据增强方法的接口,可以直接调用,最终数据增强代码如下:
self.mean = [0.4914, 0.4822, 0.4465]
self.std = [0.2023, 0.1994, 0.2010]
self.num_workers= num_workers
self.transform_train = transforms.Compose([# 数据增强
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.autoaugment.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing()
])
- 训练epoch数为80,优化器Adam,初始学习率0.01,每20epoch衰减,衰减因子gamma为0.1,目前还在训练ing,要花两个小时。完整重头训练估计要花4个小时,在之前的基础上微调会快很多,最终测试集Acc达到94.83%,达到预期。下一篇记录利用onnxruntime推理进行测试的过程。
ONNXRuntime学习笔记(二)的更多相关文章
- WPF的Binding学习笔记(二)
原文: http://www.cnblogs.com/pasoraku/archive/2012/10/25/2738428.htmlWPF的Binding学习笔记(二) 上次学了点点Binding的 ...
- AJax 学习笔记二(onreadystatechange的作用)
AJax 学习笔记二(onreadystatechange的作用) 当发送一个请求后,客户端无法确定什么时候会完成这个请求,所以需要用事件机制来捕获请求的状态XMLHttpRequest对象提供了on ...
- [Firefly引擎][学习笔记二][已完结]卡牌游戏开发模型的设计
源地址:http://bbs.9miao.com/thread-44603-1-1.html 在此补充一下Socket的验证机制:socket登陆验证.会采用session会话超时的机制做心跳接口验证 ...
- JMX学习笔记(二)-Notification
Notification通知,也可理解为消息,有通知,必然有发送通知的广播,JMX这里采用了一种订阅的方式,类似于观察者模式,注册一个观察者到广播里,当有通知时,广播通过调用观察者,逐一通知. 这里写 ...
- java之jvm学习笔记二(类装载器的体系结构)
java的class只在需要的时候才内转载入内存,并由java虚拟机的执行引擎来执行,而执行引擎从总的来说主要的执行方式分为四种, 第一种,一次性解释代码,也就是当字节码转载到内存后,每次需要都会重新 ...
- Java IO学习笔记二
Java IO学习笔记二 流的概念 在程序中所有的数据都是以流的方式进行传输或保存的,程序需要数据的时候要使用输入流读取数据,而当程序需要将一些数据保存起来的时候,就要使用输出流完成. 程序中的输入输 ...
- 《SQL必知必会》学习笔记二)
<SQL必知必会>学习笔记(二) 咱们接着上一篇的内容继续.这一篇主要回顾子查询,联合查询,复制表这三类内容. 上一部分基本上都是简单的Select查询,即从单个数据库表中检索数据的单条语 ...
- NumPy学习笔记 二
NumPy学习笔记 二 <NumPy学习笔记>系列将记录学习NumPy过程中的动手笔记,前期的参考书是<Python数据分析基础教程 NumPy学习指南>第二版.<数学分 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(二) indigo tools
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...
随机推荐
- XMLBeanFactory?
最常用的就是 org.springframework.beans.factory.xml.XmlBeanFactory ,它根据XML文件中的定义加载beans.该容器从XML 文件读取配置元数据并用 ...
- try{}里有一个return语句,那么紧跟在这个try后的finally{}里的代码会不会被执行,什么时候被执行,在return前还是后?
答:会执行,在方法返回调用者前执行.
- vue开发chrome扩展,数据通过storage对象获取
开发chrome插件时遇到一个问题,那就是单文件组件的data数据需要从chrome提供的storage对象中获取,但是 chrome.storage.sync.get 方法是异步获取数据的,需要通过 ...
- java-中的代理
静态代理: 例子: 接口: public interface InterfaceBase { void proxy(); } 接口实现类: public class InterfaceBaseReal ...
- 百度移动统计调用api教程,少进坑(82001错误)
相信很多小伙伴使用了百度统计,来查看自己应用使用的情况,但是会发现百度移动统计在官网没有api调用取数据的接口, 现在我就以自己成功调用api并且成功拿到数据,将这个步骤给大家参考,(末尾有调用移动统 ...
- python学习笔记(五)——静态方法、类方法、运算符重载
我们都知道类名是不能够直接调用类方法的.在C++中,把成员方法声明为 static 静态方法后可以通过类名调用.同样的在python中也可以通过定义静态方法的方式让类名直接调用. 静态方法 使用 @s ...
- C++ | 虚表的写入时机
虚表 在C++的多态机制中,使用了 virtual 关键字声明的函数称之为虚函数,每个有虚函数的类或者虚继承的子类,编译器都会为它生成一个虚拟函数表(简称:虚表,以下用 vftable表示),表中的每 ...
- 【静态页面架构】CSS之链接和图像
CSS架构 一.链接: 链接元素:通过使用a元素的href属性设置跳转到指定页面地址 <style> a{ color: blue; text-decoration: none; } a: ...
- 国际化相对时间格式化API:Intl.RelativeTimeFormat
原文:The Intl.RelativeTimeFormat API 作者:Mathias Bynens(@mathias) 现代 Web 应用程序通常使用"昨天","4 ...
- C#编写程序,找一找一个二维数组中的鞍点
编写程序,找一找一个二维数组中的鞍点(即该位置上的元素值在行中最大,在该列上最小.有可能数组没有鞍点).要求: 1.二维数组的大小.数组元素的值在运行时输入: 2.程序有友好的提示信息. 代码: us ...