在进行深度学习训练时,同一模型往往可以训练出不同的效果,这就是炼丹这件事的玄学所在。使用一些trick能够让你更容易追上目前SOTA的效果,一些流行的开源代码中已经集成了不少trick,值得学习一番。本节介绍EMA这一方法。

1.原理:

EMA也就是指数移动平均(Exponential moving average)。其公式非常简单,如下所示:

\(\theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t}\)

\(\theta_{t}\)是t时刻的网络参数,\(\theta_{\text{EMA}, t}\)是t时刻滑动平均后的网络参数,那么t+1时刻的滑动平均结果就是这两者的加权融合。这里 \(\lambda\)通常会取接近于1的数,比如0.9995,数字越大平均的效果就比较强。

值得注意的是,这里可以看成有两个模型,基础模型其参数按照常规的前后向传播来更新,另外一个模型则是基础模型的滑动平均版本,它并不直接参与前后向传播,仅仅是利用基础模型的参数结果来更新自己。

EMA为什么会有效呢?大概是因为在训练的时候,会使用验证集来衡量模型精度,但其实验证集精度并不和测试集一致,在训练后期阶段,模型可能已经在测试集最佳精度附近波动,所以使用滑动平均的结果会比使用单一结果更加可靠。感兴趣的话可以看看这几篇论文,论文1,论文2,论文3

2.实现:

Pytorch其实已经为我们实现了这一功能,为了避免自己造轮子可能引入的错误,这里直接学习一下官方的代码。这个类的名称就叫做AveragedModel。代码如下所示。

我们需要做的是提供avg_fn这个函数,avg_fn用来指定以何种方式进行平均。

  1. class AveragedModel(Module):
  2. """
  3. You can also use custom averaging functions with `avg_fn` parameter.
  4. If no averaging function is provided, the default is to compute
  5. equally-weighted average of the weights.
  6. """
  7. def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
  8. super(AveragedModel, self).__init__()
  9. self.module = deepcopy(model)
  10. if device is not None:
  11. self.module = self.module.to(device)
  12. self.register_buffer('n_averaged',
  13. torch.tensor(0, dtype=torch.long, device=device))
  14. if avg_fn is None:
  15. def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
  16. return averaged_model_parameter + \
  17. (model_parameter - averaged_model_parameter) / (num_averaged + 1)
  18. self.avg_fn = avg_fn
  19. self.use_buffers = use_buffers
  20. def forward(self, *args, **kwargs):
  21. return self.module(*args, **kwargs)
  22. def update_parameters(self, model):
  23. self_param = (
  24. itertools.chain(self.module.parameters(), self.module.buffers())
  25. if self.use_buffers else self.parameters()
  26. )
  27. model_param = (
  28. itertools.chain(model.parameters(), model.buffers())
  29. if self.use_buffers else model.parameters()
  30. )
  31. for p_swa, p_model in zip(self_param, model_param):
  32. device = p_swa.device
  33. p_model_ = p_model.detach().to(device)
  34. if self.n_averaged == 0:
  35. p_swa.detach().copy_(p_model_)
  36. else:
  37. p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
  38. self.n_averaged.to(device)))
  39. self.n_averaged += 1
  40. @torch.no_grad()
  41. def update_bn(loader, model, device=None):
  42. r"""Updates BatchNorm running_mean, running_var buffers in the model.
  43. It performs one pass over data in `loader` to estimate the activation
  44. statistics for BatchNorm layers in the model.
  45. Args:
  46. loader (torch.utils.data.DataLoader): dataset loader to compute the
  47. activation statistics on. Each data batch should be either a
  48. tensor, or a list/tuple whose first element is a tensor
  49. containing data.
  50. model (torch.nn.Module): model for which we seek to update BatchNorm
  51. statistics.
  52. device (torch.device, optional): If set, data will be transferred to
  53. :attr:`device` before being passed into :attr:`model`.
  54. Example:
  55. >>> loader, model = ...
  56. >>> torch.optim.swa_utils.update_bn(loader, model)
  57. .. note::
  58. The `update_bn` utility assumes that each data batch in :attr:`loader`
  59. is either a tensor or a list or tuple of tensors; in the latter case it
  60. is assumed that :meth:`model.forward()` should be called on the first
  61. element of the list or tuple corresponding to the data batch.
  62. """
  63. momenta = {}
  64. for module in model.modules():
  65. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  66. module.running_mean = torch.zeros_like(module.running_mean)
  67. module.running_var = torch.ones_like(module.running_var)
  68. momenta[module] = module.momentum
  69. if not momenta:
  70. return
  71. was_training = model.training
  72. model.train()
  73. for module in momenta.keys():
  74. module.momentum = None
  75. module.num_batches_tracked *= 0
  76. for input in loader:
  77. if isinstance(input, (list, tuple)):
  78. input = input[0]
  79. if device is not None:
  80. input = input.to(device)
  81. model(input)
  82. for bn_module in momenta.keys():
  83. bn_module.momentum = momenta[bn_module]
  84. model.train(was_training)

这里同样参考官方的示例代码,给出滑动平均的实现。ExponentialMovingAverage继承了AveragedModel,并且复写了init方法,其实更直接的方法是将ema_avg函数作为参数传递给AveragedModel,这里可能是为了可读性,避免出现一个孤零零的ema_avg函数。

  1. class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
  2. """Maintains moving averages of model parameters using an exponential decay.
  3. ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
  4. `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
  5. is used to compute the EMA.
  6. """
  7. def __init__(self, model, decay, device="cpu"):
  8. def ema_avg(avg_model_param, model_param, num_averaged):
  9. return decay * avg_model_param + (1 - decay) * model_param
  10. super().__init__(model, device, ema_avg, use_buffers=True)

