继上一篇计划的实践项目,这篇记录我训练模型相关的工作。

  • 首先要确定总体目标:训练一个pytorch模型,CIFAR-100数据集测试集acc达到90%;部署后推理效率达到50ms/张, 部署平台为window10+3050Ti+RX5800h.
  • 训练模型的话,最好是有一套完备的代码,像谷歌的models,FB的detectron2,商汤的mm系列等等框架,这些是建立在深度学习框架tf或pth基础上的进一步封装,提供一些更高级的写好的模块可以调用,如Resnet、FPN、、proposal、NMS等等。但凡事都有两面,封装度越高意味着稳定性更好但修改的灵活性越差。只调用API对我们理解底层实现是不利的。之前我写过一个基于Pytorch的图像分类训练推理代码,现在又可以拿出来用一用了,地址:https://github.com/lee-zq/CNN-Backbone ,我在之前训练CIFAR-10的基础上又添加了CIFAR-100数据集的Dataloader创建代码。
  1. 首先,我尝试了CIFAR10+DenseNet,最后测试效果Acc=85%;然后尝试了CIFAR10+ResNet18,收敛较慢,但最终Acc=91.02%;基于此,。我尝试了CIFAR100+ResNet18,收敛很慢,大概到73Epoch稳定下来,但最终训练集Acc能达到90.62%,但测试集Acc为65.67%。大概率原因是模型拟合能力够用但是训练集多样性太差。模型结构如下:
  1. ResNet(
  2. (conv1): Sequential(
  3. (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  4. (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  5. (2): ReLU()
  6. )
  7. (layer1): Sequential(
  8. (0): ResidualBlock(
  9. (left): Sequential(
  10. (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  11. (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  12. (2): ReLU(inplace=True)
  13. (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  14. (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  15. )
  16. (shortcut): Sequential()
  17. )
  18. (1): ResidualBlock(
  19. (left): Sequential(
  20. (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  21. (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  22. (2): ReLU(inplace=True)
  23. (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  24. (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  25. )
  26. (shortcut): Sequential()
  27. )
  28. )
  29. (layer2): Sequential(
  30. (0): ResidualBlock(
  31. (left): Sequential(
  32. (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  33. (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  34. (2): ReLU(inplace=True)
  35. (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  36. (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  37. )
  38. (shortcut): Sequential(
  39. (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
  40. (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  41. )
  42. )
  43. (1): ResidualBlock(
  44. (left): Sequential(
  45. (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  46. (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  47. (2): ReLU(inplace=True)
  48. (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  49. (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  50. )
  51. (shortcut): Sequential()
  52. )
  53. )
  54. (layer3): Sequential(
  55. (0): ResidualBlock(
  56. (left): Sequential(
  57. (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  58. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  59. (2): ReLU(inplace=True)
  60. (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  61. (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  62. )
  63. (shortcut): Sequential(
  64. (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
  65. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  66. )
  67. )
  68. (1): ResidualBlock(
  69. (left): Sequential(
  70. (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  71. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  72. (2): ReLU(inplace=True)
  73. (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  74. (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  75. )
  76. (shortcut): Sequential()
  77. )
  78. )
  79. (layer4): Sequential(
  80. (0): ResidualBlock(
  81. (left): Sequential(
  82. (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  83. (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  84. (2): ReLU(inplace=True)
  85. (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  86. (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  87. )
  88. (shortcut): Sequential(
  89. (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
  90. (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  91. )
  92. )
  93. (1): ResidualBlock(
  94. (left): Sequential(
  95. (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  96. (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  97. (2): ReLU(inplace=True)
  98. (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  99. (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  100. )
  101. (shortcut): Sequential()
  102. )
  103. )
  104. (fc): Linear(in_features=512, out_features=10, bias=True)
  105. )
  106. Total number of parameters: 11173962

总参数量约11M,既然CIFAR-100效果太差,那就暂且还是用CIFAR-10做后面的训练测试吧,我又在之前的数据增强基础上加了RandomGrayscale和RandomAffine,最终的数据增强如下:

  1. self.mean = [0.4914, 0.4822, 0.4465]
  2. self.std = [0.2023, 0.1994, 0.2010]
  3. self.num_workers= num_workers
  4. self.transform_train = transforms.Compose([# 数据增强
  5. transforms.RandomCrop(32, padding=4),
  6. transforms.RandomHorizontalFlip(),
  7. transforms.RandomGrayscale(0.15),
  8. transforms.RandomAffine((-30,30)),
  9. transforms.RandomRotation(20),
  10. transforms.ToTensor(),
  11. transforms.Normalize(self.mean, self.std),
  12. transforms.RandomErasing(),
  13. ])
  1. 然后微调继续训练,测试集Acc进一步提升到92.28%,可见数据多样性的重要性。进一步的,torchvision提供了AutoAugment数据增强方法的接口,可以直接调用,最终数据增强代码如下:
  1. self.mean = [0.4914, 0.4822, 0.4465]
  2. self.std = [0.2023, 0.1994, 0.2010]
  3. self.num_workers= num_workers
  4. self.transform_train = transforms.Compose([# 数据增强
  5. transforms.RandomCrop(32, padding=4),
  6. transforms.RandomHorizontalFlip(),
  7. transforms.autoaugment.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10),
  8. transforms.ToTensor(),
  9. transforms.Normalize(self.mean, self.std),
  10. transforms.RandomErasing()
  11. ])
  1. 训练epoch数为80,优化器Adam,初始学习率0.01,每20epoch衰减,衰减因子gamma为0.1,目前还在训练ing,要花两个小时。完整重头训练估计要花4个小时,在之前的基础上微调会快很多,最终测试集Acc达到94.83%,达到预期。下一篇记录利用onnxruntime推理进行测试的过程。

ONNXRuntime学习笔记(二)的更多相关文章

  1. WPF的Binding学习笔记(二)

    原文: http://www.cnblogs.com/pasoraku/archive/2012/10/25/2738428.htmlWPF的Binding学习笔记(二) 上次学了点点Binding的 ...

  2. AJax 学习笔记二(onreadystatechange的作用)

    AJax 学习笔记二(onreadystatechange的作用) 当发送一个请求后,客户端无法确定什么时候会完成这个请求,所以需要用事件机制来捕获请求的状态XMLHttpRequest对象提供了on ...

  3. [Firefly引擎][学习笔记二][已完结]卡牌游戏开发模型的设计

    源地址:http://bbs.9miao.com/thread-44603-1-1.html 在此补充一下Socket的验证机制:socket登陆验证.会采用session会话超时的机制做心跳接口验证 ...

  4. JMX学习笔记(二)-Notification

    Notification通知,也可理解为消息,有通知,必然有发送通知的广播,JMX这里采用了一种订阅的方式,类似于观察者模式,注册一个观察者到广播里,当有通知时,广播通过调用观察者,逐一通知. 这里写 ...

  5. java之jvm学习笔记二(类装载器的体系结构)

    java的class只在需要的时候才内转载入内存,并由java虚拟机的执行引擎来执行,而执行引擎从总的来说主要的执行方式分为四种, 第一种,一次性解释代码,也就是当字节码转载到内存后,每次需要都会重新 ...

  6. Java IO学习笔记二

    Java IO学习笔记二 流的概念 在程序中所有的数据都是以流的方式进行传输或保存的,程序需要数据的时候要使用输入流读取数据,而当程序需要将一些数据保存起来的时候,就要使用输出流完成. 程序中的输入输 ...

  7. 《SQL必知必会》学习笔记二)

    <SQL必知必会>学习笔记(二) 咱们接着上一篇的内容继续.这一篇主要回顾子查询,联合查询,复制表这三类内容. 上一部分基本上都是简单的Select查询,即从单个数据库表中检索数据的单条语 ...

  8. NumPy学习笔记 二

    NumPy学习笔记 二 <NumPy学习笔记>系列将记录学习NumPy过程中的动手笔记,前期的参考书是<Python数据分析基础教程 NumPy学习指南>第二版.<数学分 ...

  9. Learning ROS for Robotics Programming Second Edition学习笔记(二) indigo tools

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

随机推荐

  1. 为什么要用 Dubbo?

    随着服务化的进一步发展,服务越来越多,服务之间的调用和依赖关系也越来越 复杂,诞生了面向服务的架构体系(SOA), 也因此衍生出了一系列相应的技术,如对服务提供.服务调用.连接处理.通信 协议.序列化 ...

  2. Kafka 分区数可以增加或减少吗?为什么?

    我们可以使用 bin/kafka-topics.sh 命令对 Kafka 增加 Kafka 的分区数据,但是 Kafka 不支持减少分区数. Kafka 分区数据不支持减少是由很多原因的,比如减少的分 ...

  3. linux设置java环境变量与开机自启

    一.下载jdk并放置在指定位置 二.编辑profile文件 vim /etc/profile  或者  将/etc下的profile 文件修改好再上传覆盖源文件 修改方式即添加以下内容至文件最底部即可 ...

  4. Java 中怎么创建 ByteBuffer?

    byte[] bytes = new byte[10]; ByteBuffer buf = ByteBuffer.wrap(bytes);

  5. Java中带参数的方法和JavaScript中带参数的函数有什么不同?

    javascript是动态语言,是弱类型语言,其参数的使用很灵活:java则是强类型语言,参数的类型必须明确的

  6. spring重点知识分享

    前言: spring是一个轻量级的开源的控制反转(Inversion of Control,IOC)和面向切面(AOP)的容器框架,它的主要目的是简化企业开发.这两个模块使得java开发更加简单.IO ...

  7. 使用 Blueprint 要注意 render_template 函数

    此文章主要是为了记录在使用 Flask 的过程中遇到的问题.本章主要讨论 render_template 函数的问题. 使用 Flask 的同学都应该知道,项目中的 url 和视图函数是在字典里一一对 ...

  8. java中如何创建自定义异常Create Custom Exception

    9.创建自定义异常 Create Custom Exception 马克-to-win:我们可以创建自己的异常:checked或unchecked异常都可以, 规则如前面我们所介绍,反正如果是chec ...

  9. Value注解获取值一直为Null

    @Value("${jwt.tokenHeader}") private String tokenHeader; 常见的错误解决办法如下: 1.使用static或final修饰了t ...

  10. 微信小程序和公众号和H5之间相互跳转

    参考链接:https://www.imooc.com/article/22900 一.小程序和公众号 答案是:可以相互关联. 在微信公众号里可以添加小程序. 可关联已有的小程序或快速创建小程序.已关联 ...