Pytorch中的BatchNorm的API主要有:

1 torch.nn.BatchNorm1d(num_features,
2
3 eps=1e-05,
4
5 momentum=0.1,
6
7 affine=True,
8
9 track_running_stats=True)

一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。

同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainningaffinetrack_running_stats

  • 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True[10]
  • trainningtrack_running_statstrack_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。

一般来说,trainningtrack_running_stats有四种组合[7]

  1. trainning=True, track_running_stats=True。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。
  2. trainning=True, track_running_stats=False。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。
  3. trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_meanrunning_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。[6,8]
  4. trainning=False, track_running_stats=False 效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。

同时,我们要注意到,BN层中的running_meanrunning_var的更新是在forward()操作中进行的,而不是optimizer.step()中进行的,因此如果处于训练状态,就算你不进行手动step(),BN的统计特性也会变化的。如

 1 model.train() # 处于训练状态
2
3
4 for data, label in self.dataloader:
5
6 pred = model(data)
7
8 # 在这里就会更新model中的BN的统计特性参数,running_mean, running_var
9
10 loss = self.loss(pred, label)
11
12 # 就算不要下列三行代码,BN的统计特性参数也会变化
13
14 opt.zero_grad()
15
16 loss.backward()
17
18 opt.step()

这个时候要将model.eval()转到测试阶段,才能固定住running_meanrunning_var。有时候如果是先预训练模型然后加载模型,重新跑测试的时候结果不同,有一点性能上的损失,这个时候十有八九是trainningtrack_running_stats设置的不对,这里需要多注意。 [8]

Reference

[1]. 用pytorch踩过的坑

[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.

[3]. <深度学习优化策略-1>Batch Normalization(BN)

[4]. 详解深度学习中的Normalization,BN/LN/WN

[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24

[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870

[7]. BatchNorm2d增加的参数track_running_stats如何理解?

[8]. Why track_running_stats is not set to False during eval

[9]. How to train with frozen BatchNorm?

[10]. Proper way of fixing batchnorm layers during training

[11]. 大白话《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》

PyTorch中的Batch Normalization的更多相关文章

  1. Pytorch中的Batch Normalization操作

    之前一直和小伙伴探讨batch normalization层的实现机理,作用在这里不谈,知乎上有一篇paper在讲这个,链接 这里只探究其具体运算过程,我们假设在网络中间经过某些卷积操作之后的输出的f ...

  2. 使用TensorFlow中的Batch Normalization

    问题 训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题.但是却不能保证在训练过程中不出现该问题, ...

  3. 在tensorflow中使用batch normalization

    问题 训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题,但是却不能保证在训练过程中不出现该问题, ...

  4. tensorflow中使用Batch Normalization

    在深度学习中为了提高训练速度,经常会使用一些正正则化方法,如L2.dropout,后来Sergey Ioffe 等人提出Batch Normalization方法,可以防止数据分布的变化,影响神经网络 ...

  5. 神经网络中使用Batch Normalization 解决梯度问题

    BN本质上解决的是反向传播过程中的梯度问题. 详细点说,反向传播时经过该层的梯度是要乘以该层的参数的,即前向有: 那么反向传播时便有: 那么考虑从l层传到k层的情况,有: 上面这个 便是问题所在.因为 ...

  6. tensorflow中batch normalization的用法

    网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下: 1.原理 公式如下: y=γ(x-μ)/σ+β 其中x是输入,y是输出,μ是均值,σ ...

  7. Batch Normalization原理及其TensorFlow实现——为了减少深度神经网络中的internal covariate shift,论文中提出了Batch Normalization算法,首先是对”每一层“的输入做一个Batch Normalization 变换

    批标准化(Bactch Normalization,BN)是为了克服神经网络加深导致难以训练而诞生的,随着神经网络深度加深,训练起来就会越来越困难,收敛速度回很慢,常常会导致梯度弥散问题(Vanish ...

  8. Batch Normalization&Dropout浅析

    一. Batch Normalization 对于深度神经网络,训练起来有时很难拟合,可以使用更先进的优化算法,例如:SGD+momentum.RMSProp.Adam等算法.另一种策略则是高改变网络 ...

  9. 《RECURRENT BATCH NORMALIZATION》

    原文链接 https://arxiv.org/pdf/1603.09025.pdf Covariate 协变量:在实验的设计中,协变量是一个独立变量(解释变量),不为实验者所操纵,但仍影响实验结果. ...

随机推荐

  1. 解析一个body片断

    问题 假如你有一个HTML片断 (比如. 一个 div 包含一对 p 标签; 一个不完整的HTML文档) 想对它进行解析.这个HTML片断可以是用户提交的一条评论或在一个CMS页面中编辑body部分. ...

  2. ProjectEuler 009题

    题目: A Pythagorean triplet is a set of three natural numbers, a b c, for which, a2 + b2 = c2 For exam ...

  3. Python代码阅读(第1篇):列表映射后的平均值

    本篇阅读的代码实现了将列表进行映射,并求取映射后的平均值. 本篇阅读的代码片段来自于30-seconds-of-python. average_by def average_by(lst, fn=la ...

  4. 高德地图——添加标记的两种方法&删除地标记的两种方法

    添加标记: 1.marker.setMap(map); 2.marker.add([marker]); 删除标记: 1.marker.setMap(null); 2 map.remove([marke ...

  5. 你的域名是如何变成 IP 地址的?

    我的 个人网站 上线了,上面可以更好的检索历史文章,并且可以对文章进行留言,欢迎大家访问 可能大家都知道或者被问过一个问题,那就是很经典的「从浏览器输入 URL 再到页面展示,都发生了什么」.这个问题 ...

  6. Tomcat 端口配置及原理详解

    1. tomcat 文件配置详细说明 tomcat服务器需配置三个端口才能启动,安装时默认启用了这三个端口,当要运行多个tomcat服务时需要修改这三个端口,不能相同.端口配置路径为tomcat\ c ...

  7. 从synchronize到CSA和

    目录 导论 悲观锁和乐观锁 公平锁和非公平锁 可重入锁和不可重入锁 Synchronized 关键字 实现原理 Java 对象头 Monitor JVM 对 synchronized 的处理 JVM ...

  8. Structs2的作用是什么??

    struts2是一种重量级的框架,位于MVC架构中的controller,可以分析出来,它是用于接受页面信息然后通过内部处理,将结果返回. 同时struts2也是一个web层的MVC框架,那么什么是s ...

  9. Appium问题解决方案(4)- Error while obtaining UI hierarchy XML file: com.android.ddmlib.SyncException

    背景 操作步骤 运行 uiautomatorviewer.bat 点击左上角的 Device ScreensShot 报错 截图 解决方法 网上还是有很多方法的,可能造成的原因不同,我是第六种方法解决 ...

  10. 【LeetCode】862. 和至少为 K 的最短子数组

    862. 和至少为 K 的最短子数组 知识点:单调:队列:前缀和 题目描述 返回 A 的最短的非空连续子数组的长度,该子数组的和至少为 K . 如果没有和至少为 K 的非空子数组,返回 -1 . 示例 ...