如何使用呢?方式是比较简单的,首先是利用当前模型创建出一个滑动平均模型。

  1. model_ema = utils.ExponentialMovingAverage(model, device=device, decay=ema_decay)

然后是进行基础模型的前后向传播,更新结束后再对滑动平均版的模型进行参数更新。

  1. output = model(image)
  2. loss = criterion(output, target)
  3. optimizer.zero_grad()
  4. loss.backward()
  5. optimizer.step()
  6. model_ema.update_parameters(model)

【炼丹Trick】EMA的原理与实现的更多相关文章

  1. 【优化技巧】指数移动平均EMA的原理

    前言 在深度学习中,经常会使用EMA(exponential moving average)方法对模型的参数做平滑或者平均,以求提高测试指标,增加模型鲁棒性. 参考 1. [优化技巧]指数移动平均(E ...

  2. 炼丹的一些trick

    采摘一些大佬的果实: 知乎:如何理解深度学习分布式训练中的large batch size与learning rate的关系? https://blog.csdn.net/shanglianlm/ar ...

  3. PHP 底层的运行机制与原理

    PHP说简单,但是要精通也不是一件简单的事.我们除了会使用之外,还得知道它底层的工作原理. PHP是一种适用于web开发的动态语言.具体点说,就是一个用C语言实现包含大量组件的软件框架.更狭义点看,可 ...

  4. JSPatch 实现原理详解

    原文地址https://github.com/bang590/JSPatch/wiki/JSPatch-%E5%AE%9E%E7%8E%B0%E5%8E%9F%E7%90%86%E8%AF%A6%E8 ...

  5. PHP的运行机制与原理(底层) [转]

    说到php的运行机制还要先给大家介绍php的模块,PHP总共有三个模块:内核.Zend引擎.以及扩展层:PHP内核用来处理请求.文件流.错误处理等相关操作:Zend引擎(ZE)用以将源文件转换成机器语 ...

  6. Linux进程调度原理

    Linux进程调度原理 Linux进程调度机制 Linux进程调度的目标 1.高效性:高效意味着在相同的时间下要完成更多的任务.调度程序会被频繁的执行,所以调度程序要尽可能的高效: 2.加强交互性能: ...

  7. PHP底层的运行机制与原理

    PHP说简单,但是要精通也不是一件简单的事.我们除了会使用之外,还得知道它底层的工作原理. PHP是一种适用于web开发的动态语言.具体点说,就是一个用C语言实现包含大量组件的软件框架.更狭义点看,可 ...

  8. 单片微机原理P0:80C51结构原理

    本来我真的不想让51的东西出现在我的博客上的,因为51这种东西真的太low了,学了最多就所谓的垃圾科创利用一下,但是想一下这门课我也要考试,还是写一点东西顺便放博客上吧. 这一系列主要参考<单片 ...

  9. Kernel PCA 原理和演示

    Kernel PCA 原理和演示 主成份(Principal Component Analysis)分析是降维(Dimension Reduction)的重要手段.每一个主成分都是数据在某一个方向上的 ...

随机推荐

  1. WPF样式和触发器

    理解样式 样式可以定义通用的格式化特征集合. Style 类的属性 Setters.Triggers.Resources.BasedOn.TargetType <Style x:Key=&quo ...

  2. esp8266 esp01s wifi继电器 初步点灯成功!艰难的历程啊,期间差点烧了

    0x00 前言说明 放假这几天,在淘宝买了esp01s,和一个搭配esp01s的wifi继电器准备做一些IOT(物联网)实验,踩了不少的坑,总算是点灯成功了!下面记录一些实验的拍照吧~ 继电器参数说明 ...

  3. shell脚本实现MySQL全量备份+异地备份

    一.知识储备工作: Mysql导出数据库语法: mysqldump -u用户名 -p密码 数据库名 > 数据库名.sql shell脚本for循环及if条件判断基本语法 gzip压缩文件用法 r ...

  4. SSH 证书登录教程

    开源Linux 专注分享开源技术知识 SSH 是服务器登录工具,提供密码登录和密钥登录. 但是,SSH 还有第三种登录方法,那就是证书登录.很多情况下,它是更合理.更安全的登录方法,本文就介绍这种登录 ...

  5. 撸了一个 Feign 增强包 V2.0 升级版

    前言 大概在两年前我写过一篇 撸了一个 Feign 增强包,当时准备是利用 SpringBoot + K8s 构建应用,这个库可以类似于 SpringCloud 那样结合 SpringBoot 使用声 ...

  6. .net 项目使用 JSON Schema

    .net 项目使用 JSON Schema 最近公司要做配置项的改造,要把appsettings.json的内容放到数据库,经过分析还是用json的方式存储最为方便,项目改动性最小,这就牵扯到一个问题 ...

  7. vs code 终端字体间距过大(全角的样子)

    文件-首选项-设置 将 terminal.integrated.fontFamily 配置为 Consolas, 'Courier New', monospace 或其他想要的字体,或者点击齿轮按钮重 ...

  8. 153. Find Minimum in Rotated Sorted Array - LeetCode

    Question 153. Find Minimum in Rotated Sorted Array Solution 题目大意:给一个按增序排列的数组,其中有一段错位了[1,2,3,4,5,6]变成 ...

  9. 安装Suberversion[SVN]到CentOS(YUM)

    运行环境 系统版本:CentOS Linux release 7.3.1611 (Core) 软件版本:Suberversion-1.7.14 硬件要求:无 安装过程 1.安装YUM-EPEL源 Su ...

  10. 协议层安全相关《http请求走私与CTF利用》

    0x00 前言 最近刷题的时候多次遇到HTTP请求走私相关的题目,但之前都没怎么接触到相关的知识点,只是在GKCTF2021--hackme中使用到了 CVE-2019-20372(Nginx< ...